Glue_times

Classes

struct  depth_lhs< glue_type, T1 >
 Template metaprogram depth_lhs calculates the number of Glue<Tx,Ty, glue_type> instances on the left hand side argument of Glue<Tx,Ty, glue_type> i.e. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes). More...
struct  depth_lhs< glue_type, Glue< T1, T2, glue_type > >
struct  glue_times_redirect< N >
struct  glue_times_redirect< 3 >
struct  glue_times_redirect< 4 >
class  glue_times
 Class which implements the immediate multiplication of two or more matrices. More...
class  glue_times_diag

Functions

template<typename T1 , typename T2 >
static void glue_times_redirect::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X)
template<typename T1 , typename T2 , typename T3 >
static void glue_times_redirect< 3 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< T1, T2, glue_times >, T3, glue_times > &X)
template<typename T1 , typename T2 , typename T3 , typename T4 >
static void glue_times_redirect< 4 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > &X)
template<typename T1 , typename T2 >
static void glue_times::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X)
template<typename T1 >
static void glue_times::apply_inplace (Mat< typename T1::elem_type > &out, const T1 &X)
template<typename T1 , typename T2 >
static arma_hot void glue_times::apply_inplace_plus (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X, const s32 sign)
template<typename eT1 , typename eT2 >
static void glue_times::apply_mixed (Mat< typename promote_type< eT1, eT2 >::result > &out, const Mat< eT1 > &X, const Mat< eT2 > &Y)
 matrix multiplication with different element types
template<typename eT >
static arma_inline u32 glue_times::mul_storage_cost (const Mat< eT > &A, const Mat< eT > &B, const bool do_trans_A, const bool do_trans_B)
template<typename eT >
static arma_hot void glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_scalar_times)
template<typename eT >
static void glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_scalar_times)
template<typename eT >
static void glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const Mat< eT > &D, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_trans_D, const bool do_scalar_times)
template<typename T1 , typename T2 >
static arma_hot void glue_times_diag::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times_diag > &X)

Function Documentation

template<u32 N>
template<typename T1 , typename T2 >
void glue_times_redirect< N >::apply ( Mat< typename T1::elem_type > &  out,
const Glue< T1, T2, glue_times > &  X 
) [inline, static, inherited]

Definition at line 26 of file glue_times_meat.hpp.

References Glue< T1, T2, glue_type >::A, Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.

Referenced by glue_times_redirect< 4 >::apply(), and glue_times_redirect< 3 >::apply().

00027   {
00028   arma_extra_debug_sigprint();
00029   
00030   typedef typename T1::elem_type eT;
00031   
00032   const partial_unwrap_check<T1> tmp1(X.A, out);
00033   const partial_unwrap_check<T2> tmp2(X.B, out);
00034   
00035   const Mat<eT>& A = tmp1.M;
00036   const Mat<eT>& B = tmp2.M;
00037   
00038   const bool do_trans_A = tmp1.do_trans;
00039   const bool do_trans_B = tmp2.do_trans;
00040 
00041   const bool use_alpha = tmp1.do_times | tmp2.do_times;
00042   const eT       alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0);
00043   
00044   glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00045   }

template<typename T1 , typename T2 , typename T3 >
void glue_times_redirect< 3 >::apply ( Mat< typename T1::elem_type > &  out,
const Glue< Glue< T1, T2, glue_times >, T3, glue_times > &  X 
) [inline, static, inherited]

Definition at line 52 of file glue_times_meat.hpp.

References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.

00053   {
00054   arma_extra_debug_sigprint();
00055   
00056   typedef typename T1::elem_type eT;
00057   
00058   // there is exactly 3 objects
00059   // hence we can safely expand X as X.A.A, X.A.B and X.B
00060   
00061   const partial_unwrap_check<T1> tmp1(X.A.A, out);
00062   const partial_unwrap_check<T2> tmp2(X.A.B, out);
00063   const partial_unwrap_check<T3> tmp3(X.B,   out);
00064   
00065   const Mat<eT>& A = tmp1.M;
00066   const Mat<eT>& B = tmp2.M;
00067   const Mat<eT>& C = tmp3.M;
00068   
00069   const bool do_trans_A = tmp1.do_trans;
00070   const bool do_trans_B = tmp2.do_trans;
00071   const bool do_trans_C = tmp3.do_trans;
00072   
00073   const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times;
00074   const eT       alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0);
00075   
00076   glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00077   }

