gemm.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup gemm
00018 //! @{
00019 
00020 
00021 
00022 //! \brief
00023 //! Partial emulation of ATLAS/BLAS gemm(), using caching for speedup.
00024 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00025 
00026 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemm_emul_cache
00028   {
00029   public:
00030   
00031   template<typename eT>
00032   arma_hot
00033   inline
00034   static
00035   void
00036   apply
00037     (
00038           Mat<eT>& C,
00039     const Mat<eT>& A,
00040     const Mat<eT>& B,
00041     const eT alpha = eT(1),
00042     const eT beta  = eT(0)
00043     )
00044     {
00045     arma_extra_debug_sigprint();
00046 
00047     const u32 A_n_rows = A.n_rows;
00048     const u32 A_n_cols = A.n_cols;
00049     
00050     const u32 B_n_rows = B.n_rows;
00051     const u32 B_n_cols = B.n_cols;
00052     
00053     if( (do_trans_A == false) && (do_trans_B == false) )
00054       {
00055       arma_aligned podarray<eT> tmp(A_n_cols);
00056       eT* A_rowdata = tmp.memptr();
00057       
00058       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00059         {
00060         
00061         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00062           {
00063           A_rowdata[col_A] = A.at(row_A,col_A);
00064           }
00065         
00066         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00067           {
00068           const eT* B_coldata = B.colptr(col_B);
00069           
00070           eT acc = eT(0);
00071           for(u32 i=0; i < B_n_rows; ++i)
00072             {
00073             acc += A_rowdata[i] * B_coldata[i];
00074             }
00075         
00076           if( (use_alpha == false) && (use_beta == false) )
00077             {
00078             C.at(row_A,col_B) = acc;
00079             }
00080           else
00081           if( (use_alpha == true) && (use_beta == false) )
00082             {
00083             C.at(row_A,col_B) = alpha * acc;
00084             }
00085           else
00086           if( (use_alpha == false) && (use_beta == true) )
00087             {
00088             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00089             }
00090           else
00091           if( (use_alpha == true) && (use_beta == true) )
00092             {
00093             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00094             }
00095           
00096           }
00097         }
00098       }
00099     else
00100     if( (do_trans_A == true) && (do_trans_B == false) )
00101       {
00102       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00103         {
00104         // col_A is interpreted as row_A when storing the results in matrix C
00105         
00106         const eT* A_coldata = A.colptr(col_A);
00107         
00108         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00109           {
00110           const eT* B_coldata = B.colptr(col_B);
00111           
00112           eT acc = eT(0);
00113           for(u32 i=0; i < B_n_rows; ++i)
00114             {
00115             acc += A_coldata[i] * B_coldata[i];
00116             }
00117         
00118           if( (use_alpha == false) && (use_beta == false) )
00119             {
00120             C.at(col_A,col_B) = acc;
00121             }
00122           else
00123           if( (use_alpha == true) && (use_beta == false) )
00124             {
00125             C.at(col_A,col_B) = alpha * acc;
00126             }
00127           else
00128           if( (use_alpha == false) && (use_beta == true) )
00129             {
00130             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00131             }
00132           else
00133           if( (use_alpha == true) && (use_beta == true) )
00134             {
00135             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00136             }
00137           
00138           }
00139         }
00140       }
00141     else
00142     if( (do_trans_A == false) && (do_trans_B == true) )
00143       {
00144       Mat<eT> B_tmp = trans(B);
00145       gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00146       }
00147     else
00148     if( (do_trans_A == true) && (do_trans_B == true) )
00149       {
00150       // mat B_tmp = trans(B);
00151       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00152       
00153       
00154       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00155       // transpose operations are not needed
00156       
00157       arma_aligned podarray<eT> tmp(B.n_cols);
00158       eT* B_rowdata = tmp.memptr();
00159       
00160       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00161         {
00162         
00163         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00164           {
00165           B_rowdata[col_B] = B.at(row_B,col_B);
00166           }
00167         
00168         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00169           {
00170           const eT* A_coldata = A.colptr(col_A);
00171           
00172           eT acc = eT(0);
00173           for(u32 i=0; i < A_n_rows; ++i)
00174             {
00175             acc += B_rowdata[i] * A_coldata[i];
00176             }
00177         
00178           if( (use_alpha == false) && (use_beta == false) )
00179             {
00180             C.at(col_A,row_B) = acc;
00181             }
00182           else
00183           if( (use_alpha == true) && (use_beta == false) )
00184             {
00185             C.at(col_A,row_B) = alpha * acc;
00186             }
00187           else
00188           if( (use_alpha == false) && (use_beta == true) )
00189             {
00190             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00191             }
00192           else
00193           if( (use_alpha == true) && (use_beta == true) )
00194             {
00195             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00196             }
00197           
00198           }
00199         }
00200       
00201       }
00202     }
00203     
00204   };
00205 
00206 
00207 
00208 //! Partial emulation of ATLAS/BLAS gemm(), non-cached version.
00209 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00210 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00211 class gemm_emul_simple
00212   {
00213   public:
00214   
00215   template<typename eT>
00216   arma_hot
00217   inline
00218   static
00219   void
00220   apply
00221     (
00222           Mat<eT>& C,
00223     const Mat<eT>& A,
00224     const Mat<eT>& B,
00225     const eT alpha = eT(1),
00226     const eT beta  = eT(0)
00227     )
00228     {
00229     arma_extra_debug_sigprint();
00230     
00231     const u32 A_n_rows = A.n_rows;
00232     const u32 A_n_cols = A.n_cols;
00233     
00234     const u32 B_n_rows = B.n_rows;
00235     const u32 B_n_cols = B.n_cols;
00236     
00237     if( (do_trans_A == false) && (do_trans_B == false) )
00238       {
00239       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00240         {
00241         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00242           {
00243           const eT* B_coldata = B.colptr(col_B);
00244           
00245           eT acc = eT(0);
00246           for(u32 i = 0; i < B_n_rows; ++i)
00247             {
00248             acc += A.at(row_A,i) * B_coldata[i];
00249             }
00250           
00251           if( (use_alpha == false) && (use_beta == false) )
00252             {
00253             C.at(row_A,col_B) = acc;
00254             }
00255           else
00256           if( (use_alpha == true) && (use_beta == false) )
00257             {
00258             C.at(row_A,col_B) = alpha * acc;
00259             }
00260           else
00261           if( (use_alpha == false) && (use_beta == true) )
00262             {
00263             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00264             }
00265           else
00266           if( (use_alpha == true) && (use_beta == true) )
00267             {
00268             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00269             }
00270           }
00271         }
00272       }
00273     else
00274     if( (do_trans_A == true) && (do_trans_B == false) )
00275       {
00276       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00277         {
00278         // col_A is interpreted as row_A when storing the results in matrix C
00279         
00280         const eT* A_coldata = A.colptr(col_A);
00281         
00282         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00283           {
00284           const eT* B_coldata = B.colptr(col_B);
00285           
00286           eT acc = eT(0);
00287           for(u32 i=0; i < B_n_rows; ++i)
00288             {
00289             acc += A_coldata[i] * B_coldata[i];
00290             }
00291         
00292           if( (use_alpha == false) && (use_beta == false) )
00293             {
00294             C.at(col_A,col_B) = acc;
00295             }
00296           else
00297           if( (use_alpha == true) && (use_beta == false) )
00298             {
00299             C.at(col_A,col_B) = alpha * acc;
00300             }
00301           else
00302           if( (use_alpha == false) && (use_beta == true) )
00303             {
00304             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00305             }
00306           else
00307           if( (use_alpha == true) && (use_beta == true) )
00308             {
00309             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00310             }
00311           
00312           }
00313         }
00314       }
00315     else
00316     if( (do_trans_A == false) && (do_trans_B == true) )
00317       {
00318       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00319         {
00320         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00321           {
00322           eT acc = eT(0);
00323           for(u32 i = 0; i < B_n_cols; ++i)
00324             {
00325             acc += A.at(row_A,i) * B.at(row_B,i);
00326             }
00327           
00328           if( (use_alpha == false) && (use_beta == false) )
00329             {
00330             C.at(row_A,row_B) = acc;
00331             }
00332           else
00333           if( (use_alpha == true) && (use_beta == false) )
00334             {
00335             C.at(row_A,row_B) = alpha * acc;
00336             }
00337           else
00338           if( (use_alpha == false) && (use_beta == true) )
00339             {
00340             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00341             }
00342           else
00343           if( (use_alpha == true) && (use_beta == true) )
00344             {
00345             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00346             }
00347           }
00348         }
00349       }
00350     else
00351     if( (do_trans_A == true) && (do_trans_B == true) )
00352       {
00353       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00354         {
00355         
00356         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00357           {
00358           const eT* A_coldata = A.colptr(col_A);
00359           
00360           eT acc = eT(0);
00361           for(u32 i=0; i < A_n_rows; ++i)
00362             {
00363             acc += B.at(row_B,i) * A_coldata[i];
00364             }
00365         
00366           if( (use_alpha == false) && (use_beta == false) )
00367             {
00368             C.at(col_A,row_B) = acc;
00369             }
00370           else
00371           if( (use_alpha == true) && (use_beta == false) )
00372             {
00373             C.at(col_A,row_B) = alpha * acc;
00374             }
00375           else
00376           if( (use_alpha == false) && (use_beta == true) )
00377             {
00378             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00379             }
00380           else
00381           if( (use_alpha == true) && (use_beta == true) )
00382             {
00383             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00384             }
00385           
00386           }
00387         }
00388       
00389       }
00390     }
00391     
00392   };
00393 
00394 
00395 
00396 //! \brief
00397 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
00398 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00399 
00400 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00401 class gemm
00402   {
00403   public:
00404   
00405   template<typename eT>
00406   inline
00407   static
00408   void
00409   apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00410     {
00411     arma_extra_debug_sigprint();
00412     
00413     if( ((A.n_elem <= 64u) && (B.n_elem <= 64u)) )
00414       {
00415       gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00416       }
00417     else
00418       {
00419       #if defined(ARMA_USE_ATLAS)
00420         {
00421         arma_extra_debug_print("atlas::cblas_gemm()");
00422         
00423         atlas::cblas_gemm<eT>
00424           (
00425           atlas::CblasColMajor,
00426           (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00427           (do_trans_B) ? atlas::CblasTrans : atlas::CblasNoTrans,
00428           C.n_rows,
00429           C.n_cols,
00430           (do_trans_A) ? A.n_rows : A.n_cols,
00431           (use_alpha) ? alpha : eT(1),
00432           A.mem,
00433           (do_trans_A) ? A.n_rows : C.n_rows,
00434           B.mem,
00435           (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00436           (use_beta) ? beta : eT(0),
00437           C.memptr(),
00438           C.n_rows
00439           );
00440         }
00441       #elif defined(ARMA_USE_BLAS)
00442         {
00443         arma_extra_debug_print("blas::gemm_()");
00444         
00445         const char trans_A = (do_trans_A) ? 'T' : 'N';
00446         const char trans_B = (do_trans_B) ? 'T' : 'N';
00447         
00448         const int m   = C.n_rows;
00449         const int n   = C.n_cols;
00450         const int k   = (do_trans_A) ? A.n_rows : A.n_cols;
00451         
00452         const eT local_alpha = (use_alpha) ? alpha : eT(1);
00453         
00454         const int lda = (do_trans_A) ? k : m;
00455         const int ldb = (do_trans_B) ? n : k;
00456         
00457         const eT local_beta  = (use_beta) ? beta : eT(0);
00458         
00459         arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_A = %c") % trans_A );
00460         arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_B = %c") % trans_B );
00461         
00462         blas::gemm_<eT>
00463           (
00464           &trans_A,
00465           &trans_B,
00466           &m,
00467           &n,
00468           &k,
00469           &local_alpha,
00470           A.mem,
00471           &lda,
00472           B.mem,
00473           &ldb,
00474           &local_beta,
00475           C.memptr(),
00476           &m
00477           );
00478         }
00479       #else
00480         {
00481         gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00482         }
00483       #endif
00484       }
00485     }
00486   
00487   
00488   
00489   //! immediate multiplication of matrices A and B, storing the result in C
00490   template<typename eT>
00491   inline
00492   static
00493   void
00494   apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00495     {
00496     if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00497       {
00498       gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00499       }
00500     else
00501       {
00502       gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00503       }
00504     }
00505   
00506   
00507   
00508   arma_inline
00509   static
00510   void
00511   apply
00512     (
00513           Mat<float>& C,
00514     const Mat<float>& A,
00515     const Mat<float>& B,
00516     const float alpha = float(1),
00517     const float beta  = float(0)
00518     )
00519     {
00520     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00521     }
00522   
00523   
00524   
00525   arma_inline
00526   static
00527   void
00528   apply
00529     (
00530           Mat<double>& C,
00531     const Mat<double>& A,
00532     const Mat<double>& B,
00533     const double alpha = double(1),
00534     const double beta  = double(0)
00535     )
00536     {
00537     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00538     }
00539   
00540   
00541   
00542   arma_inline
00543   static
00544   void
00545   apply
00546     (
00547           Mat< std::complex<float> >& C,
00548     const Mat< std::complex<float> >& A,
00549     const Mat< std::complex<float> >& B,
00550     const std::complex<float> alpha = std::complex<float>(1),
00551     const std::complex<float> beta  = std::complex<float>(0)
00552     )
00553     {
00554     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00555     }
00556   
00557   
00558   
00559   arma_inline
00560   static
00561   void
00562   apply
00563     (
00564           Mat< std::complex<double> >& C,
00565     const Mat< std::complex<double> >& A,
00566     const Mat< std::complex<double> >& B,
00567     const std::complex<double> alpha = std::complex<double>(1),
00568     const std::complex<double> beta  = std::complex<double>(0)
00569     )
00570     {
00571     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00572     }
00573   
00574   };
00575 
00576 
00577 
00578 //! @}