gemv.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup gemv
00018 //! @{
00019 
00020 
00021 
00022 //! \brief
00023 //! Partial emulation of ATLAS/BLAS gemv().
00024 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
00025 
00026 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemv_arma
00028   {
00029   public:
00030   
00031   template<typename eT>
00032   arma_hot
00033   inline
00034   static
00035   void
00036   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00037     {
00038     arma_extra_debug_sigprint();
00039     
00040     const u32 A_n_rows = A.n_rows;
00041     const u32 A_n_cols = A.n_cols;
00042     
00043     if(do_trans_A == false)
00044       {
00045       for(u32 row=0; row < A_n_rows; ++row)
00046         {
00047         
00048         eT acc = eT(0);
00049         for(u32 col=0; col < A_n_cols; ++col)
00050           {
00051           acc += A.at(row,col) * x[col];
00052           }
00053           
00054         if( (use_alpha == false) && (use_beta == false) )
00055           {
00056           y[row] = acc;
00057           }
00058         else
00059         if( (use_alpha == true) && (use_beta == false) )
00060           {
00061           y[row] = alpha * acc;
00062           }
00063         else
00064         if( (use_alpha == false) && (use_beta == true) )
00065           {
00066           y[row] = acc + beta*y[row];
00067           }
00068         else
00069         if( (use_alpha == true) && (use_beta == true) )
00070           {
00071           y[row] = alpha*acc + beta*y[row];
00072           }
00073         }
00074       }
00075     else
00076     if(do_trans_A == true)
00077       {
00078       for(u32 col=0; col < A_n_cols; ++col)
00079         {
00080         // col is interpreted as row when storing the results in 'y'
00081         
00082         const eT* A_coldata = A.colptr(col);
00083         
00084         eT acc = eT(0);
00085         for(u32 row=0; row < A_n_rows; ++row)
00086           {
00087           acc += A_coldata[row] * x[row];
00088           }
00089       
00090         if( (use_alpha == false) && (use_beta == false) )
00091           {
00092           y[col] = acc;
00093           }
00094         else
00095         if( (use_alpha == true) && (use_beta == false) )
00096           {
00097           y[col] = alpha * acc;
00098           }
00099         else
00100         if( (use_alpha == false) && (use_beta == true) )
00101           {
00102           y[col] = acc + beta*y[col];
00103           }
00104         else
00105         if( (use_alpha == true) && (use_beta == true) )
00106           {
00107           y[col] = alpha*acc + beta*y[col];
00108           }
00109         
00110         }
00111       }
00112     }
00113     
00114   };
00115 
00116 
00117 
00118 //! \brief
00119 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
00120 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
00121 
00122 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00123 class gemv
00124   {
00125   public:
00126   
00127   template<typename eT>
00128   inline
00129   static
00130   void
00131   apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00132     {
00133     arma_extra_debug_sigprint();
00134     
00135     if(A.n_elem <= 256u)
00136      {
00137      gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00138      }
00139     else
00140       {
00141       #if defined(ARMA_USE_ATLAS)
00142         {
00143         arma_extra_debug_print("atlas::cblas_gemv()");
00144         
00145         atlas::cblas_gemv<eT>
00146           (
00147           atlas::CblasColMajor,
00148           (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00149           A.n_rows,
00150           A.n_cols,
00151           (use_alpha) ? alpha : eT(1),
00152           A.mem,
00153           A.n_rows,
00154           x,
00155           1,
00156           (use_beta) ? beta : eT(0),
00157           y,
00158           1
00159           );
00160         }
00161       #elif defined(ARMA_USE_BLAS)
00162         {
00163         arma_extra_debug_print("blas::gemv_()");
00164         
00165         const char trans_A     = (do_trans_A) ? 'T' : 'N';
00166         const int  m           = A.n_rows;
00167         const int  n           = A.n_cols;
00168         const eT   local_alpha = (use_alpha) ? alpha : eT(1);
00169         //const int  lda         = A.n_rows;
00170         const int  inc         = 1;
00171         const eT   local_beta  = (use_beta) ? beta : eT(0);
00172         
00173         arma_extra_debug_print( arma_boost::format("blas::gemv_(): trans_A = %c") % trans_A );
00174 
00175         blas::gemv_<eT>
00176           (
00177           &trans_A,
00178           &m,
00179           &n,
00180           &local_alpha,
00181           A.mem,
00182           &m,  // lda
00183           x,
00184           &inc,
00185           &local_beta,
00186           y,
00187           &inc
00188           );
00189         }
00190       #else
00191         {
00192         gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00193         }
00194       #endif
00195       }
00196     
00197     }
00198   
00199   
00200   
00201   template<typename eT>
00202   arma_inline
00203   static
00204   void
00205   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00206     {
00207     gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00208     }
00209   
00210   
00211   
00212   arma_inline
00213   static
00214   void
00215   apply
00216     (
00217           float*      y,
00218     const Mat<float>& A,
00219     const float*      x,
00220     const float       alpha = float(1),
00221     const float       beta  = float(0)
00222     )
00223     {
00224     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00225     }
00226 
00227 
00228   
00229   arma_inline
00230   static
00231   void
00232   apply
00233     (
00234           double*      y,
00235     const Mat<double>& A,
00236     const double*      x,
00237     const double       alpha = double(1),
00238     const double       beta  = double(0)
00239     )
00240     {
00241     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00242     }
00243 
00244 
00245   
00246   arma_inline
00247   static
00248   void
00249   apply
00250     (
00251           std::complex<float>*         y,
00252     const Mat< std::complex<float > >& A,
00253     const std::complex<float>*         x,
00254     const std::complex<float>          alpha = std::complex<float>(1),
00255     const std::complex<float>          beta  = std::complex<float>(0)
00256     )
00257     {
00258     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00259     }
00260 
00261 
00262   
00263   arma_inline
00264   static
00265   void
00266   apply
00267     (
00268           std::complex<double>*        y,
00269     const Mat< std::complex<double> >& A,
00270     const std::complex<double>*        x,
00271     const std::complex<double>         alpha = std::complex<double>(1),
00272     const std::complex<double>         beta  = std::complex<double>(0)
00273     )
00274     {
00275     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00276     }
00277 
00278 
00279   
00280   };
00281 
00282 
00283 //! @}