template<typename T1 , typename T2 , typename T3 , typename T4 >
void glue_times_redirect< 4 >::apply ( Mat< typename T1::elem_type > &  out,
const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > &  X 
) [inline, static, inherited]

Definition at line 84 of file glue_times_meat.hpp.

References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.

00085   {
00086   arma_extra_debug_sigprint();
00087   
00088   typedef typename T1::elem_type eT;
00089   
00090   // there is exactly 4 objects
00091   // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
00092   
00093   const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
00094   const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
00095   const partial_unwrap_check<T3> tmp3(X.A.B,   out);
00096   const partial_unwrap_check<T4> tmp4(X.B,     out);
00097   
00098   const Mat<eT>& A = tmp1.M;
00099   const Mat<eT>& B = tmp2.M;
00100   const Mat<eT>& C = tmp3.M;
00101   const Mat<eT>& D = tmp4.M;
00102   
00103   const bool do_trans_A = tmp1.do_trans;
00104   const bool do_trans_B = tmp2.do_trans;
00105   const bool do_trans_C = tmp3.do_trans;
00106   const bool do_trans_D = tmp4.do_trans;
00107   
00108   const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times;
00109   const eT       alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0);
00110   
00111   glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00112   }

template<typename T1 , typename T2 >
void glue_times::apply ( Mat< typename T1::elem_type > &  out,
const Glue< T1, T2, glue_times > &  X 
) [inline, static, inherited]

Definition at line 119 of file glue_times_meat.hpp.

Referenced by apply(), apply_inplace(), apply_inplace_plus(), and apply_mixed().

00120   {
00121   arma_extra_debug_sigprint();
00122 
00123   typedef typename T1::elem_type eT;
00124 
00125   const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00126 
00127   arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00128 
00129   glue_times_redirect<N_mat>::apply(out, X);
00130   }

template<typename T1 >
void glue_times::apply_inplace ( Mat< typename T1::elem_type > &  out,
const T1 &  X 
) [inline, static, inherited]

Definition at line 137 of file glue_times_meat.hpp.

References apply(), Mat< eT >::at(), Mat< eT >::colptr(), unwrap_check< T1 >::M, podarray< eT >::memptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

Referenced by Mat< eT >::operator*=().

00138   {
00139   arma_extra_debug_sigprint();
00140   
00141   typedef typename T1::elem_type eT;
00142   
00143   const unwrap_check<T1> tmp(X, out);
00144   const Mat<eT>& B     = tmp.M;
00145   
00146   arma_debug_assert_mul_size(out, B, "matrix multiply");
00147   
00148   if(out.n_cols == B.n_cols)
00149     {
00150     podarray<eT> tmp(out.n_cols);
00151     eT* tmp_rowdata = tmp.memptr();
00152     
00153     for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00154       {
00155       for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00156         {
00157         tmp_rowdata[out_col] = out.at(out_row,out_col);
00158         }
00159       
00160       for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00161         {
00162         const eT* B_coldata = B.colptr(B_col);
00163         
00164         eT val = eT(0);
00165         for(u32 i=0; i < B.n_rows; ++i)
00166           {
00167           val += tmp_rowdata[i] * B_coldata[i];
00168           }
00169         
00170         out.at(out_row,B_col) = val;
00171         }
00172       }
00173     
00174     }
00175   else
00176     {
00177     const Mat<eT> tmp(out);
00178     glue_times::apply(out, tmp, B, eT(1), false, false, false);
00179     }
00180   
00181   }

template<typename T1 , typename T2 >
arma_hot void glue_times::apply_inplace_plus ( Mat< typename T1::elem_type > &  out,
const Glue< T1, T2, glue_times > &  X,
const s32  sign 
) [inline, static, inherited]

Definition at line 189 of file glue_times_meat.hpp.

References Glue< T1, T2, glue_type >::A, apply(), arma_assert_same_size(), Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and partial_unwrap_check< T1 >::val.

Referenced by Mat< eT >::operator+=(), and Mat< eT >::operator-=().

