glue_times_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup glue_times
00017 //! @{
00018 
00019 
00020 
00021 template<typename T1, typename T2>
00022 void
00023 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00024   {
00025   arma_extra_debug_sigprint();
00026 
00027   typedef typename T1::elem_type eT;
00028 
00029   const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00030 
00031   arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00032 
00033   if(N_mat == 2)
00034     {
00035     const unwrap<T1> tmp1(X.A);
00036     const unwrap<T2> tmp2(X.B);
00037     
00038     glue_times::apply(out, tmp1.M, tmp2.M);
00039     }
00040   else
00041     {
00042     // we have at least three matrices
00043 
00044     const Mat<eT>* ptrs[N_mat];
00045     bool            del[N_mat];
00046   
00047     // takes care of any aliasing problems
00048     mat_ptrs_outcheck<glue_times, Glue<T1,T2,glue_times> >::get_ptrs(ptrs, del, X, &out);
00049   
00050     for(s32 i=0; i<N_mat; ++i)  arma_extra_debug_print( arma_boost::format("ptrs[%d] = %x") % i % ptrs[i] );
00051     for(s32 i=0; i<N_mat; ++i)  arma_extra_debug_print( arma_boost::format(" del[%d] = %d") % i %  del[i] );
00052   
00053   
00054     arma_extra_debug_print( arma_boost::format("required size of 'out':  %d, %d") % ptrs[0]->n_rows % ptrs[N_mat-1]->n_cols );
00055   
00056     int order[N_mat];  for(s32 i=0; i<N_mat; ++i)  order[i] = -1;
00057   
00058     int first_id = 0;
00059     int last_id  = N_mat-1;
00060     int starting_id = -1;
00061   
00062     int mat_count = N_mat;
00063   
00064     int largest_size = 0;
00065   
00066     while(mat_count != 0)
00067       {
00068   
00069       for(s32 i=first_id; i != N_mat; ++i)
00070         {
00071         if(order[i] == -1)  { first_id = i; break; }
00072         }
00073   
00074       for(s32 i=last_id; i != -1; --i)
00075         {
00076         if(order[i] == -1)  { last_id = i; break; }
00077         }
00078   
00079       arma_extra_debug_print();
00080       arma_extra_debug_print(arma_boost::format("mat_count = %d") % mat_count );
00081       arma_extra_debug_print(arma_boost::format("first_id  = %d") % first_id  );
00082       arma_extra_debug_print(arma_boost::format("last_id   = %d") % last_id   );
00083   
00084       if(first_id == last_id)  { order[first_id] = 0; starting_id = first_id; break; }
00085   
00086       s32 storage_cost_wo_last  = mul_storage_cost( *ptrs[ first_id   ], *ptrs[ last_id-1 ] );
00087       s32 storage_cost_wo_first = mul_storage_cost( *ptrs[ first_id+1 ], *ptrs[ last_id   ] );
00088   
00089       if(storage_cost_wo_last < storage_cost_wo_first)
00090         {
00091         order[last_id]  = mat_count-1;
00092         if(storage_cost_wo_last > largest_size)  largest_size = storage_cost_wo_last;
00093         }
00094       else
00095         {
00096         order[first_id] = mat_count-1;
00097         if(storage_cost_wo_first > largest_size)  largest_size = storage_cost_wo_first;
00098         }
00099   
00100       arma_extra_debug_print(arma_boost::format("storage_cost_wo_last  = %d") % storage_cost_wo_last  );
00101       arma_extra_debug_print(arma_boost::format("storage_cost_wo_first = %d") % storage_cost_wo_first );
00102   
00103       arma_extra_debug_print("order = ");
00104       for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00105   
00106       --mat_count;
00107       }
00108   
00109     arma_extra_debug_print("final order = ");
00110     for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00111   
00112     arma_extra_debug_print(arma_boost::format("*** largest_size = %d") % largest_size);
00113     arma_extra_debug_print(arma_boost::format("starting_id  = %d") % starting_id);
00114   
00115   
00116     // multiply based on order
00117     // if there are only three matrices, we need only one temporary store:
00118     //   out = a*b*c translates to:  tmp1 = a*b,  out = tmp1*c
00119     //
00120     // if there are four matrices, we need two temporary stores
00121     //   out = a*b*c*d translates to:  tmp1 = a*b, tmp2 = tmp1*c, out = tmp2*d
00122     //
00123     // if there are five matrices, we need two temporary stores
00124     //   out = a*b*c*d*e translates to:  tmp1 = a*b, tmp2 = tmp1*c, tmp1 = tmp2*d, out = tmp1*e
00125     //
00126     // if there are six matrices, we need two temporary stores
00127     //   out = a*b*c*d*e*f translates to:  tmp1 = a*b, tmp2 = tmp1*c, tmp1 = tmp2*d, tmp2 = tmp1*e, out = tmp2*f
00128     //
00129   
00130     
00131     const u32 N_mul = N_mat - 1;
00132     int mul_count = N_mul;
00133     int current_id = starting_id;
00134   
00135     const Mat<eT>* src_mat_1_ptr = ptrs[current_id];
00136     const Mat<eT>* src_mat_2_ptr = 0;
00137   
00138     // TODO:
00139     // allocate two storage areas (of size 'largest_size'), not two matrices
00140   
00141     
00142     Mat<eT> tmp_mat_1;
00143     Mat<eT> tmp_mat_2;
00144     
00145     Mat<eT>* tmp_mat_1_ptr = &tmp_mat_1;
00146     Mat<eT>* tmp_mat_2_ptr = (N_mul <= 2) ? 0 : &tmp_mat_2;
00147   
00148     Mat<eT>* dest_mat_ptr  = tmp_mat_2_ptr;
00149   
00150     arma_extra_debug_print(arma_boost::format("tmp_mat_1_ptr = %x") % tmp_mat_1_ptr );
00151     arma_extra_debug_print(arma_boost::format("tmp_mat_2_ptr = %x") % tmp_mat_2_ptr );
00152     arma_extra_debug_print(arma_boost::format("&out          = %x") % &out );
00153   
00154     while(mul_count != 0)
00155       {
00156       arma_extra_debug_print("");
00157       arma_extra_debug_print("");
00158       arma_extra_debug_print(arma_boost::format("mul_count = %d") % mul_count);
00159   
00160       arma_extra_debug_print("order = ");
00161       for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00162       arma_extra_debug_print("");
00163   
00164       // only one multiplication left, hence destination matrix is the out matrix
00165       if(mul_count == 1)
00166         {
00167         arma_extra_debug_print("dest_mat = &out");
00168         dest_mat_ptr = &out;
00169         }
00170       else
00171         {
00172         if(dest_mat_ptr == tmp_mat_2_ptr)
00173           {
00174           arma_extra_debug_print("dest_mat_ptr = tmp_mat_2_ptr");
00175           dest_mat_ptr = tmp_mat_1_ptr;
00176           }
00177         else
00178           {
00179           arma_extra_debug_print("dest_mat_ptr = tmp_mat_1_ptr");
00180           dest_mat_ptr = tmp_mat_2_ptr;
00181           }
00182         }
00183   
00184       arma_extra_debug_print(arma_boost::format("dest_mat_ptr = %x") % dest_mat_ptr );
00185   
00186       // search on either side of current_pos for a useable value.  unuseable values are equal to -1
00187       s32 left_val = N_mat;
00188       s32 left_id = -1;
00189   
00190       s32 right_val = N_mat;
00191       s32 right_id = -1;
00192   
00193       // go left from current_pos
00194       for(s32 i=current_id-1; i >= 0; --i)
00195         if( order[i] > order[current_id] ) { left_val = order[i]; left_id = i; break; }
00196   
00197       // go right from current_pos
00198       for(s32 i=current_id+1; i < N_mat; ++i)
00199         if( order[current_id] < order[i] ) { right_val = order[i]; right_id = i; break; }
00200   
00201       arma_extra_debug_print("");
00202       arma_extra_debug_print(arma_boost::format("left_id  = %d") % left_id  );
00203       arma_extra_debug_print(arma_boost::format("left_val = %f") % left_val );
00204   
00205       arma_extra_debug_print("");
00206       arma_extra_debug_print(arma_boost::format("right_id  = %d") % right_id  );
00207       arma_extra_debug_print(arma_boost::format("right_val = %f") % right_val );
00208   
00209   
00210       if(left_val < right_val)
00211         {
00212         // a pre-multiply
00213         src_mat_2_ptr = ptrs[left_id];
00214   
00215         arma_extra_debug_print("");
00216         arma_extra_debug_print(arma_boost::format("case pre-multiply with matrix %d") % left_id);
00217         arma_extra_debug_print(arma_boost::format("required destination size: %d, %d  (%d)") %  src_mat_2_ptr->n_rows % src_mat_1_ptr->n_cols % (src_mat_2_ptr->n_rows * src_mat_1_ptr->n_cols) );
00218   
00219         glue_times::apply_noalias(*dest_mat_ptr, *src_mat_2_ptr, *src_mat_1_ptr);
00220   
00221         order[current_id] = -1;
00222         current_id = left_id;
00223         }
00224       else
00225         {
00226         // a post-multiply
00227         src_mat_2_ptr = ptrs[right_id];
00228   
00229         arma_extra_debug_print("");
00230         arma_extra_debug_print(arma_boost::format("case post-multiply with matrix %d") % right_id);
00231         arma_extra_debug_print(arma_boost::format("required destination size: %d, %d  (%d)") % src_mat_1_ptr->n_rows % src_mat_2_ptr->n_cols % (src_mat_1_ptr->n_rows * src_mat_2_ptr->n_cols) );
00232   
00233         glue_times::apply_noalias(*dest_mat_ptr, *src_mat_1_ptr, *src_mat_2_ptr);
00234   
00235         order[current_id] = -1;
00236         current_id = right_id;
00237         }
00238   
00239       // update pointer to source matrix: must point to last multiplication result
00240       src_mat_1_ptr = dest_mat_ptr;
00241   
00242       --mul_count;
00243       }
00244   
00245   
00246     for(s32 i=0; i<N_mat; ++i)
00247       {
00248       if(del[i] == true)
00249         {
00250         arma_extra_debug_print(arma_boost::format("delete mat_ptr[%d]") % i );
00251         delete ptrs[i];
00252         }
00253       }
00254     }
00255   }
00256 
00257 
00258 
00259 template<typename T1>
00260 inline
00261 void
00262 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00263   {
00264   arma_extra_debug_sigprint();
00265   
00266   typedef typename T1::elem_type eT;
00267   
00268   const unwrap<T1>   tmp(X);
00269   const Mat<eT>& B = tmp.M;
00270   
00271   arma_debug_assert_mul_size(out, B, "matrix multiply");
00272   
00273   if(out.n_cols == B.n_cols)
00274     {
00275     podarray<eT> tmp(out.n_cols);
00276     eT* tmp_rowdata = tmp.memptr();
00277     
00278     for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00279       {
00280       for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00281         {
00282         tmp_rowdata[out_col] = out.at(out_row,out_col);
00283         }
00284       
00285       for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00286         {
00287         const eT* B_coldata = B.colptr(B_col);
00288         
00289         eT val = eT(0);
00290         for(u32 i=0; i < B.n_rows; ++i)
00291           {
00292           val += tmp_rowdata[i] * B_coldata[i];
00293           }
00294         
00295         out.at(out_row,B_col) = val;
00296         }
00297       }
00298     
00299     }
00300   else
00301     {
00302     const Mat<eT> tmp(out);
00303     glue_times::apply(out, tmp, B);
00304     }
00305   
00306   }
00307 
00308 
00309 
00310 //! matrix multiplication with different element types
00311 template<typename eT1, typename eT2>
00312 inline
00313 void
00314 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00315   {
00316   arma_extra_debug_sigprint();
00317   
00318   typedef typename promote_type<eT1,eT2>::result out_eT;
00319   
00320   arma_debug_assert_mul_size(X,Y, "matrix multiply");
00321   
00322   out.set_size(X.n_rows,Y.n_cols);
00323   gemm_mixed<>::apply(out, X, Y);
00324   }
00325 
00326 
00327 
00328 template<typename eT>
00329 arma_inline
00330 u32 glue_times::mul_storage_cost(const Mat<eT>& X, const Mat<eT>& Y)
00331   {
00332   return X.n_rows * Y.n_cols;
00333   }
00334 
00335 
00336 
00337 //! multiply matrices A and B, storing the result in 'out'
00338 //! assumes that A and B are not aliases of 'out'
00339 template<typename eT>
00340 inline
00341 void
00342 glue_times::apply_noalias(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B)
00343   {
00344   arma_extra_debug_sigprint();
00345   
00346   arma_debug_assert_mul_size(A, B, "matrix multiply");
00347   
00348   out.set_size(A.n_rows,B.n_cols);
00349   gemm<>::apply(out,A,B);
00350   }
00351 
00352 
00353 
00354 template<typename eT>
00355 inline
00356 void
00357 glue_times::apply(Mat<eT>& out, const Mat<eT>& A_in, const Mat<eT>& B_in)
00358   {
00359   arma_extra_debug_sigprint();
00360   
00361   if( (&out != &A_in) && (&out != &B_in) )
00362     {
00363     glue_times::apply_noalias(out,A_in,B_in);
00364     }
00365   else
00366     {
00367     
00368     if( (&out == &A_in) && (&out != &B_in) )
00369       {
00370       Mat<eT> A_copy(A_in);
00371       glue_times::apply_noalias(out,A_copy,B_in);
00372       }
00373     else
00374     if( (&out != &A_in) && (&out == &B_in) )
00375       {
00376       Mat<eT> B_copy(B_in);
00377       glue_times::apply_noalias(out,A_in,B_copy);
00378       }
00379     else
00380     if( (&out == &A_in) && (&out == &B_in) )
00381       {
00382       Mat<eT> tmp(A_in);
00383       glue_times::apply_noalias(out,tmp,tmp);
00384       }
00385 
00386     }
00387     
00388   }
00389 
00390 
00391 template<typename eT>
00392 inline
00393 void
00394 glue_times::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
00395   {
00396   arma_extra_debug_sigprint();
00397 
00398   arma_debug_assert_mul_size(A, B, "matrix multiply");
00399   arma_debug_assert_mul_size(B, C, "matrix multiply");
00400   
00401   if( mul_storage_cost(A,B) <= mul_storage_cost(B,C) )
00402     {
00403     Mat<eT> tmp;
00404     glue_times::apply_noalias(tmp, A, B);
00405     
00406     if(&out != &C)
00407       {
00408       glue_times::apply_noalias(out, tmp, C);
00409       }
00410     else
00411       {
00412       Mat<eT> C_copy = C;
00413       glue_times::apply_noalias(out, tmp, C_copy);
00414       }
00415       
00416     }
00417   else
00418     {
00419     Mat<eT> tmp;
00420     glue_times::apply_noalias(tmp, B, C);
00421     
00422     if(&out != &A)
00423       {
00424       glue_times::apply_noalias(out, A, tmp);
00425       }
00426     else
00427       {
00428       Mat<eT> A_copy = A;
00429       glue_times::apply_noalias(out, A_copy, tmp);
00430       }
00431     }
00432   
00433   }
00434 
00435 
00436 
00437 template<typename eT>
00438 inline
00439 eT
00440 glue_times::direct_rowvec_mat_colvec
00441   (
00442   const eT*      A_mem,
00443   const Mat<eT>& B,
00444   const eT*      C_mem
00445   )
00446   {
00447   arma_extra_debug_sigprint();
00448   
00449   const u32 cost_AB = B.n_cols;
00450   const u32 cost_BC = B.n_rows;
00451   
00452   if(cost_AB <= cost_BC)
00453     {
00454     podarray<eT> tmp(B.n_cols);
00455     
00456     for(u32 col=0; col<B.n_cols; ++col)
00457       {
00458       const eT* B_coldata = B.colptr(col);
00459       
00460       eT val = eT(0);
00461       for(u32 i=0; i<B.n_rows; ++i)
00462         {
00463         val += A_mem[i] * B_coldata[i];
00464         }
00465         
00466       tmp[col] = val;
00467       }
00468     
00469     return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
00470     }
00471   else
00472     {
00473     podarray<eT> tmp(B.n_rows);
00474     
00475     for(u32 row=0; row<B.n_rows; ++row)
00476       {
00477       eT val = eT(0);
00478       for(u32 col=0; col<B.n_cols; ++col)
00479         {
00480         val += B.at(row,col) * C_mem[col];
00481         }
00482       
00483       tmp[row] = val;
00484       }
00485     
00486     return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
00487     }
00488   
00489   
00490   }
00491 
00492 
00493 
00494 template<typename eT>
00495 inline
00496 eT
00497 glue_times::direct_rowvec_diagmat_colvec
00498   (
00499   const eT*      A_mem,
00500   const Mat<eT>& B,
00501   const eT*      C_mem
00502   )
00503   {
00504   arma_extra_debug_sigprint();
00505   
00506   eT val = eT(0);
00507 
00508   for(u32 i=0; i<B.n_rows; ++i)
00509     {
00510     val += A_mem[i] * B.at(i,i) * C_mem[i];
00511     }
00512 
00513   return val;
00514   }
00515 
00516 
00517 
00518 template<typename eT>
00519 inline
00520 eT
00521 glue_times::direct_rowvec_invdiagmat_colvec
00522   (
00523   const eT*      A_mem,
00524   const Mat<eT>& B,
00525   const eT*      C_mem
00526   )
00527   {
00528   arma_extra_debug_sigprint();
00529   
00530   eT val = eT(0);
00531 
00532   for(u32 i=0; i<B.n_rows; ++i)
00533     {
00534     val += (A_mem[i] * C_mem[i]) / B.at(i,i);
00535     }
00536 
00537   return val;
00538   }
00539 
00540 
00541 
00542 template<typename eT>
00543 inline
00544 eT
00545 glue_times::direct_rowvec_invdiagvec_colvec
00546   (
00547   const eT*      A_mem,
00548   const Mat<eT>& B,
00549   const eT*      C_mem
00550   )
00551   {
00552   arma_extra_debug_sigprint();
00553   
00554   const eT* B_mem = B.mem;
00555   
00556   eT val = eT(0);
00557 
00558   for(u32 i=0; i<B.n_elem; ++i)
00559     {
00560     val += (A_mem[i] * C_mem[i]) / B_mem[i];
00561     }
00562 
00563   return val;
00564   }
00565 
00566 
00567 
00568 #if defined(ARMA_GOOD_COMPILER)
00569 
00570 
00571 template<typename eT>
00572 inline
00573 void
00574 glue_times::apply(Mat<eT>& out, const Glue<Mat<eT>,Mat<eT>,glue_times>& X)
00575   {
00576   glue_times::apply(out, X.A, X.B);
00577   }
00578 
00579 
00580 
00581 template<typename eT>
00582 inline
00583 void
00584 glue_times::apply(Mat<eT>& out, const Glue< Glue<Mat<eT>,Mat<eT>, glue_times>, Mat<eT>, glue_times>& X)
00585   {
00586   glue_times::apply(out, X.A.A, X.A.B, X.B);
00587   }
00588 
00589 
00590 
00591 //! out = T1 * trans(T2)
00592 template<typename T1, typename T2>
00593 inline
00594 void
00595 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_trans>, glue_times>& X)
00596   {
00597   arma_extra_debug_sigprint();
00598   
00599   
00600   typedef typename T1::elem_type eT;
00601   
00602   // checks for aliases are done later
00603   
00604   const unwrap<T1> tmp1(X.A);
00605   const unwrap<T2> tmp2(X.B.m);
00606   
00607   const Mat<eT>& A = tmp1.M;
00608   const Mat<eT>& B = tmp2.M;
00609   
00610   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "matrix multiply");
00611     
00612   if( (A.n_rows*B.n_rows) > 0)
00613     {
00614     if(&A != &B)   // A*B'
00615       {
00616       unwrap_check< Mat<eT> > A_safe_tmp(A, out);
00617       unwrap_check< Mat<eT> > B_safe_tmp(B, out);
00618       
00619       const Mat<eT>& A_safe = A_safe_tmp.M;
00620       const Mat<eT>& B_safe = B_safe_tmp.M;
00621       
00622       out.set_size(A_safe.n_rows, B_safe.n_rows);
00623       
00624       gemm<false,true>::apply(out, A, B);
00625       }
00626     else   // A*A'
00627       {
00628       arma_extra_debug_print("glue_times::apply(): detected A*A'");
00629       
00630       Mat<eT> tmp;
00631       op_trans::apply(tmp,A);
00632       
00633       // no aliasing problem
00634       out.set_size(A.n_rows, A.n_rows);
00635       
00636       for(u32 row=0; row != A.n_rows; ++row)
00637         {
00638         for(u32 col=0; col <= row; ++col)
00639           {
00640           const eT* coldata1 = tmp.colptr(row);
00641           const eT* coldata2 = tmp.colptr(col);
00642         
00643           eT val = eT(0);
00644           for(u32 i=0; i < tmp.n_rows; ++i)
00645             {
00646             val += coldata1[i] * coldata2[i];
00647             }
00648           
00649           out.at(row,col) = val;
00650           out.at(col,row) = val;
00651           }
00652         }
00653         
00654       }
00655     
00656     }
00657   
00658   }
00659 
00660 
00661 
00662 //! out = trans(T1) * T2
00663 template<typename T1, typename T2>
00664 inline
00665 void
00666 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, T2, glue_times>& X)
00667   {
00668   arma_extra_debug_sigprint();
00669   
00670   typedef typename T1::elem_type eT;
00671   
00672   const unwrap_check<T1> tmp1(X.A.m, out);
00673   const unwrap_check<T2> tmp2(X.B,   out);
00674   
00675   const Mat<eT>& A = tmp1.M;
00676   const Mat<eT>& B = tmp2.M;
00677   
00678   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiply");
00679   
00680   if( (A.n_cols*B.n_cols) > 0 )
00681     {
00682     out.set_size(A.n_cols, B.n_cols);
00683     
00684     gemm<true,false>::apply(out, A, B);
00685     }
00686     
00687   }
00688 
00689 
00690 
00691 //! out = trans(T1) * trans(T2)
00692 template<typename T1, typename T2>
00693 inline
00694 void
00695 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, Op<T2,op_trans>, glue_times>& X)
00696   {
00697   arma_extra_debug_sigprint();
00698   
00699   typedef typename T1::elem_type eT;
00700   
00701   const unwrap_check<T1> tmp1(X.A.m, out);
00702   const unwrap_check<T2> tmp2(X.B.m, out);
00703   
00704   const Mat<eT>& A = tmp1.M;
00705   const Mat<eT>& B = tmp2.M;
00706   
00707   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_cols, B.n_rows, "matrix multiply");
00708   
00709   if( (A.n_cols*B.n_rows) > 0 )
00710     {
00711     out.set_size(A.n_cols, B.n_rows);
00712     
00713     gemm<true,true>::apply(out, A, B);
00714     
00715     }
00716     
00717   }
00718 
00719 
00720 
00721 
00722 //! out = -T1 * T2
00723 template<typename T1, typename T2>
00724 inline
00725 void
00726 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1, op_neg>, T2, glue_times>& X)
00727   {
00728   arma_extra_debug_sigprint();
00729   
00730   typedef typename T1::elem_type eT;
00731   
00732   const unwrap_check<T1> tmp1(X.A.m, out);
00733   const unwrap_check<T2> tmp2(X.B,   out);
00734   
00735   const Mat<eT>& A = tmp1.M;
00736   const Mat<eT>& B = tmp2.M;
00737 
00738   glue_times::apply(out, A, B);
00739   
00740   const u32 n_elem = out.n_elem;
00741   for(u32 i=0; i<n_elem; ++i)
00742     {
00743     out[i] = -out[i];
00744     }
00745   }
00746 
00747 
00748 
00749 template<typename T1, typename T2>
00750 inline
00751 void
00752 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X)
00753   {
00754   arma_extra_debug_sigprint();
00755   
00756   out = out * X;
00757   }
00758 
00759 
00760 
00761 #endif
00762 
00763 
00764 
00765 //
00766 // glue_times_diag
00767 
00768 
00769 template<typename T1, typename T2>
00770 inline
00771 void
00772 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const T1& A_orig, const Op<T2,op_diagmat>& B_orig)
00773   {
00774   arma_extra_debug_sigprint();
00775     
00776   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00777   
00778   const unwrap_check<T1> tmp1(A_orig,   out);
00779   const unwrap_check<T2> tmp2(B_orig.m, out);
00780   
00781   typedef typename T1::elem_type eT;
00782   
00783   const Mat<eT>& A = tmp1.M;
00784   const Mat<eT>& B = tmp2.M;
00785   
00786   arma_debug_check( (B.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00787   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00788   
00789   out.set_size(A.n_rows, B.n_cols);
00790   
00791   for(u32 col=0; col<A.n_cols; ++col)
00792     {
00793     const eT val = B.at(col,col);
00794     
00795     const eT*   A_coldata =   A.colptr(col);
00796           eT* out_coldata = out.colptr(col);
00797     
00798     for(u32 row=0; row<B.n_rows; ++row)
00799       {
00800       out_coldata[row] = A_coldata[row] * val;
00801       }
00802     
00803     }
00804   
00805   }
00806 
00807 
00808 
00809 template<typename T1, typename T2>
00810 inline
00811 void
00812 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const T2& B_orig)
00813   {
00814   arma_extra_debug_sigprint();
00815   
00816   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00817   
00818   const unwrap_check<T1> tmp1(A_orig.m, out);
00819   const unwrap_check<T2> tmp2(B_orig,   out);
00820   
00821   typedef typename T1::elem_type eT;
00822   
00823   const Mat<eT>& A = tmp1.M;
00824   const Mat<eT>& B = tmp2.M;
00825   
00826   arma_debug_check( (A.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00827   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00828   
00829   out.set_size(A.n_rows, B.n_cols);
00830   
00831   
00832   for(u32 col=0; col<A.n_cols; ++col)
00833     {
00834     const eT*   B_coldata =   B.colptr(col);
00835           eT* out_coldata = out.colptr(col);
00836     
00837     for(u32 row=0; row<B.n_rows; ++row)
00838       {
00839       out_coldata[row] = A.at(row,row) * B_coldata[row];
00840       }
00841     
00842     }
00843 
00844   }
00845 
00846 
00847 
00848 template<typename T1, typename T2>
00849 inline
00850 void
00851 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const Op<T2,op_diagmat>& B_orig)
00852   {
00853   arma_extra_debug_sigprint();
00854   
00855   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00856   
00857   unwrap_check<T1> tmp1(A_orig.m, out);
00858   unwrap_check<T2> tmp2(B_orig.m, out);
00859   
00860   typedef typename T1::elem_type eT;
00861   
00862   const Mat<eT>& A = tmp1.M;
00863   const Mat<eT>& B = tmp2.M;
00864   
00865   arma_debug_check( !A.is_square() || !B.is_square(), "glue_times_diag::apply(): incompatible matrix dimensions" );
00866   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00867   
00868   out.zeros(A.n_rows, B.n_cols);
00869   
00870   for(u32 i=0; i<A.n_rows; ++i)
00871     {
00872     out.at(i,i) = A.at(i,i) * B.at(i,i);
00873     }
00874   }
00875 
00876 
00877 
00878 template<typename T1, typename T2>
00879 inline
00880 void 
00881 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_diagmat>, glue_times_diag>& X)
00882   {
00883   glue_times_diag::apply(out, X.A, X.B);
00884   }
00885 
00886 
00887 
00888 template<typename T1, typename T2>
00889 inline
00890 void
00891 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, T2, glue_times_diag>& X)
00892   {
00893   glue_times_diag::apply(out, X.A, X.B);
00894   }
00895 
00896 
00897 
00898 template<typename T1, typename T2>
00899 inline
00900 void
00901 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, Op<T2,op_diagmat>, glue_times_diag>& X)
00902   {
00903   glue_times_diag::apply(out, X.A, X.B);
00904   }
00905 
00906 
00907 
00908 //
00909 // glue_times_vec
00910 
00911 
00912 //! at least one of T1 and T2 is a vector (both could be vectors)
00913 template<typename T1, typename T2>
00914 inline
00915 void
00916 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_vec>& X)
00917   {
00918   arma_extra_debug_sigprint();
00919   
00920   typedef typename T1::elem_type eT;
00921   
00922   unwrap_check<T1> tmp1(X.A, out);
00923   unwrap_check<T2> tmp2(X.B, out);
00924   
00925   const Mat<eT>& A = tmp1.M;
00926   const Mat<eT>& B = tmp2.M;
00927   
00928   arma_debug_assert_mul_size(A, B, "vector multiply");
00929   
00930   // col * row  --> outer product
00931   // mat * row  --> only makes sense if mat is a col vector, hence equiv to col * row
00932   // col * mat  --> only makes sense if mat is a row vector, hence equiv to col * row
00933   
00934   // row * col  --> dot product
00935   // row * mat  --> ok
00936   
00937   // mat * col  --> ok
00938   
00939   out.set_size(A.n_rows, B.n_cols);
00940   
00941   if(A.n_cols == 1)    // A is a column vector
00942     {
00943     glue_times_vec::mul_col_row(out, A.mem, B.mem);
00944     }
00945   else
00946     {
00947     if(A.n_rows == 1)  // A is a row vector
00948       {
00949       if(B.n_cols == 1)
00950         {
00951         out[0] = op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00952         }
00953       else
00954         {
00955         gemv<true>::apply(out.memptr(), B, A.mem);
00956         }
00957       }
00958     else               // A is a matrix
00959       {
00960       gemv<>::apply(out.memptr(), A, B.mem);
00961       }
00962     }
00963   
00964   }
00965 
00966 
00967 
00968 template<typename eT>
00969 inline
00970 void
00971 glue_times_vec::mul_col_row(Mat<eT>& out, const eT* A, const eT* B)
00972   {
00973   const u32 n_rows = out.n_rows;
00974   const u32 n_cols = out.n_cols;
00975   
00976   for(u32 col=0; col < n_cols; ++col)
00977     {
00978     const eT val = B[col];
00979     
00980     eT* out_coldata = out.colptr(col);
00981     
00982     for(u32 row=0; row < n_rows; ++row)
00983       {
00984       out_coldata[row] = A[row] * val;
00985       }
00986     }
00987   
00988   }
00989 
00990 
00991 
00992 template<typename eT>
00993 inline
00994 void
00995 glue_times_vec::mul_col_row_inplace_add(Mat<eT>& out, const eT* A, const eT* B)
00996   {
00997   const u32 n_rows = out.n_rows;
00998   const u32 n_cols = out.n_cols;
00999   
01000   for(u32 col=0; col < n_cols; ++col)
01001     {
01002     const eT val = B[col];
01003     
01004     eT* out_coldata = out.colptr(col);
01005     
01006     for(u32 row=0; row < n_rows; ++row)
01007       {
01008       out_coldata[row] += A[row] * val;
01009       }
01010     }
01011   
01012   }
01013 
01014 
01015 
01016 #if defined(ARMA_GOOD_COMPILER)
01017 
01018 
01019 
01020 template<typename eT>
01021 inline
01022 void
01023 glue_times_vec::apply(Mat<eT>& out, const Glue<Col<eT>,Row<eT>,glue_times_vec>& X)
01024   {
01025   arma_extra_debug_sigprint();
01026   
01027   unwrap_check< Col<eT> > tmp1(X.A, out);
01028   unwrap_check< Row<eT> > tmp2(X.B, out);
01029   
01030   const Col<eT>& A = tmp1.M;
01031   const Row<eT>& B = tmp2.M;
01032   
01033   arma_debug_assert_mul_size(A, B, "vector multiply");
01034   
01035   out.set_size(A.n_rows, B.n_cols);
01036   
01037   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01038   }
01039 
01040 
01041 
01042 template<typename eT>
01043 inline
01044 void
01045 glue_times_vec::apply(Mat<eT>& out, const Glue< Op<Row<eT>, op_trans>, Row<eT>, glue_times_vec>& X)
01046   {
01047   arma_extra_debug_sigprint();
01048   
01049   unwrap_check< Row<eT> > tmp1(X.A.m, out);
01050   unwrap_check< Row<eT> > tmp2(X.B,   out);
01051   
01052   const Row<eT>& A = tmp1.M;
01053   const Row<eT>& B = tmp2.M;
01054   
01055   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01056   
01057   out.set_size(A.n_cols, B.n_cols);
01058   
01059   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01060   }
01061 
01062 
01063 
01064 template<typename eT>
01065 inline
01066 void
01067 glue_times_vec::apply(Mat<eT>& out, const Glue< Col<eT>, Op<Col<eT>, op_trans>, glue_times_vec>& X)
01068   {
01069   arma_extra_debug_sigprint();
01070   
01071   unwrap_check< Col<eT> > tmp1(X.A,   out);
01072   unwrap_check< Col<eT> > tmp2(X.B.m, out);
01073   
01074   const Col<eT>& A = tmp1.M;
01075   const Col<eT>& B = tmp2.M;
01076   
01077   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "vector multiply");
01078   
01079   out.set_size(A.n_rows, B.n_rows);
01080   
01081   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01082   }
01083 
01084 
01085 
01086 template<typename T1>
01087 inline
01088 void
01089 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1, op_trans>, Col<typename T1::elem_type>,glue_times_vec>& X)
01090   {
01091   arma_extra_debug_sigprint();
01092   
01093   typedef typename T1::elem_type eT;
01094   
01095   unwrap_check< T1 >      tmp1(X.A.m, out);
01096   unwrap_check< Col<eT> > tmp2(X.B,   out);
01097   
01098   const Mat<eT>& A = tmp1.M;
01099   const Col<eT>& B = tmp2.M;
01100 
01101   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01102   
01103   out.set_size(A.n_cols, B.n_cols);
01104   
01105 //         eT* out_mem = out.memptr();
01106 //   const eT* B_mem   = B.mem;
01107 //   
01108 //   const u32 A_n_cols = A.n_cols;
01109 //   const u32 B_n_rows = B.n_rows;
01110 //   
01111 //   for(u32 col=0; col < A_n_cols; ++col)
01112 //     {
01113 //     const eT* A_col = A.colptr(col);
01114 //     
01115 //     eT val = eT(0);
01116 //     for(u32 row=0; row<B_n_rows; ++row)
01117 //       {
01118 //       val += A_col[row] * B_mem[row];
01119 //       }
01120 //     
01121 //     out_mem[col] = val;
01122 //     }
01123   
01124   gemv<true>::apply(out.memptr(), A, B.mem);
01125   }
01126 
01127 
01128 
01129 #endif
01130 
01131 
01132 
01133 //! @}