00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemm_emul_cache
00028 {
00029 public:
00030
00031 template<typename eT>
00032 arma_hot
00033 inline
00034 static
00035 void
00036 apply
00037 (
00038 Mat<eT>& C,
00039 const Mat<eT>& A,
00040 const Mat<eT>& B,
00041 const eT alpha = eT(1),
00042 const eT beta = eT(0)
00043 )
00044 {
00045 arma_extra_debug_sigprint();
00046
00047 const u32 A_n_rows = A.n_rows;
00048 const u32 A_n_cols = A.n_cols;
00049
00050 const u32 B_n_rows = B.n_rows;
00051 const u32 B_n_cols = B.n_cols;
00052
00053 if( (do_trans_A == false) && (do_trans_B == false) )
00054 {
00055 arma_aligned podarray<eT> tmp(A_n_cols);
00056 eT* A_rowdata = tmp.memptr();
00057
00058 for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00059 {
00060
00061 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00062 {
00063 A_rowdata[col_A] = A.at(row_A,col_A);
00064 }
00065
00066 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00067 {
00068 const eT* B_coldata = B.colptr(col_B);
00069
00070 eT acc = eT(0);
00071 for(u32 i=0; i < B_n_rows; ++i)
00072 {
00073 acc += A_rowdata[i] * B_coldata[i];
00074 }
00075
00076 if( (use_alpha == false) && (use_beta == false) )
00077 {
00078 C.at(row_A,col_B) = acc;
00079 }
00080 else
00081 if( (use_alpha == true) && (use_beta == false) )
00082 {
00083 C.at(row_A,col_B) = alpha * acc;
00084 }
00085 else
00086 if( (use_alpha == false) && (use_beta == true) )
00087 {
00088 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00089 }
00090 else
00091 if( (use_alpha == true) && (use_beta == true) )
00092 {
00093 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00094 }
00095
00096 }
00097 }
00098 }
00099 else
00100 if( (do_trans_A == true) && (do_trans_B == false) )
00101 {
00102 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00103 {
00104
00105
00106 const eT* A_coldata = A.colptr(col_A);
00107
00108 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00109 {
00110 const eT* B_coldata = B.colptr(col_B);
00111
00112 eT acc = eT(0);
00113 for(u32 i=0; i < B_n_rows; ++i)
00114 {
00115 acc += A_coldata[i] * B_coldata[i];
00116 }
00117
00118 if( (use_alpha == false) && (use_beta == false) )
00119 {
00120 C.at(col_A,col_B) = acc;
00121 }
00122 else
00123 if( (use_alpha == true) && (use_beta == false) )
00124 {
00125 C.at(col_A,col_B) = alpha * acc;
00126 }
00127 else
00128 if( (use_alpha == false) && (use_beta == true) )
00129 {
00130 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00131 }
00132 else
00133 if( (use_alpha == true) && (use_beta == true) )
00134 {
00135 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00136 }
00137
00138 }
00139 }
00140 }
00141 else
00142 if( (do_trans_A == false) && (do_trans_B == true) )
00143 {
00144 Mat<eT> B_tmp = trans(B);
00145 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00146 }
00147 else
00148 if( (do_trans_A == true) && (do_trans_B == true) )
00149 {
00150
00151
00152
00153
00154
00155
00156
00157 arma_aligned podarray<eT> tmp(B.n_cols);
00158 eT* B_rowdata = tmp.memptr();
00159
00160 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00161 {
00162
00163 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00164 {
00165 B_rowdata[col_B] = B.at(row_B,col_B);
00166 }
00167
00168 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00169 {
00170 const eT* A_coldata = A.colptr(col_A);
00171
00172 eT acc = eT(0);
00173 for(u32 i=0; i < A_n_rows; ++i)
00174 {
00175 acc += B_rowdata[i] * A_coldata[i];
00176 }
00177
00178 if( (use_alpha == false) && (use_beta == false) )
00179 {
00180 C.at(col_A,row_B) = acc;
00181 }
00182 else
00183 if( (use_alpha == true) && (use_beta == false) )
00184 {
00185 C.at(col_A,row_B) = alpha * acc;
00186 }
00187 else
00188 if( (use_alpha == false) && (use_beta == true) )
00189 {
00190 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00191 }
00192 else
00193 if( (use_alpha == true) && (use_beta == true) )
00194 {
00195 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00196 }
00197
00198 }
00199 }
00200
00201 }
00202 }
00203
00204 };
00205
00206
00207
00208
00209
00210 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00211 class gemm_emul_simple
00212 {
00213 public:
00214
00215 template<typename eT>
00216 arma_hot
00217 inline
00218 static
00219 void
00220 apply
00221 (
00222 Mat<eT>& C,
00223 const Mat<eT>& A,
00224 const Mat<eT>& B,
00225 const eT alpha = eT(1),
00226 const eT beta = eT(0)
00227 )
00228 {
00229 arma_extra_debug_sigprint();
00230
00231 const u32 A_n_rows = A.n_rows;
00232 const u32 A_n_cols = A.n_cols;
00233
00234 const u32 B_n_rows = B.n_rows;
00235 const u32 B_n_cols = B.n_cols;
00236
00237 if( (do_trans_A == false) && (do_trans_B == false) )
00238 {
00239 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00240 {
00241 for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00242 {
00243 const eT* B_coldata = B.colptr(col_B);
00244
00245 eT acc = eT(0);
00246 for(u32 i = 0; i < B_n_rows; ++i)
00247 {
00248 acc += A.at(row_A,i) * B_coldata[i];
00249 }
00250
00251 if( (use_alpha == false) && (use_beta == false) )
00252 {
00253 C.at(row_A,col_B) = acc;
00254 }
00255 else
00256 if( (use_alpha == true) && (use_beta == false) )
00257 {
00258 C.at(row_A,col_B) = alpha * acc;
00259 }
00260 else
00261 if( (use_alpha == false) && (use_beta == true) )
00262 {
00263 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00264 }
00265 else
00266 if( (use_alpha == true) && (use_beta == true) )
00267 {
00268 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00269 }
00270 }
00271 }
00272 }
00273 else
00274 if( (do_trans_A == true) && (do_trans_B == false) )
00275 {
00276 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00277 {
00278
00279
00280 const eT* A_coldata = A.colptr(col_A);
00281
00282 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00283 {
00284 const eT* B_coldata = B.colptr(col_B);
00285
00286 eT acc = eT(0);
00287 for(u32 i=0; i < B_n_rows; ++i)
00288 {
00289 acc += A_coldata[i] * B_coldata[i];
00290 }
00291
00292 if( (use_alpha == false) && (use_beta == false) )
00293 {
00294 C.at(col_A,col_B) = acc;
00295 }
00296 else
00297 if( (use_alpha == true) && (use_beta == false) )
00298 {
00299 C.at(col_A,col_B) = alpha * acc;
00300 }
00301 else
00302 if( (use_alpha == false) && (use_beta == true) )
00303 {
00304 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00305 }
00306 else
00307 if( (use_alpha == true) && (use_beta == true) )
00308 {
00309 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00310 }
00311
00312 }
00313 }
00314 }
00315 else
00316 if( (do_trans_A == false) && (do_trans_B == true) )
00317 {
00318 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00319 {
00320 for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00321 {
00322 eT acc = eT(0);
00323 for(u32 i = 0; i < B_n_cols; ++i)
00324 {
00325 acc += A.at(row_A,i) * B.at(row_B,i);
00326 }
00327
00328 if( (use_alpha == false) && (use_beta == false) )
00329 {
00330 C.at(row_A,row_B) = acc;
00331 }
00332 else
00333 if( (use_alpha == true) && (use_beta == false) )
00334 {
00335 C.at(row_A,row_B) = alpha * acc;
00336 }
00337 else
00338 if( (use_alpha == false) && (use_beta == true) )
00339 {
00340 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00341 }
00342 else
00343 if( (use_alpha == true) && (use_beta == true) )
00344 {
00345 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00346 }
00347 }
00348 }
00349 }
00350 else
00351 if( (do_trans_A == true) && (do_trans_B == true) )
00352 {
00353 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00354 {
00355
00356 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00357 {
00358 const eT* A_coldata = A.colptr(col_A);
00359
00360 eT acc = eT(0);
00361 for(u32 i=0; i < A_n_rows; ++i)
00362 {
00363 acc += B.at(row_B,i) * A_coldata[i];
00364 }
00365
00366 if( (use_alpha == false) && (use_beta == false) )
00367 {
00368 C.at(col_A,row_B) = acc;
00369 }
00370 else
00371 if( (use_alpha == true) && (use_beta == false) )
00372 {
00373 C.at(col_A,row_B) = alpha * acc;
00374 }
00375 else
00376 if( (use_alpha == false) && (use_beta == true) )
00377 {
00378 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00379 }
00380 else
00381 if( (use_alpha == true) && (use_beta == true) )
00382 {
00383 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00384 }
00385
00386 }
00387 }
00388
00389 }
00390 }
00391
00392 };
00393
00394
00395
00396
00397
00398
00399
00400 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00401 class gemm
00402 {
00403 public:
00404
00405 template<typename eT>
00406 inline
00407 static
00408 void
00409 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00410 {
00411 arma_extra_debug_sigprint();
00412
00413 if( ((A.n_elem <= 64u) && (B.n_elem <= 64u)) )
00414 {
00415 gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00416 }
00417 else
00418 {
00419 #if defined(ARMA_USE_ATLAS)
00420 {
00421 arma_extra_debug_print("atlas::cblas_gemm()");
00422
00423 atlas::cblas_gemm<eT>
00424 (
00425 atlas::CblasColMajor,
00426 (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00427 (do_trans_B) ? atlas::CblasTrans : atlas::CblasNoTrans,
00428 C.n_rows,
00429 C.n_cols,
00430 (do_trans_A) ? A.n_rows : A.n_cols,
00431 (use_alpha) ? alpha : eT(1),
00432 A.mem,
00433 (do_trans_A) ? A.n_rows : C.n_rows,
00434 B.mem,
00435 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00436 (use_beta) ? beta : eT(0),
00437 C.memptr(),
00438 C.n_rows
00439 );
00440 }
00441 #elif defined(ARMA_USE_BLAS)
00442 {
00443 arma_extra_debug_print("blas::gemm_()");
00444
00445 const char trans_A = (do_trans_A) ? 'T' : 'N';
00446 const char trans_B = (do_trans_B) ? 'T' : 'N';
00447
00448 const int m = C.n_rows;
00449 const int n = C.n_cols;
00450 const int k = (do_trans_A) ? A.n_rows : A.n_cols;
00451
00452 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00453
00454 const int lda = (do_trans_A) ? k : m;
00455 const int ldb = (do_trans_B) ? n : k;
00456
00457 const eT local_beta = (use_beta) ? beta : eT(0);
00458
00459 arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_A = %c") % trans_A );
00460 arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_B = %c") % trans_B );
00461
00462 blas::gemm_<eT>
00463 (
00464 &trans_A,
00465 &trans_B,
00466 &m,
00467 &n,
00468 &k,
00469 &local_alpha,
00470 A.mem,
00471 &lda,
00472 B.mem,
00473 &ldb,
00474 &local_beta,
00475 C.memptr(),
00476 &m
00477 );
00478 }
00479 #else
00480 {
00481 gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00482 }
00483 #endif
00484 }
00485 }
00486
00487
00488
00489
00490 template<typename eT>
00491 inline
00492 static
00493 void
00494 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00495 {
00496 if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00497 {
00498 gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00499 }
00500 else
00501 {
00502 gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00503 }
00504 }
00505
00506
00507
00508 arma_inline
00509 static
00510 void
00511 apply
00512 (
00513 Mat<float>& C,
00514 const Mat<float>& A,
00515 const Mat<float>& B,
00516 const float alpha = float(1),
00517 const float beta = float(0)
00518 )
00519 {
00520 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00521 }
00522
00523
00524
00525 arma_inline
00526 static
00527 void
00528 apply
00529 (
00530 Mat<double>& C,
00531 const Mat<double>& A,
00532 const Mat<double>& B,
00533 const double alpha = double(1),
00534 const double beta = double(0)
00535 )
00536 {
00537 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00538 }
00539
00540
00541
00542 arma_inline
00543 static
00544 void
00545 apply
00546 (
00547 Mat< std::complex<float> >& C,
00548 const Mat< std::complex<float> >& A,
00549 const Mat< std::complex<float> >& B,
00550 const std::complex<float> alpha = std::complex<float>(1),
00551 const std::complex<float> beta = std::complex<float>(0)
00552 )
00553 {
00554 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00555 }
00556
00557
00558
00559 arma_inline
00560 static
00561 void
00562 apply
00563 (
00564 Mat< std::complex<double> >& C,
00565 const Mat< std::complex<double> >& A,
00566 const Mat< std::complex<double> >& B,
00567 const std::complex<double> alpha = std::complex<double>(1),
00568 const std::complex<double> beta = std::complex<double>(0)
00569 )
00570 {
00571 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00572 }
00573
00574 };
00575
00576
00577
00578