00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 template<u32 N>
00023 template<typename T1, typename T2>
00024 inline
00025 void
00026 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00027 {
00028 arma_extra_debug_sigprint();
00029
00030 typedef typename T1::elem_type eT;
00031
00032 const partial_unwrap_check<T1> tmp1(X.A, out);
00033 const partial_unwrap_check<T2> tmp2(X.B, out);
00034
00035 const Mat<eT>& A = tmp1.M;
00036 const Mat<eT>& B = tmp2.M;
00037
00038 const bool do_trans_A = tmp1.do_trans;
00039 const bool do_trans_B = tmp2.do_trans;
00040
00041 const bool use_alpha = tmp1.do_times | tmp2.do_times;
00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0);
00043
00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00045 }
00046
00047
00048
00049 template<typename T1, typename T2, typename T3>
00050 inline
00051 void
00052 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
00053 {
00054 arma_extra_debug_sigprint();
00055
00056 typedef typename T1::elem_type eT;
00057
00058
00059
00060
00061 const partial_unwrap_check<T1> tmp1(X.A.A, out);
00062 const partial_unwrap_check<T2> tmp2(X.A.B, out);
00063 const partial_unwrap_check<T3> tmp3(X.B, out);
00064
00065 const Mat<eT>& A = tmp1.M;
00066 const Mat<eT>& B = tmp2.M;
00067 const Mat<eT>& C = tmp3.M;
00068
00069 const bool do_trans_A = tmp1.do_trans;
00070 const bool do_trans_B = tmp2.do_trans;
00071 const bool do_trans_C = tmp3.do_trans;
00072
00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times;
00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0);
00075
00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00077 }
00078
00079
00080
00081 template<typename T1, typename T2, typename T3, typename T4>
00082 inline
00083 void
00084 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
00085 {
00086 arma_extra_debug_sigprint();
00087
00088 typedef typename T1::elem_type eT;
00089
00090
00091
00092
00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
00095 const partial_unwrap_check<T3> tmp3(X.A.B, out);
00096 const partial_unwrap_check<T4> tmp4(X.B, out);
00097
00098 const Mat<eT>& A = tmp1.M;
00099 const Mat<eT>& B = tmp2.M;
00100 const Mat<eT>& C = tmp3.M;
00101 const Mat<eT>& D = tmp4.M;
00102
00103 const bool do_trans_A = tmp1.do_trans;
00104 const bool do_trans_B = tmp2.do_trans;
00105 const bool do_trans_C = tmp3.do_trans;
00106 const bool do_trans_D = tmp4.do_trans;
00107
00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times;
00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0);
00110
00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00112 }
00113
00114
00115
00116 template<typename T1, typename T2>
00117 inline
00118 void
00119 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00120 {
00121 arma_extra_debug_sigprint();
00122
00123 typedef typename T1::elem_type eT;
00124
00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00126
00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00128
00129 glue_times_redirect<N_mat>::apply(out, X);
00130 }
00131
00132
00133
00134 template<typename T1>
00135 inline
00136 void
00137 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00138 {
00139 arma_extra_debug_sigprint();
00140
00141 typedef typename T1::elem_type eT;
00142
00143 const unwrap_check<T1> tmp(X, out);
00144 const Mat<eT>& B = tmp.M;
00145
00146 arma_debug_assert_mul_size(out, B, "matrix multiply");
00147
00148 if(out.n_cols == B.n_cols)
00149 {
00150 podarray<eT> tmp(out.n_cols);
00151 eT* tmp_rowdata = tmp.memptr();
00152
00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00154 {
00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00156 {
00157 tmp_rowdata[out_col] = out.at(out_row,out_col);
00158 }
00159
00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00161 {
00162 const eT* B_coldata = B.colptr(B_col);
00163
00164 eT val = eT(0);
00165 for(u32 i=0; i < B.n_rows; ++i)
00166 {
00167 val += tmp_rowdata[i] * B_coldata[i];
00168 }
00169
00170 out.at(out_row,B_col) = val;
00171 }
00172 }
00173
00174 }
00175 else
00176 {
00177 const Mat<eT> tmp(out);
00178 glue_times::apply(out, tmp, B, eT(1), false, false, false);
00179 }
00180
00181 }
00182
00183
00184
00185 template<typename T1, typename T2>
00186 arma_hot
00187 inline
00188 void
00189 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const s32 sign)
00190 {
00191 arma_extra_debug_sigprint();
00192
00193 typedef typename T1::elem_type eT;
00194
00195 const partial_unwrap_check<T1> tmp1(X.A, out);
00196 const partial_unwrap_check<T2> tmp2(X.B, out);
00197
00198 const Mat<eT>& A = tmp1.M;
00199 const Mat<eT>& B = tmp2.M;
00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) );
00201
00202 const bool do_trans_A = tmp1.do_trans;
00203 const bool do_trans_B = tmp2.do_trans;
00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0));
00205
00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00207
00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00210
00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition");
00212
00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00214 {
00215 if(A.n_rows == 1)
00216 {
00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00218 }
00219 else
00220 if(B.n_cols == 1)
00221 {
00222 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00223 }
00224 else
00225 {
00226 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
00227 }
00228 }
00229 else
00230 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00231 {
00232 if(A.n_rows == 1)
00233 {
00234 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00235 }
00236 else
00237 if(B.n_cols == 1)
00238 {
00239 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00240 }
00241 else
00242 {
00243 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
00244 }
00245 }
00246 else
00247 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00248 {
00249 if(A.n_cols == 1)
00250 {
00251 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00252 }
00253 else
00254 if(B.n_cols == 1)
00255 {
00256 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00257 }
00258 else
00259 {
00260 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
00261 }
00262 }
00263 else
00264 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00265 {
00266 if(A.n_cols == 1)
00267 {
00268 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00269 }
00270 else
00271 if(B.n_cols == 1)
00272 {
00273 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00274 }
00275 else
00276 {
00277 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
00278 }
00279 }
00280 else
00281 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00282 {
00283 if(A.n_rows == 1)
00284 {
00285 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00286 }
00287 else
00288 if(B.n_rows == 1)
00289 {
00290 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00291 }
00292 else
00293 {
00294 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
00295 }
00296 }
00297 else
00298 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00299 {
00300 if(A.n_rows == 1)
00301 {
00302 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00303 }
00304 else
00305 if(B.n_rows == 1)
00306 {
00307 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00308 }
00309 else
00310 {
00311 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
00312 }
00313 }
00314 else
00315 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00316 {
00317 if(A.n_cols == 1)
00318 {
00319 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00320 }
00321 else
00322 if(B.n_rows == 1)
00323 {
00324 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00325 }
00326 else
00327 {
00328 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
00329 }
00330 }
00331 else
00332 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00333 {
00334 if(A.n_cols == 1)
00335 {
00336 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00337 }
00338 else
00339 if(B.n_rows == 1)
00340 {
00341 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00342 }
00343 else
00344 {
00345 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
00346 }
00347 }
00348
00349
00350 }
00351
00352
00353
00354 template<typename eT>
00355 arma_inline
00356 u32
00357 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
00358 {
00359 const u32 final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00360 const u32 final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00361
00362 return final_A_n_rows * final_B_n_cols;
00363 }
00364
00365
00366
00367 template<typename eT>
00368 arma_hot
00369 inline
00370 void
00371 glue_times::apply
00372 (
00373 Mat<eT>& out,
00374 const Mat<eT>& A,
00375 const Mat<eT>& B,
00376 const eT alpha,
00377 const bool do_trans_A,
00378 const bool do_trans_B,
00379 const bool use_alpha
00380 )
00381 {
00382 arma_extra_debug_sigprint();
00383
00384 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00385
00386 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00387 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00388
00389 out.set_size(final_n_rows, final_n_cols);
00390
00391
00392
00393 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00394 {
00395 if(A.n_rows == 1)
00396 {
00397 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00398 }
00399 else
00400 if(B.n_cols == 1)
00401 {
00402 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00403 }
00404 else
00405 {
00406 gemm<false, false, false, false>::apply(out, A, B);
00407 }
00408 }
00409 else
00410 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00411 {
00412 if(A.n_rows == 1)
00413 {
00414 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00415 }
00416 else
00417 if(B.n_cols == 1)
00418 {
00419 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00420 }
00421 else
00422 {
00423 gemm<false, false, true, false>::apply(out, A, B, alpha);
00424 }
00425 }
00426 else
00427 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00428 {
00429 if(A.n_cols == 1)
00430 {
00431 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00432 }
00433 else
00434 if(B.n_cols == 1)
00435 {
00436 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00437 }
00438 else
00439 {
00440 gemm<true, false, false, false>::apply(out, A, B);
00441 }
00442 }
00443 else
00444 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00445 {
00446 if(A.n_cols == 1)
00447 {
00448 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00449 }
00450 else
00451 if(B.n_cols == 1)
00452 {
00453 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00454 }
00455 else
00456 {
00457 gemm<true, false, true, false>::apply(out, A, B, alpha);
00458 }
00459 }
00460 else
00461 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00462 {
00463 if(A.n_rows == 1)
00464 {
00465 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00466 }
00467 else
00468 if(B.n_rows == 1)
00469 {
00470 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00471 }
00472 else
00473 {
00474 gemm<false, true, false, false>::apply(out, A, B);
00475 }
00476 }
00477 else
00478 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00479 {
00480 if(A.n_rows == 1)
00481 {
00482 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00483 }
00484 else
00485 if(B.n_rows == 1)
00486 {
00487 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00488 }
00489 else
00490 {
00491 gemm<false, true, true, false>::apply(out, A, B, alpha);
00492 }
00493 }
00494 else
00495 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00496 {
00497 if(A.n_cols == 1)
00498 {
00499 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00500 }
00501 else
00502 if(B.n_rows == 1)
00503 {
00504 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00505 }
00506 else
00507 {
00508 gemm<true, true, false, false>::apply(out, A, B);
00509 }
00510 }
00511 else
00512 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00513 {
00514 if(A.n_cols == 1)
00515 {
00516 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00517 }
00518 else
00519 if(B.n_rows == 1)
00520 {
00521 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00522 }
00523 else
00524 {
00525 gemm<true, true, true, false>::apply(out, A, B, alpha);
00526 }
00527 }
00528 }
00529
00530
00531
00532 template<typename eT>
00533 inline
00534 void
00535 glue_times::apply
00536 (
00537 Mat<eT>& out,
00538 const Mat<eT>& A,
00539 const Mat<eT>& B,
00540 const Mat<eT>& C,
00541 const eT alpha,
00542 const bool do_trans_A,
00543 const bool do_trans_B,
00544 const bool do_trans_C,
00545 const bool use_alpha
00546 )
00547 {
00548 arma_extra_debug_sigprint();
00549
00550 Mat<eT> tmp;
00551
00552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
00553 {
00554
00555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
00557 }
00558 else
00559 {
00560
00561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
00562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
00563 }
00564 }
00565
00566
00567
00568 template<typename eT>
00569 inline
00570 void
00571 glue_times::apply
00572 (
00573 Mat<eT>& out,
00574 const Mat<eT>& A,
00575 const Mat<eT>& B,
00576 const Mat<eT>& C,
00577 const Mat<eT>& D,
00578 const eT alpha,
00579 const bool do_trans_A,
00580 const bool do_trans_B,
00581 const bool do_trans_C,
00582 const bool do_trans_D,
00583 const bool use_alpha
00584 )
00585 {
00586 arma_extra_debug_sigprint();
00587
00588 Mat<eT> tmp;
00589
00590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
00591 {
00592
00593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00594
00595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
00596 }
00597 else
00598 {
00599
00600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00601
00602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
00603 }
00604 }
00605
00606
00607
00608
00609
00610
00611
00612 template<typename T1, typename T2>
00613 arma_hot
00614 inline
00615 void
00616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
00617 {
00618 arma_extra_debug_sigprint();
00619
00620 typedef typename T1::elem_type eT;
00621
00622 const strip_diagmat<T1> S1(X.A);
00623 const strip_diagmat<T2> S2(X.B);
00624
00625 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
00626 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00627
00628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
00629 {
00630 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00631
00632 const unwrap_check<T2> tmp(X.B, out);
00633 const Mat<eT>& B = tmp.M;
00634
00635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply");
00636
00637 out.set_size(A.n_elem, B.n_cols);
00638
00639 for(u32 col=0; col<B.n_cols; ++col)
00640 {
00641 eT* out_coldata = out.colptr(col);
00642 const eT* B_coldata = B.colptr(col);
00643
00644 for(u32 row=0; row<B.n_rows; ++row)
00645 {
00646 out_coldata[row] = A[row] * B_coldata[row];
00647 }
00648 }
00649 }
00650 else
00651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
00652 {
00653 const unwrap_check<T1> tmp(X.A, out);
00654 const Mat<eT>& A = tmp.M;
00655
00656 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00657
00658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply");
00659
00660 out.set_size(A.n_rows, B.n_elem);
00661
00662 for(u32 col=0; col<A.n_cols; ++col)
00663 {
00664 const eT val = B[col];
00665
00666 eT* out_coldata = out.colptr(col);
00667 const eT* A_coldata = A.colptr(col);
00668
00669 for(u32 row=0; row<A.n_rows; ++row)
00670 {
00671 out_coldata[row] = A_coldata[row] * val;
00672 }
00673 }
00674 }
00675 else
00676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
00677 {
00678 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00679 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00680
00681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply");
00682
00683 out.zeros(A.n_elem, A.n_elem);
00684
00685 for(u32 i=0; i<A.n_elem; ++i)
00686 {
00687 out.at(i,i) = A[i] * B[i];
00688 }
00689 }
00690 }
00691
00692
00693
00694