00190   {
00191   arma_extra_debug_sigprint();
00192   
00193   typedef typename T1::elem_type eT;
00194   
00195   const partial_unwrap_check<T1> tmp1(X.A, out);
00196   const partial_unwrap_check<T2> tmp2(X.B, out);
00197   
00198   const Mat<eT>& A     = tmp1.M;
00199   const Mat<eT>& B     = tmp2.M;
00200   const eT       alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) );
00201   
00202   const bool do_trans_A = tmp1.do_trans;
00203   const bool do_trans_B = tmp2.do_trans;
00204   const bool use_alpha  = tmp1.do_times | tmp2.do_times | (sign < s32(0));
00205   
00206   arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00207   
00208   const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00209   const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00210   
00211   arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition");
00212   
00213   if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00214     {
00215     if(A.n_rows == 1)
00216       {
00217       gemv<true,         false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00218       }
00219     else
00220     if(B.n_cols == 1)
00221       {
00222       gemv<false,        false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00223       }
00224     else
00225       {
00226       gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
00227       }
00228     }
00229   else
00230   if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00231     {
00232     if(A.n_rows == 1)
00233       {
00234       gemv<true,         true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00235       }
00236     else
00237     if(B.n_cols == 1)
00238       {
00239       gemv<false,        true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00240       }
00241     else
00242       {
00243       gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
00244       }
00245     }
00246   else
00247   if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00248     {
00249     if(A.n_cols == 1)
00250       {
00251       gemv<true,        false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00252       }
00253     else
00254     if(B.n_cols == 1)
00255       {
00256       gemv<true,        false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00257       }
00258     else
00259       {
00260       gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
00261       }
00262     }
00263   else
00264   if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00265     {
00266     if(A.n_cols == 1)
00267       {
00268       gemv<true,        true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00269       }
00270     else
00271     if(B.n_cols == 1)
00272       {
00273       gemv<true,        true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00274       }
00275     else
00276       {
00277       gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
00278       }
00279     }
00280   else
00281   if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00282     {
00283     if(A.n_rows == 1)
00284       {
00285       gemv<false,       false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00286       }
00287     else
00288     if(B.n_rows == 1)
00289       {
00290       gemv<false,       false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00291       }
00292     else
00293       {
00294       gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
00295       }
00296     }
00297   else
00298   if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00299     {
00300     if(A.n_rows == 1)
00301       {
00302       gemv<false,       true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00303       }
00304     else
00305     if(B.n_rows == 1)
00306       {
00307       gemv<false,       true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00308       }
00309     else
00310       {
00311       gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
00312       }
00313     }
00314   else
00315   if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00316     {
00317     if(A.n_cols == 1)
00318       {
00319       gemv<false,      false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00320       }
00321     else
00322     if(B.n_rows == 1)
00323       {
00324       gemv<true,       false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00325       }
00326     else
00327       {
00328       gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
00329       }
00330     }
00331   else
00332   if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00333     {
00334     if(A.n_cols == 1)
00335       {
00336       gemv<false,      true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00337       }
00338     else
00339     if(B.n_rows == 1)
00340       {
00341       gemv<true,       true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00342       }
00343     else
00344       {
00345       gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
00346       }
00347     }
00348   
00349   
00350   }

template<typename eT1 , typename eT2 >
void glue_times::apply_mixed ( Mat< typename promote_type< eT1, eT2 >::result > &  out,
const Mat< eT1 > &  X,
const Mat< eT2 > &  Y 
) [inline, static, inherited]

matrix multiplication with different element types

Definition at line 358 of file glue_times_meat.hpp.

