Classes | |
struct | depth_lhs< glue_type, T1 > |
Template metaprogram depth_lhs calculates the number of Glue<Tx,Ty, glue_type> instances on the left hand side argument of Glue<Tx,Ty, glue_type> i.e. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes). More... | |
struct | depth_lhs< glue_type, Glue< T1, T2, glue_type > > |
struct | glue_times_redirect< N > |
struct | glue_times_redirect< 3 > |
struct | glue_times_redirect< 4 > |
class | glue_times |
Class which implements the immediate multiplication of two or more matrices. More... | |
class | glue_times_diag |
Functions | |
template<typename T1 , typename T2 > | |
static void | glue_times_redirect::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 > | |
static void | glue_times_redirect< 3 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< T1, T2, glue_times >, T3, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 , typename T4 > | |
static void | glue_times_redirect< 4 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > &X) |
template<typename T1 , typename T2 > | |
static void | glue_times::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 > | |
static void | glue_times::apply_inplace (Mat< typename T1::elem_type > &out, const T1 &X) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times::apply_inplace_plus (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X, const s32 sign) |
template<typename eT1 , typename eT2 > | |
static void | glue_times::apply_mixed (Mat< typename promote_type< eT1, eT2 >::result > &out, const Mat< eT1 > &X, const Mat< eT2 > &Y) |
matrix multiplication with different element types | |
template<typename eT > | |
static arma_inline u32 | glue_times::mul_storage_cost (const Mat< eT > &A, const Mat< eT > &B, const bool do_trans_A, const bool do_trans_B) |
template<typename eT > | |
static arma_hot void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const Mat< eT > &D, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_trans_D, const bool do_scalar_times) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times_diag::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times_diag > &X) |
void glue_times_redirect< N >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 26 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
Referenced by glue_times_redirect< 4 >::apply(), and glue_times_redirect< 3 >::apply().
00027 { 00028 arma_extra_debug_sigprint(); 00029 00030 typedef typename T1::elem_type eT; 00031 00032 const partial_unwrap_check<T1> tmp1(X.A, out); 00033 const partial_unwrap_check<T2> tmp2(X.B, out); 00034 00035 const Mat<eT>& A = tmp1.M; 00036 const Mat<eT>& B = tmp2.M; 00037 00038 const bool do_trans_A = tmp1.do_trans; 00039 const bool do_trans_B = tmp2.do_trans; 00040 00041 const bool use_alpha = tmp1.do_times | tmp2.do_times; 00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0); 00043 00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00045 }
void glue_times_redirect< 3 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< T1, T2, glue_times >, T3, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 52 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00053 { 00054 arma_extra_debug_sigprint(); 00055 00056 typedef typename T1::elem_type eT; 00057 00058 // there is exactly 3 objects 00059 // hence we can safely expand X as X.A.A, X.A.B and X.B 00060 00061 const partial_unwrap_check<T1> tmp1(X.A.A, out); 00062 const partial_unwrap_check<T2> tmp2(X.A.B, out); 00063 const partial_unwrap_check<T3> tmp3(X.B, out); 00064 00065 const Mat<eT>& A = tmp1.M; 00066 const Mat<eT>& B = tmp2.M; 00067 const Mat<eT>& C = tmp3.M; 00068 00069 const bool do_trans_A = tmp1.do_trans; 00070 const bool do_trans_B = tmp2.do_trans; 00071 const bool do_trans_C = tmp3.do_trans; 00072 00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times; 00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0); 00075 00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00077 }
void glue_times_redirect< 4 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 84 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00085 { 00086 arma_extra_debug_sigprint(); 00087 00088 typedef typename T1::elem_type eT; 00089 00090 // there is exactly 4 objects 00091 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B 00092 00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out); 00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out); 00095 const partial_unwrap_check<T3> tmp3(X.A.B, out); 00096 const partial_unwrap_check<T4> tmp4(X.B, out); 00097 00098 const Mat<eT>& A = tmp1.M; 00099 const Mat<eT>& B = tmp2.M; 00100 const Mat<eT>& C = tmp3.M; 00101 const Mat<eT>& D = tmp4.M; 00102 00103 const bool do_trans_A = tmp1.do_trans; 00104 const bool do_trans_B = tmp2.do_trans; 00105 const bool do_trans_C = tmp3.do_trans; 00106 const bool do_trans_D = tmp4.do_trans; 00107 00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times; 00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0); 00110 00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00112 }
void glue_times::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 119 of file glue_times_meat.hpp.
Referenced by apply(), apply_inplace(), apply_inplace_plus(), and apply_mixed().
00120 { 00121 arma_extra_debug_sigprint(); 00122 00123 typedef typename T1::elem_type eT; 00124 00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; 00126 00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); 00128 00129 glue_times_redirect<N_mat>::apply(out, X); 00130 }
void glue_times::apply_inplace | ( | Mat< typename T1::elem_type > & | out, | |
const T1 & | X | |||
) | [inline, static, inherited] |
Definition at line 137 of file glue_times_meat.hpp.
References apply(), Mat< eT >::at(), Mat< eT >::colptr(), unwrap_check< T1 >::M, podarray< eT >::memptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by Mat< eT >::operator*=().
00138 { 00139 arma_extra_debug_sigprint(); 00140 00141 typedef typename T1::elem_type eT; 00142 00143 const unwrap_check<T1> tmp(X, out); 00144 const Mat<eT>& B = tmp.M; 00145 00146 arma_debug_assert_mul_size(out, B, "matrix multiply"); 00147 00148 if(out.n_cols == B.n_cols) 00149 { 00150 podarray<eT> tmp(out.n_cols); 00151 eT* tmp_rowdata = tmp.memptr(); 00152 00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row) 00154 { 00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col) 00156 { 00157 tmp_rowdata[out_col] = out.at(out_row,out_col); 00158 } 00159 00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col) 00161 { 00162 const eT* B_coldata = B.colptr(B_col); 00163 00164 eT val = eT(0); 00165 for(u32 i=0; i < B.n_rows; ++i) 00166 { 00167 val += tmp_rowdata[i] * B_coldata[i]; 00168 } 00169 00170 out.at(out_row,B_col) = val; 00171 } 00172 } 00173 00174 } 00175 else 00176 { 00177 const Mat<eT> tmp(out); 00178 glue_times::apply(out, tmp, B, eT(1), false, false, false); 00179 } 00180 00181 }
arma_hot void glue_times::apply_inplace_plus | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X, | |||
const s32 | sign | |||
) | [inline, static, inherited] |
Definition at line 189 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, apply(), arma_assert_same_size(), Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and partial_unwrap_check< T1 >::val.
Referenced by Mat< eT >::operator+=(), and Mat< eT >::operator-=().
00190 { 00191 arma_extra_debug_sigprint(); 00192 00193 typedef typename T1::elem_type eT; 00194 00195 const partial_unwrap_check<T1> tmp1(X.A, out); 00196 const partial_unwrap_check<T2> tmp2(X.B, out); 00197 00198 const Mat<eT>& A = tmp1.M; 00199 const Mat<eT>& B = tmp2.M; 00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) ); 00201 00202 const bool do_trans_A = tmp1.do_trans; 00203 const bool do_trans_B = tmp2.do_trans; 00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0)); 00205 00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00207 00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00210 00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition"); 00212 00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00214 { 00215 if(A.n_rows == 1) 00216 { 00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00218 } 00219 else 00220 if(B.n_cols == 1) 00221 { 00222 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00223 } 00224 else 00225 { 00226 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); 00227 } 00228 } 00229 else 00230 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00231 { 00232 if(A.n_rows == 1) 00233 { 00234 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00235 } 00236 else 00237 if(B.n_cols == 1) 00238 { 00239 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00240 } 00241 else 00242 { 00243 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); 00244 } 00245 } 00246 else 00247 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00248 { 00249 if(A.n_cols == 1) 00250 { 00251 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00252 } 00253 else 00254 if(B.n_cols == 1) 00255 { 00256 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00257 } 00258 else 00259 { 00260 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); 00261 } 00262 } 00263 else 00264 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00265 { 00266 if(A.n_cols == 1) 00267 { 00268 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00269 } 00270 else 00271 if(B.n_cols == 1) 00272 { 00273 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00274 } 00275 else 00276 { 00277 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); 00278 } 00279 } 00280 else 00281 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00282 { 00283 if(A.n_rows == 1) 00284 { 00285 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00286 } 00287 else 00288 if(B.n_rows == 1) 00289 { 00290 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00291 } 00292 else 00293 { 00294 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); 00295 } 00296 } 00297 else 00298 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00299 { 00300 if(A.n_rows == 1) 00301 { 00302 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00303 } 00304 else 00305 if(B.n_rows == 1) 00306 { 00307 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00308 } 00309 else 00310 { 00311 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); 00312 } 00313 } 00314 else 00315 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00316 { 00317 if(A.n_cols == 1) 00318 { 00319 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00320 } 00321 else 00322 if(B.n_rows == 1) 00323 { 00324 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00325 } 00326 else 00327 { 00328 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); 00329 } 00330 } 00331 else 00332 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00333 { 00334 if(A.n_cols == 1) 00335 { 00336 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00337 } 00338 else 00339 if(B.n_rows == 1) 00340 { 00341 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00342 } 00343 else 00344 { 00345 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); 00346 } 00347 } 00348 00349 00350 }
void glue_times::apply_mixed | ( | Mat< typename promote_type< eT1, eT2 >::result > & | out, | |
const Mat< eT1 > & | X, | |||
const Mat< eT2 > & | Y | |||
) | [inline, static, inherited] |
matrix multiplication with different element types
Definition at line 358 of file glue_times_meat.hpp.
References apply(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by operator*().
00359 { 00360 arma_extra_debug_sigprint(); 00361 00362 typedef typename promote_type<eT1,eT2>::result out_eT; 00363 00364 arma_debug_assert_mul_size(X,Y, "matrix multiply"); 00365 00366 out.set_size(X.n_rows,Y.n_cols); 00367 gemm_mixed<>::apply(out, X, Y); 00368 }
arma_inline u32 glue_times::mul_storage_cost | ( | const Mat< eT > & | A, | |
const Mat< eT > & | B, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B | |||
) | [inline, static, inherited] |
Definition at line 375 of file glue_times_meat.hpp.
References Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by apply().
arma_hot void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 390 of file glue_times_meat.hpp.
References gemm< do_trans_A, do_trans_B, use_alpha, use_beta >::apply(), gemv< do_trans_A, use_alpha, use_beta >::apply(), Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and Mat< eT >::set_size().
00399 { 00400 arma_extra_debug_sigprint(); 00401 00402 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00403 00404 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00405 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00406 00407 out.set_size(final_n_rows, final_n_cols); 00408 00409 // TODO: thoroughly test all combinations 00410 00411 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00412 { 00413 if(A.n_rows == 1) 00414 { 00415 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00416 } 00417 else 00418 if(B.n_cols == 1) 00419 { 00420 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00421 } 00422 else 00423 { 00424 gemm<false, false, false, false>::apply(out, A, B); 00425 } 00426 } 00427 else 00428 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00429 { 00430 if(A.n_rows == 1) 00431 { 00432 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00433 } 00434 else 00435 if(B.n_cols == 1) 00436 { 00437 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00438 } 00439 else 00440 { 00441 gemm<false, false, true, false>::apply(out, A, B, alpha); 00442 } 00443 } 00444 else 00445 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00446 { 00447 if(A.n_cols == 1) 00448 { 00449 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00450 } 00451 else 00452 if(B.n_cols == 1) 00453 { 00454 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00455 } 00456 else 00457 { 00458 gemm<true, false, false, false>::apply(out, A, B); 00459 } 00460 } 00461 else 00462 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00463 { 00464 if(A.n_cols == 1) 00465 { 00466 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00467 } 00468 else 00469 if(B.n_cols == 1) 00470 { 00471 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00472 } 00473 else 00474 { 00475 gemm<true, false, true, false>::apply(out, A, B, alpha); 00476 } 00477 } 00478 else 00479 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00480 { 00481 if(A.n_rows == 1) 00482 { 00483 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00484 } 00485 else 00486 if(B.n_rows == 1) 00487 { 00488 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00489 } 00490 else 00491 { 00492 gemm<false, true, false, false>::apply(out, A, B); 00493 } 00494 } 00495 else 00496 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00497 { 00498 if(A.n_rows == 1) 00499 { 00500 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00501 } 00502 else 00503 if(B.n_rows == 1) 00504 { 00505 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00506 } 00507 else 00508 { 00509 gemm<false, true, true, false>::apply(out, A, B, alpha); 00510 } 00511 } 00512 else 00513 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00514 { 00515 if(A.n_cols == 1) 00516 { 00517 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00518 } 00519 else 00520 if(B.n_rows == 1) 00521 { 00522 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00523 } 00524 else 00525 { 00526 gemm<true, true, false, false>::apply(out, A, B); 00527 } 00528 } 00529 else 00530 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00531 { 00532 if(A.n_cols == 1) 00533 { 00534 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00535 } 00536 else 00537 if(B.n_rows == 1) 00538 { 00539 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00540 } 00541 else 00542 { 00543 gemm<true, true, true, false>::apply(out, A, B, alpha); 00544 } 00545 } 00546 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 554 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00565 { 00566 arma_extra_debug_sigprint(); 00567 00568 Mat<eT> tmp; 00569 00570 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) 00571 { 00572 // out = (A*B)*C 00573 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00574 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); 00575 } 00576 else 00577 { 00578 // out = A*(B*C) 00579 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); 00580 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); 00581 } 00582 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const Mat< eT > & | D, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_trans_D, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 590 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00603 { 00604 arma_extra_debug_sigprint(); 00605 00606 Mat<eT> tmp; 00607 00608 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) 00609 { 00610 // out = (A*B*C)*D 00611 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00612 00613 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); 00614 } 00615 else 00616 { 00617 // out = A*(B*C*D) 00618 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00619 00620 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); 00621 } 00622 }
arma_hot void glue_times_diag::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times_diag > & | X | |||
) | [inline, static, inherited] |
Definition at line 634 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Mat< eT >::at(), Glue< T1, T2, glue_type >::B, Mat< eT >::colptr(), strip_diagmat< T1 >::do_diagmat, unwrap_check< T1 >::M, strip_diagmat< T1 >::M, Mat< eT >::n_cols, diagmat_proxy_check< T1 >::n_elem, Mat< eT >::n_rows, Mat< eT >::set_size(), and Mat< eT >::zeros().
00635 { 00636 arma_extra_debug_sigprint(); 00637 00638 typedef typename T1::elem_type eT; 00639 00640 const strip_diagmat<T1> S1(X.A); 00641 const strip_diagmat<T2> S2(X.B); 00642 00643 typedef typename strip_diagmat<T1>::stored_type T1_stripped; 00644 typedef typename strip_diagmat<T2>::stored_type T2_stripped; 00645 00646 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) 00647 { 00648 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00649 00650 const unwrap_check<T2> tmp(X.B, out); 00651 const Mat<eT>& B = tmp.M; 00652 00653 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply"); 00654 00655 out.set_size(A.n_elem, B.n_cols); 00656 00657 for(u32 col=0; col<B.n_cols; ++col) 00658 { 00659 eT* out_coldata = out.colptr(col); 00660 const eT* B_coldata = B.colptr(col); 00661 00662 for(u32 row=0; row<B.n_rows; ++row) 00663 { 00664 out_coldata[row] = A[row] * B_coldata[row]; 00665 } 00666 } 00667 } 00668 else 00669 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) ) 00670 { 00671 const unwrap_check<T1> tmp(X.A, out); 00672 const Mat<eT>& A = tmp.M; 00673 00674 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00675 00676 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply"); 00677 00678 out.set_size(A.n_rows, B.n_elem); 00679 00680 for(u32 col=0; col<A.n_cols; ++col) 00681 { 00682 const eT val = B[col]; 00683 00684 eT* out_coldata = out.colptr(col); 00685 const eT* A_coldata = A.colptr(col); 00686 00687 for(u32 row=0; row<A.n_rows; ++row) 00688 { 00689 out_coldata[row] = A_coldata[row] * val; 00690 } 00691 } 00692 } 00693 else 00694 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) ) 00695 { 00696 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00697 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00698 00699 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply"); 00700 00701 out.zeros(A.n_elem, A.n_elem); 00702 00703 for(u32 i=0; i<A.n_elem; ++i) 00704 { 00705 out.at(i,i) = A[i] * B[i]; 00706 } 00707 } 00708 }