00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemv_arma
00028 {
00029 public:
00030
00031 template<typename eT>
00032 arma_hot
00033 inline
00034 static
00035 void
00036 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00037 {
00038 arma_extra_debug_sigprint();
00039
00040 const u32 A_n_rows = A.n_rows;
00041 const u32 A_n_cols = A.n_cols;
00042
00043 if(do_trans_A == false)
00044 {
00045 for(u32 row=0; row < A_n_rows; ++row)
00046 {
00047
00048 eT acc = eT(0);
00049 for(u32 col=0; col < A_n_cols; ++col)
00050 {
00051 acc += A.at(row,col) * x[col];
00052 }
00053
00054 if( (use_alpha == false) && (use_beta == false) )
00055 {
00056 y[row] = acc;
00057 }
00058 else
00059 if( (use_alpha == true) && (use_beta == false) )
00060 {
00061 y[row] = alpha * acc;
00062 }
00063 else
00064 if( (use_alpha == false) && (use_beta == true) )
00065 {
00066 y[row] = acc + beta*y[row];
00067 }
00068 else
00069 if( (use_alpha == true) && (use_beta == true) )
00070 {
00071 y[row] = alpha*acc + beta*y[row];
00072 }
00073 }
00074 }
00075 else
00076 if(do_trans_A == true)
00077 {
00078 for(u32 col=0; col < A_n_cols; ++col)
00079 {
00080
00081
00082 const eT* A_coldata = A.colptr(col);
00083
00084 eT acc = eT(0);
00085 for(u32 row=0; row < A_n_rows; ++row)
00086 {
00087 acc += A_coldata[row] * x[row];
00088 }
00089
00090 if( (use_alpha == false) && (use_beta == false) )
00091 {
00092 y[col] = acc;
00093 }
00094 else
00095 if( (use_alpha == true) && (use_beta == false) )
00096 {
00097 y[col] = alpha * acc;
00098 }
00099 else
00100 if( (use_alpha == false) && (use_beta == true) )
00101 {
00102 y[col] = acc + beta*y[col];
00103 }
00104 else
00105 if( (use_alpha == true) && (use_beta == true) )
00106 {
00107 y[col] = alpha*acc + beta*y[col];
00108 }
00109
00110 }
00111 }
00112 }
00113
00114 };
00115
00116
00117
00118
00119
00120
00121
00122 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00123 class gemv
00124 {
00125 public:
00126
00127 template<typename eT>
00128 inline
00129 static
00130 void
00131 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00132 {
00133 arma_extra_debug_sigprint();
00134
00135 if(A.n_elem <= 256u)
00136 {
00137 gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00138 }
00139 else
00140 {
00141 #if defined(ARMA_USE_ATLAS)
00142 {
00143 arma_extra_debug_print("atlas::cblas_gemv()");
00144
00145 atlas::cblas_gemv<eT>
00146 (
00147 atlas::CblasColMajor,
00148 (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00149 A.n_rows,
00150 A.n_cols,
00151 (use_alpha) ? alpha : eT(1),
00152 A.mem,
00153 A.n_rows,
00154 x,
00155 1,
00156 (use_beta) ? beta : eT(0),
00157 y,
00158 1
00159 );
00160 }
00161 #elif defined(ARMA_USE_BLAS)
00162 {
00163 arma_extra_debug_print("blas::gemv_()");
00164
00165 const char trans_A = (do_trans_A) ? 'T' : 'N';
00166 const int m = A.n_rows;
00167 const int n = A.n_cols;
00168 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00169
00170 const int inc = 1;
00171 const eT local_beta = (use_beta) ? beta : eT(0);
00172
00173 arma_extra_debug_print( arma_boost::format("blas::gemv_(): trans_A = %c") % trans_A );
00174
00175 blas::gemv_<eT>
00176 (
00177 &trans_A,
00178 &m,
00179 &n,
00180 &local_alpha,
00181 A.mem,
00182 &m,
00183 x,
00184 &inc,
00185 &local_beta,
00186 y,
00187 &inc
00188 );
00189 }
00190 #else
00191 {
00192 gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00193 }
00194 #endif
00195 }
00196
00197 }
00198
00199
00200
00201 template<typename eT>
00202 arma_inline
00203 static
00204 void
00205 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00206 {
00207 gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00208 }
00209
00210
00211
00212 arma_inline
00213 static
00214 void
00215 apply
00216 (
00217 float* y,
00218 const Mat<float>& A,
00219 const float* x,
00220 const float alpha = float(1),
00221 const float beta = float(0)
00222 )
00223 {
00224 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00225 }
00226
00227
00228
00229 arma_inline
00230 static
00231 void
00232 apply
00233 (
00234 double* y,
00235 const Mat<double>& A,
00236 const double* x,
00237 const double alpha = double(1),
00238 const double beta = double(0)
00239 )
00240 {
00241 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00242 }
00243
00244
00245
00246 arma_inline
00247 static
00248 void
00249 apply
00250 (
00251 std::complex<float>* y,
00252 const Mat< std::complex<float > >& A,
00253 const std::complex<float>* x,
00254 const std::complex<float> alpha = std::complex<float>(1),
00255 const std::complex<float> beta = std::complex<float>(0)
00256 )
00257 {
00258 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00259 }
00260
00261
00262
00263 arma_inline
00264 static
00265 void
00266 apply
00267 (
00268 std::complex<double>* y,
00269 const Mat< std::complex<double> >& A,
00270 const std::complex<double>* x,
00271 const std::complex<double> alpha = std::complex<double>(1),
00272 const std::complex<double> beta = std::complex<double>(0)
00273 )
00274 {
00275 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00276 }
00277
00278
00279
00280 };
00281
00282
00283