References apply(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

Referenced by operator*().

00359   {
00360   arma_extra_debug_sigprint();
00361   
00362   typedef typename promote_type<eT1,eT2>::result out_eT;
00363   
00364   arma_debug_assert_mul_size(X,Y, "matrix multiply");
00365   
00366   out.set_size(X.n_rows,Y.n_cols);
00367   gemm_mixed<>::apply(out, X, Y);
00368   }

template<typename eT >
arma_inline u32 glue_times::mul_storage_cost ( const Mat< eT > &  A,
const Mat< eT > &  B,
const bool  do_trans_A,
const bool  do_trans_B 
) [inline, static, inherited]

Definition at line 375 of file glue_times_meat.hpp.

References Mat< eT >::n_cols, and Mat< eT >::n_rows.

Referenced by apply().

00376   {
00377   const u32 final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00378   const u32 final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00379   
00380   return final_A_n_rows * final_B_n_cols;
00381   }

template<typename eT >
arma_hot void glue_times::apply ( Mat< eT > &  out,
const Mat< eT > &  A,
const Mat< eT > &  B,
const eT  val,
const bool  do_trans_A,
const bool  do_trans_B,
const bool  do_scalar_times 
) [inline, static, inherited]

Definition at line 390 of file glue_times_meat.hpp.

References gemm< do_trans_A, do_trans_B, use_alpha, use_beta >::apply(), gemv< do_trans_A, use_alpha, use_beta >::apply(), Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and Mat< eT >::set_size().

00399   {
00400   arma_extra_debug_sigprint();
00401   
00402   arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply");
00403   
00404   const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00405   const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00406   
00407   out.set_size(final_n_rows, final_n_cols);
00408   
00409   // TODO: thoroughly test all combinations
00410   
00411   if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00412     {
00413     if(A.n_rows == 1)
00414       {
00415       gemv<true,         false, false>::apply(out.memptr(), B, A.memptr());
00416       }
00417     else
00418     if(B.n_cols == 1)
00419       {
00420       gemv<false,        false, false>::apply(out.memptr(), A, B.memptr());
00421       }
00422     else
00423       {
00424       gemm<false, false, false, false>::apply(out, A, B);
00425       }
00426     }
00427   else
00428   if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00429     {
00430     if(A.n_rows == 1)
00431       {
00432       gemv<true,         true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00433       }
00434     else
00435     if(B.n_cols == 1)
00436       {
00437       gemv<false,        true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00438       }
00439     else
00440       {
00441       gemm<false, false, true, false>::apply(out, A, B, alpha);
00442       }
00443     }
00444   else
00445   if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00446     {
00447     if(A.n_cols == 1)
00448       {
00449       gemv<true,        false, false>::apply(out.memptr(), B, A.memptr());
00450       }
00451     else
00452     if(B.n_cols == 1)
00453       {
00454       gemv<true,        false, false>::apply(out.memptr(), A, B.memptr());
00455       }
00456     else
00457       {
00458       gemm<true, false, false, false>::apply(out, A, B);
00459       }
00460     }
00461   else
00462   if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00463     {
00464     if(A.n_cols == 1)
00465       {
00466       gemv<true,        true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00467       }
00468     else
00469     if(B.n_cols == 1)
00470       {
00471       gemv<true,        true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00472       }
00473     else
00474       {
00475       gemm<true, false, true, false>::apply(out, A, B, alpha);
00476       }
00477     }
00478   else
00479   if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00480     {
00481     if(A.n_rows == 1)
00482       {
00483       gemv<false,       false, false>::apply(out.memptr(), B, A.memptr());
00484       }
00485     else
00486     if(B.n_rows == 1)
00487       {
00488       gemv<false,       false, false>::apply(out.memptr(), A, B.memptr());
00489       }
00490     else
00491       {
00492       gemm<false, true, false, false>::apply(out, A, B);
00493       }
00494     }
00495   else
00496   if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00497     {
00498     if(A.n_rows == 1)
00499       {
00500       gemv<false,       true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00501       }
00502     else
00503     if(B.n_rows == 1)
00504       {
00505       gemv<false,       true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00506       }
00507     else
00508       {
00509       gemm<false, true, true, false>::apply(out, A, B, alpha);
00510       }
00511     }
00512   else
00513   if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00514     {
00515     if(A.n_cols == 1)
00516       {
00517       gemv<false,      false, false>::apply(out.memptr(), B, A.memptr());
00518       }
00519     else
00520     if(B.n_rows == 1)
00521       {
00522       gemv<true,       false, false>::apply(out.memptr(), A, B.memptr());
00523       }
00524     else
00525       {
00526       gemm<true, true, false, false>::apply(out, A, B);
00527       }
00528     }
00529   else
00530   if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00531     {
00532     if(A.n_cols == 1)
00533       {
00534       gemv<false,      true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00535       }
00536     else
00537     if(B.n_rows == 1)
00538       {
00539       gemv<true,       true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00540       }
00541     else
00542       {
00543       gemm<true, true, true, false>::apply(out, A, B, alpha);
00544       }
00545     }
00546   }

template<typename eT >
void glue_times::apply ( Mat< eT > &  out,
const Mat< eT > &  A,
const Mat< eT > &  B,
const Mat< eT > &  C,
const eT  val,
const bool  do_trans_A,
const bool  do_trans_B,
const bool  do_trans_C,
const bool  do_scalar_times 
) [inline, static, inherited]

Definition at line 554 of file glue_times_meat.hpp.

References apply(), and mul_storage_cost().

00565   {
00566   arma_extra_debug_sigprint();
00567   
00568   Mat<eT> tmp;
00569   
00570   if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
00571     {
00572     // out = (A*B)*C
00573     glue_times::apply(tmp, A,   B, alpha, do_trans_A, do_trans_B, use_alpha);
00574     glue_times::apply(out, tmp, C, eT(0), false,      do_trans_C, false    );
00575     }
00576   else
00577     {
00578     // out = A*(B*C)
00579     glue_times::apply(tmp, B, C,   alpha, do_trans_B, do_trans_C, use_alpha);
00580     glue_times::apply(out, A, tmp, eT(0), do_trans_A, false,      false    );
00581     }
00582   }

template<typename eT >
void glue_times::apply ( Mat< eT > &  out,
const Mat< eT > &  A,
const Mat< eT > &  B,
const Mat< eT > &  C,
const Mat< eT > &  D,
const eT  val,
const bool  do_trans_A,
const bool  do_trans_B,
const bool  do_trans_C,
const bool  do_trans_D,
const bool  do_scalar_times 
) [inline, static, inherited]

Definition at line 590 of file glue_times_meat.hpp.

References apply(), and mul_storage_cost().

00603   {
00604   arma_extra_debug_sigprint();
00605   
00606   Mat<eT> tmp;
00607   
00608   if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
00609     {
00610     // out = (A*B*C)*D
00611     glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00612     
00613     glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
00614     }
00615   else
00616     {
00617     // out = A*(B*C*D)
00618     glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00619     
00620     glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
00621     }
00622   }

template<typename T1 , typename T2 >
arma_hot void glue_times_diag::apply ( Mat< typename T1::elem_type > &  out,
const Glue< T1, T2, glue_times_diag > &  X 
) [inline, static, inherited]

Definition at line 634 of file glue_times_meat.hpp.

References Glue< T1, T2, glue_type >::A, Mat< eT >::at(), Glue< T1, T2, glue_type >::B, Mat< eT >::colptr(), strip_diagmat< T1 >::do_diagmat, unwrap_check< T1 >::M, strip_diagmat< T1 >::M, Mat< eT >::n_cols, diagmat_proxy_check< T1 >::n_elem, Mat< eT >::n_rows, Mat< eT >::set_size(), and Mat< eT >::zeros().

00635   {
00636   arma_extra_debug_sigprint();
00637   
00638   typedef typename T1::elem_type eT;
00639   
00640   const strip_diagmat<T1> S1(X.A);
00641   const strip_diagmat<T2> S2(X.B);
00642   
00643   typedef typename strip_diagmat<T1>::stored_type T1_stripped;
00644   typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00645   
00646   if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
00647     {
00648     const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00649     
00650     const unwrap_check<T2> tmp(X.B, out);
00651     const Mat<eT>& B     = tmp.M;
00652     
00653     arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply");
00654     
00655     out.set_size(A.n_elem, B.n_cols);
00656     
00657     for(u32 col=0; col<B.n_cols; ++col)
00658       {
00659             eT* out_coldata = out.colptr(col);
00660       const eT* B_coldata   = B.colptr(col);
00661       
00662       for(u32 row=0; row<B.n_rows; ++row)
00663         {
00664         out_coldata[row] = A[row] * B_coldata[row];
00665         }
00666       }
00667     }
00668   else
00669   if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
00670     {
00671     const unwrap_check<T1> tmp(X.A, out);
00672     const Mat<eT>& A     = tmp.M;
00673     
00674     const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00675     
00676     arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply");
00677     
00678     out.set_size(A.n_rows, B.n_elem);
00679     
00680     for(u32 col=0; col<A.n_cols; ++col)
00681       {
00682       const eT  val = B[col];
00683       
00684             eT* out_coldata = out.colptr(col);
00685       const eT*   A_coldata =   A.colptr(col);
00686     
00687       for(u32 row=0; row<A.n_rows; ++row)
00688         {
00689         out_coldata[row] = A_coldata[row] * val;
00690         }
00691       }
00692     }
00693   else
00694   if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
00695     {
00696     const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00697     const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00698     
00699     arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply");
00700     
00701     out.zeros(A.n_elem, A.n_elem);
00702     
00703     for(u32 i=0; i<A.n_elem; ++i)
00704       {
00705       out.at(i,i) = A[i] * B[i];
00706       }
00707     }
00708   }