gemm_mixed.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup gemm_mixed
00018 //! @{
00019 
00020 
00021 
00022 //! \brief
00023 //! Matrix multplication where the matrices have different element types.
00024 //! Uses caching for speedup.
00025 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00026 
00027 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00028 class gemm_mixed_cache
00029   {
00030   public:
00031   
00032   template<typename out_eT, typename in_eT1, typename in_eT2>
00033   arma_hot
00034   inline
00035   static
00036   void
00037   apply
00038     (
00039           Mat<out_eT>& C,
00040     const Mat<in_eT1>& A,
00041     const Mat<in_eT2>& B,
00042     const out_eT alpha = out_eT(1),
00043     const out_eT beta  = out_eT(0)
00044     )
00045     {
00046     arma_extra_debug_sigprint();
00047     
00048     const u32 A_n_rows = A.n_rows;
00049     const u32 A_n_cols = A.n_cols;
00050     
00051     const u32 B_n_rows = B.n_rows;
00052     const u32 B_n_cols = B.n_cols;
00053     
00054     if( (do_trans_A == false) && (do_trans_B == false) )
00055       {
00056       podarray<in_eT1> tmp(A_n_cols);
00057       in_eT1* A_rowdata = tmp.memptr();
00058       
00059       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00060         {
00061         
00062         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00063           {
00064           A_rowdata[col_A] = A.at(row_A,col_A);
00065           }
00066         
00067         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00068           {
00069           const in_eT2* B_coldata = B.colptr(col_B);
00070           
00071           out_eT acc = out_eT(0);
00072           for(u32 i=0; i < B_n_rows; ++i)
00073             {
00074             acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00075             }
00076         
00077           if( (use_alpha == false) && (use_beta == false) )
00078             {
00079             C.at(row_A,col_B) = acc;
00080             }
00081           else
00082           if( (use_alpha == true) && (use_beta == false) )
00083             {
00084             C.at(row_A,col_B) = alpha * acc;
00085             }
00086           else
00087           if( (use_alpha == false) && (use_beta == true) )
00088             {
00089             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00090             }
00091           else
00092           if( (use_alpha == true) && (use_beta == true) )
00093             {
00094             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00095             }
00096           
00097           }
00098         }
00099       }
00100     else
00101     if( (do_trans_A == true) && (do_trans_B == false) )
00102       {
00103       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00104         {
00105         // col_A is interpreted as row_A when storing the results in matrix C
00106         
00107         const in_eT1* A_coldata = A.colptr(col_A);
00108         
00109         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00110           {
00111           const in_eT2* B_coldata = B.colptr(col_B);
00112           
00113           out_eT acc = out_eT(0);
00114           for(u32 i=0; i < B_n_rows; ++i)
00115             {
00116             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00117             }
00118         
00119           if( (use_alpha == false) && (use_beta == false) )
00120             {
00121             C.at(col_A,col_B) = acc;
00122             }
00123           else
00124           if( (use_alpha == true) && (use_beta == false) )
00125             {
00126             C.at(col_A,col_B) = alpha * acc;
00127             }
00128           else
00129           if( (use_alpha == false) && (use_beta == true) )
00130             {
00131             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00132             }
00133           else
00134           if( (use_alpha == true) && (use_beta == true) )
00135             {
00136             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00137             }
00138           
00139           }
00140         }
00141       }
00142     else
00143     if( (do_trans_A == false) && (do_trans_B == true) )
00144       {
00145       Mat<in_eT2> B_tmp = trans(B);
00146       gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00147       }
00148     else
00149     if( (do_trans_A == true) && (do_trans_B == true) )
00150       {
00151       // mat B_tmp = trans(B);
00152       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00153       
00154       
00155       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00156       // transpose operations are not needed
00157       
00158       podarray<in_eT2> tmp(B.n_cols);
00159       in_eT2* B_rowdata = tmp.memptr();
00160       
00161       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00162         {
00163         
00164         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00165           {
00166           B_rowdata[col_B] = B.at(row_B,col_B);
00167           }
00168         
00169         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00170           {
00171           const in_eT1* A_coldata = A.colptr(col_A);
00172           
00173           out_eT acc = out_eT(0);
00174           for(u32 i=0; i < A_n_rows; ++i)
00175             {
00176             acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00177             }
00178         
00179           if( (use_alpha == false) && (use_beta == false) )
00180             {
00181             C.at(col_A,row_B) = acc;
00182             }
00183           else
00184           if( (use_alpha == true) && (use_beta == false) )
00185             {
00186             C.at(col_A,row_B) = alpha * acc;
00187             }
00188           else
00189           if( (use_alpha == false) && (use_beta == true) )
00190             {
00191             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00192             }
00193           else
00194           if( (use_alpha == true) && (use_beta == true) )
00195             {
00196             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00197             }
00198           
00199           }
00200         }
00201       
00202       }
00203     }
00204     
00205   };
00206 
00207 
00208 
00209 //! Matrix multplication where the matrices have different element types.
00210 //! Simple version (no caching).
00211 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00212 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00213 class gemm_mixed_simple
00214   {
00215   public:
00216   
00217   template<typename out_eT, typename in_eT1, typename in_eT2>
00218   arma_hot
00219   inline
00220   static
00221   void
00222   apply
00223     (
00224           Mat<out_eT>& C,
00225     const Mat<in_eT1>& A,
00226     const Mat<in_eT2>& B,
00227     const out_eT alpha = out_eT(1),
00228     const out_eT beta  = out_eT(0)
00229     )
00230     {
00231     arma_extra_debug_sigprint();
00232     
00233     const u32 A_n_rows = A.n_rows;
00234     const u32 A_n_cols = A.n_cols;
00235     
00236     const u32 B_n_rows = B.n_rows;
00237     const u32 B_n_cols = B.n_cols;
00238     
00239     if( (do_trans_A == false) && (do_trans_B == false) )
00240       {
00241       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00242         {
00243         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00244           {
00245           const in_eT2* B_coldata = B.colptr(col_B);
00246           
00247           out_eT acc = out_eT(0);
00248           for(u32 i = 0; i < B_n_rows; ++i)
00249             {
00250             const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00251             const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00252             acc += val1 * val2;
00253             //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00254             }
00255           
00256           if( (use_alpha == false) && (use_beta == false) )
00257             {
00258             C.at(row_A,col_B) = acc;
00259             }
00260           else
00261           if( (use_alpha == true) && (use_beta == false) )
00262             {
00263             C.at(row_A,col_B) = alpha * acc;
00264             }
00265           else
00266           if( (use_alpha == false) && (use_beta == true) )
00267             {
00268             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00269             }
00270           else
00271           if( (use_alpha == true) && (use_beta == true) )
00272             {
00273             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00274             }
00275           }
00276         }
00277       }
00278     else
00279     if( (do_trans_A == true) && (do_trans_B == false) )
00280       {
00281       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00282         {
00283         // col_A is interpreted as row_A when storing the results in matrix C
00284         
00285         const in_eT1* A_coldata = A.colptr(col_A);
00286         
00287         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00288           {
00289           const in_eT2* B_coldata = B.colptr(col_B);
00290           
00291           out_eT acc = out_eT(0);
00292           for(u32 i=0; i < B_n_rows; ++i)
00293             {
00294             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00295             }
00296         
00297           if( (use_alpha == false) && (use_beta == false) )
00298             {
00299             C.at(col_A,col_B) = acc;
00300             }
00301           else
00302           if( (use_alpha == true) && (use_beta == false) )
00303             {
00304             C.at(col_A,col_B) = alpha * acc;
00305             }
00306           else
00307           if( (use_alpha == false) && (use_beta == true) )
00308             {
00309             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00310             }
00311           else
00312           if( (use_alpha == true) && (use_beta == true) )
00313             {
00314             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00315             }
00316           
00317           }
00318         }
00319       }
00320     else
00321     if( (do_trans_A == false) && (do_trans_B == true) )
00322       {
00323       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00324         {
00325         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00326           {
00327           out_eT acc = out_eT(0);
00328           for(u32 i = 0; i < B_n_cols; ++i)
00329             {
00330             acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00331             }
00332           
00333           if( (use_alpha == false) && (use_beta == false) )
00334             {
00335             C.at(row_A,row_B) = acc;
00336             }
00337           else
00338           if( (use_alpha == true) && (use_beta == false) )
00339             {
00340             C.at(row_A,row_B) = alpha * acc;
00341             }
00342           else
00343           if( (use_alpha == false) && (use_beta == true) )
00344             {
00345             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00346             }
00347           else
00348           if( (use_alpha == true) && (use_beta == true) )
00349             {
00350             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00351             }
00352           }
00353         }
00354       }
00355     else
00356     if( (do_trans_A == true) && (do_trans_B == true) )
00357       {
00358       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00359         {
00360         
00361         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00362           {
00363           const in_eT1* A_coldata = A.colptr(col_A);
00364           
00365           out_eT acc = out_eT(0);
00366           for(u32 i=0; i < A_n_rows; ++i)
00367             {
00368             acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00369             }
00370         
00371           if( (use_alpha == false) && (use_beta == false) )
00372             {
00373             C.at(col_A,row_B) = acc;
00374             }
00375           else
00376           if( (use_alpha == true) && (use_beta == false) )
00377             {
00378             C.at(col_A,row_B) = alpha * acc;
00379             }
00380           else
00381           if( (use_alpha == false) && (use_beta == true) )
00382             {
00383             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00384             }
00385           else
00386           if( (use_alpha == true) && (use_beta == true) )
00387             {
00388             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00389             }
00390           
00391           }
00392         }
00393       
00394       }
00395     }
00396     
00397   };
00398 
00399 
00400 
00401 
00402 
00403 //! \brief
00404 //! Matrix multplication where the matrices have different element types.
00405 
00406 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00407 class gemm_mixed
00408   {
00409   public:
00410   
00411   //! immediate multiplication of matrices A and B, storing the result in C
00412   template<typename out_eT, typename in_eT1, typename in_eT2>
00413   inline
00414   static
00415   void
00416   apply
00417     (
00418           Mat<out_eT>& C,
00419     const Mat<in_eT1>& A,
00420     const Mat<in_eT2>& B,
00421     const out_eT alpha = out_eT(1),
00422     const out_eT beta  = out_eT(0)
00423     )
00424     {
00425     arma_extra_debug_sigprint();
00426     
00427     if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00428       {
00429       gemm_mixed_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00430       }
00431     else
00432       {
00433       gemm_mixed_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00434       }
00435     }
00436   
00437   };
00438 
00439 
00440 
00441 //! @}