00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 template<u32 N>
00023 template<typename T1, typename T2>
00024 inline
00025 void
00026 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00027 {
00028 arma_extra_debug_sigprint();
00029
00030 typedef typename T1::elem_type eT;
00031
00032 const partial_unwrap_check<T1> tmp1(X.A, out);
00033 const partial_unwrap_check<T2> tmp2(X.B, out);
00034
00035 const Mat<eT>& A = tmp1.M;
00036 const Mat<eT>& B = tmp2.M;
00037
00038 const bool do_trans_A = tmp1.do_trans;
00039 const bool do_trans_B = tmp2.do_trans;
00040
00041 const bool use_alpha = tmp1.do_times | tmp2.do_times;
00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0);
00043
00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00045 }
00046
00047
00048
00049 template<typename T1, typename T2, typename T3>
00050 inline
00051 void
00052 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
00053 {
00054 arma_extra_debug_sigprint();
00055
00056 typedef typename T1::elem_type eT;
00057
00058
00059
00060
00061 const partial_unwrap_check<T1> tmp1(X.A.A, out);
00062 const partial_unwrap_check<T2> tmp2(X.A.B, out);
00063 const partial_unwrap_check<T3> tmp3(X.B, out);
00064
00065 const Mat<eT>& A = tmp1.M;
00066 const Mat<eT>& B = tmp2.M;
00067 const Mat<eT>& C = tmp3.M;
00068
00069 const bool do_trans_A = tmp1.do_trans;
00070 const bool do_trans_B = tmp2.do_trans;
00071 const bool do_trans_C = tmp3.do_trans;
00072
00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times;
00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0);
00075
00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00077 }
00078
00079
00080
00081 template<typename T1, typename T2, typename T3, typename T4>
00082 inline
00083 void
00084 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
00085 {
00086 arma_extra_debug_sigprint();
00087
00088 typedef typename T1::elem_type eT;
00089
00090
00091
00092
00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
00095 const partial_unwrap_check<T3> tmp3(X.A.B, out);
00096 const partial_unwrap_check<T4> tmp4(X.B, out);
00097
00098 const Mat<eT>& A = tmp1.M;
00099 const Mat<eT>& B = tmp2.M;
00100 const Mat<eT>& C = tmp3.M;
00101 const Mat<eT>& D = tmp4.M;
00102
00103 const bool do_trans_A = tmp1.do_trans;
00104 const bool do_trans_B = tmp2.do_trans;
00105 const bool do_trans_C = tmp3.do_trans;
00106 const bool do_trans_D = tmp4.do_trans;
00107
00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times;
00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0);
00110
00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00112 }
00113
00114
00115
00116 template<typename T1, typename T2>
00117 inline
00118 void
00119 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00120 {
00121 arma_extra_debug_sigprint();
00122
00123 typedef typename T1::elem_type eT;
00124
00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00126
00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00128
00129 glue_times_redirect<N_mat>::apply(out, X);
00130 }
00131
00132
00133
00134 template<typename T1>
00135 inline
00136 void
00137 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00138 {
00139 arma_extra_debug_sigprint();
00140
00141 typedef typename T1::elem_type eT;
00142
00143 const unwrap_check<T1> tmp(X, out);
00144 const Mat<eT>& B = tmp.M;
00145
00146 arma_debug_assert_mul_size(out, B, "matrix multiply");
00147
00148 if(out.n_cols == B.n_cols)
00149 {
00150 podarray<eT> tmp(out.n_cols);
00151 eT* tmp_rowdata = tmp.memptr();
00152
00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00154 {
00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00156 {
00157 tmp_rowdata[out_col] = out.at(out_row,out_col);
00158 }
00159
00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00161 {
00162 const eT* B_coldata = B.colptr(B_col);
00163
00164 eT val = eT(0);
00165 for(u32 i=0; i < B.n_rows; ++i)
00166 {
00167 val += tmp_rowdata[i] * B_coldata[i];
00168 }
00169
00170 out.at(out_row,B_col) = val;
00171 }
00172 }
00173
00174 }
00175 else
00176 {
00177 const Mat<eT> tmp(out);
00178 glue_times::apply(out, tmp, B, eT(1), false, false, false);
00179 }
00180
00181 }
00182
00183
00184
00185 template<typename T1, typename T2>
00186 arma_hot
00187 inline
00188 void
00189 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const s32 sign)
00190 {
00191 arma_extra_debug_sigprint();
00192
00193 typedef typename T1::elem_type eT;
00194
00195 const partial_unwrap_check<T1> tmp1(X.A, out);
00196 const partial_unwrap_check<T2> tmp2(X.B, out);
00197
00198 const Mat<eT>& A = tmp1.M;
00199 const Mat<eT>& B = tmp2.M;
00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) );
00201
00202 const bool do_trans_A = tmp1.do_trans;
00203 const bool do_trans_B = tmp2.do_trans;
00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0));
00205
00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00207
00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00210
00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition");
00212
00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00214 {
00215 if(A.n_rows == 1)
00216 {
00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00218 }
00219 else
00220 if(B.n_cols == 1)
00221 {
00222 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00223 }
00224 else
00225 {
00226 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
00227 }
00228 }
00229 else
00230 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00231 {
00232 if(A.n_rows == 1)
00233 {
00234 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00235 }
00236 else
00237 if(B.n_cols == 1)
00238 {
00239 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00240 }
00241 else
00242 {
00243 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
00244 }
00245 }
00246 else
00247 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00248 {
00249 if(A.n_cols == 1)
00250 {
00251 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00252 }
00253 else
00254 if(B.n_cols == 1)
00255 {
00256 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00257 }
00258 else
00259 {
00260 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
00261 }
00262 }
00263 else
00264 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00265 {
00266 if(A.n_cols == 1)
00267 {
00268 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00269 }
00270 else
00271 if(B.n_cols == 1)
00272 {
00273 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00274 }
00275 else
00276 {
00277 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
00278 }
00279 }
00280 else
00281 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00282 {
00283 if(A.n_rows == 1)
00284 {
00285 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00286 }
00287 else
00288 if(B.n_rows == 1)
00289 {
00290 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00291 }
00292 else
00293 {
00294 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
00295 }
00296 }
00297 else
00298 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00299 {
00300 if(A.n_rows == 1)
00301 {
00302 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00303 }
00304 else
00305 if(B.n_rows == 1)
00306 {
00307 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00308 }
00309 else
00310 {
00311 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
00312 }
00313 }
00314 else
00315 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00316 {
00317 if(A.n_cols == 1)
00318 {
00319 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00320 }
00321 else
00322 if(B.n_rows == 1)
00323 {
00324 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00325 }
00326 else
00327 {
00328 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
00329 }
00330 }
00331 else
00332 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00333 {
00334 if(A.n_cols == 1)
00335 {
00336 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00337 }
00338 else
00339 if(B.n_rows == 1)
00340 {
00341 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00342 }
00343 else
00344 {
00345 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
00346 }
00347 }
00348
00349
00350 }
00351
00352
00353
00354
00355 template<typename eT1, typename eT2>
00356 inline
00357 void
00358 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00359 {
00360 arma_extra_debug_sigprint();
00361
00362 typedef typename promote_type<eT1,eT2>::result out_eT;
00363
00364 arma_debug_assert_mul_size(X,Y, "matrix multiply");
00365
00366 out.set_size(X.n_rows,Y.n_cols);
00367 gemm_mixed<>::apply(out, X, Y);
00368 }
00369
00370
00371
00372 template<typename eT>
00373 arma_inline
00374 u32
00375 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
00376 {
00377 const u32 final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00378 const u32 final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00379
00380 return final_A_n_rows * final_B_n_cols;
00381 }
00382
00383
00384
00385 template<typename eT>
00386 arma_hot
00387 inline
00388 void
00389 glue_times::apply
00390 (
00391 Mat<eT>& out,
00392 const Mat<eT>& A,
00393 const Mat<eT>& B,
00394 const eT alpha,
00395 const bool do_trans_A,
00396 const bool do_trans_B,
00397 const bool use_alpha
00398 )
00399 {
00400 arma_extra_debug_sigprint();
00401
00402 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00403
00404 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00405 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00406
00407 out.set_size(final_n_rows, final_n_cols);
00408
00409
00410
00411 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00412 {
00413 if(A.n_rows == 1)
00414 {
00415 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00416 }
00417 else
00418 if(B.n_cols == 1)
00419 {
00420 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00421 }
00422 else
00423 {
00424 gemm<false, false, false, false>::apply(out, A, B);
00425 }
00426 }
00427 else
00428 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00429 {
00430 if(A.n_rows == 1)
00431 {
00432 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00433 }
00434 else
00435 if(B.n_cols == 1)
00436 {
00437 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00438 }
00439 else
00440 {
00441 gemm<false, false, true, false>::apply(out, A, B, alpha);
00442 }
00443 }
00444 else
00445 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00446 {
00447 if(A.n_cols == 1)
00448 {
00449 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00450 }
00451 else
00452 if(B.n_cols == 1)
00453 {
00454 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00455 }
00456 else
00457 {
00458 gemm<true, false, false, false>::apply(out, A, B);
00459 }
00460 }
00461 else
00462 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00463 {
00464 if(A.n_cols == 1)
00465 {
00466 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00467 }
00468 else
00469 if(B.n_cols == 1)
00470 {
00471 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00472 }
00473 else
00474 {
00475 gemm<true, false, true, false>::apply(out, A, B, alpha);
00476 }
00477 }
00478 else
00479 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00480 {
00481 if(A.n_rows == 1)
00482 {
00483 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00484 }
00485 else
00486 if(B.n_rows == 1)
00487 {
00488 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00489 }
00490 else
00491 {
00492 gemm<false, true, false, false>::apply(out, A, B);
00493 }
00494 }
00495 else
00496 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00497 {
00498 if(A.n_rows == 1)
00499 {
00500 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00501 }
00502 else
00503 if(B.n_rows == 1)
00504 {
00505 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00506 }
00507 else
00508 {
00509 gemm<false, true, true, false>::apply(out, A, B, alpha);
00510 }
00511 }
00512 else
00513 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00514 {
00515 if(A.n_cols == 1)
00516 {
00517 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00518 }
00519 else
00520 if(B.n_rows == 1)
00521 {
00522 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00523 }
00524 else
00525 {
00526 gemm<true, true, false, false>::apply(out, A, B);
00527 }
00528 }
00529 else
00530 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00531 {
00532 if(A.n_cols == 1)
00533 {
00534 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00535 }
00536 else
00537 if(B.n_rows == 1)
00538 {
00539 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00540 }
00541 else
00542 {
00543 gemm<true, true, true, false>::apply(out, A, B, alpha);
00544 }
00545 }
00546 }
00547
00548
00549
00550 template<typename eT>
00551 inline
00552 void
00553 glue_times::apply
00554 (
00555 Mat<eT>& out,
00556 const Mat<eT>& A,
00557 const Mat<eT>& B,
00558 const Mat<eT>& C,
00559 const eT alpha,
00560 const bool do_trans_A,
00561 const bool do_trans_B,
00562 const bool do_trans_C,
00563 const bool use_alpha
00564 )
00565 {
00566 arma_extra_debug_sigprint();
00567
00568 Mat<eT> tmp;
00569
00570 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
00571 {
00572
00573 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00574 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
00575 }
00576 else
00577 {
00578
00579 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
00580 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
00581 }
00582 }
00583
00584
00585
00586 template<typename eT>
00587 inline
00588 void
00589 glue_times::apply
00590 (
00591 Mat<eT>& out,
00592 const Mat<eT>& A,
00593 const Mat<eT>& B,
00594 const Mat<eT>& C,
00595 const Mat<eT>& D,
00596 const eT alpha,
00597 const bool do_trans_A,
00598 const bool do_trans_B,
00599 const bool do_trans_C,
00600 const bool do_trans_D,
00601 const bool use_alpha
00602 )
00603 {
00604 arma_extra_debug_sigprint();
00605
00606 Mat<eT> tmp;
00607
00608 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
00609 {
00610
00611 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00612
00613 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
00614 }
00615 else
00616 {
00617
00618 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00619
00620 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
00621 }
00622 }
00623
00624
00625
00626
00627
00628
00629
00630 template<typename T1, typename T2>
00631 arma_hot
00632 inline
00633 void
00634 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
00635 {
00636 arma_extra_debug_sigprint();
00637
00638 typedef typename T1::elem_type eT;
00639
00640 const strip_diagmat<T1> S1(X.A);
00641 const strip_diagmat<T2> S2(X.B);
00642
00643 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
00644 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00645
00646 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
00647 {
00648 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00649
00650 const unwrap_check<T2> tmp(X.B, out);
00651 const Mat<eT>& B = tmp.M;
00652
00653 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply");
00654
00655 out.set_size(A.n_elem, B.n_cols);
00656
00657 for(u32 col=0; col<B.n_cols; ++col)
00658 {
00659 eT* out_coldata = out.colptr(col);
00660 const eT* B_coldata = B.colptr(col);
00661
00662 for(u32 row=0; row<B.n_rows; ++row)
00663 {
00664 out_coldata[row] = A[row] * B_coldata[row];
00665 }
00666 }
00667 }
00668 else
00669 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
00670 {
00671 const unwrap_check<T1> tmp(X.A, out);
00672 const Mat<eT>& A = tmp.M;
00673
00674 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00675
00676 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply");
00677
00678 out.set_size(A.n_rows, B.n_elem);
00679
00680 for(u32 col=0; col<A.n_cols; ++col)
00681 {
00682 const eT val = B[col];
00683
00684 eT* out_coldata = out.colptr(col);
00685 const eT* A_coldata = A.colptr(col);
00686
00687 for(u32 row=0; row<A.n_rows; ++row)
00688 {
00689 out_coldata[row] = A_coldata[row] * val;
00690 }
00691 }
00692 }
00693 else
00694 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
00695 {
00696 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00697 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00698
00699 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply");
00700
00701 out.zeros(A.n_elem, A.n_elem);
00702
00703 for(u32 i=0; i<A.n_elem; ++i)
00704 {
00705 out.at(i,i) = A[i] * B[i];
00706 }
00707 }
00708 }
00709
00710
00711
00712