00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00028 class gemm_mixed_cache
00029 {
00030 public:
00031
00032 template<typename out_eT, typename in_eT1, typename in_eT2>
00033 arma_hot
00034 inline
00035 static
00036 void
00037 apply
00038 (
00039 Mat<out_eT>& C,
00040 const Mat<in_eT1>& A,
00041 const Mat<in_eT2>& B,
00042 const out_eT alpha = out_eT(1),
00043 const out_eT beta = out_eT(0)
00044 )
00045 {
00046 arma_extra_debug_sigprint();
00047
00048 const u32 A_n_rows = A.n_rows;
00049 const u32 A_n_cols = A.n_cols;
00050
00051 const u32 B_n_rows = B.n_rows;
00052 const u32 B_n_cols = B.n_cols;
00053
00054 if( (do_trans_A == false) && (do_trans_B == false) )
00055 {
00056 podarray<in_eT1> tmp(A_n_cols);
00057 in_eT1* A_rowdata = tmp.memptr();
00058
00059 for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00060 {
00061
00062 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00063 {
00064 A_rowdata[col_A] = A.at(row_A,col_A);
00065 }
00066
00067 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00068 {
00069 const in_eT2* B_coldata = B.colptr(col_B);
00070
00071 out_eT acc = out_eT(0);
00072 for(u32 i=0; i < B_n_rows; ++i)
00073 {
00074 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00075 }
00076
00077 if( (use_alpha == false) && (use_beta == false) )
00078 {
00079 C.at(row_A,col_B) = acc;
00080 }
00081 else
00082 if( (use_alpha == true) && (use_beta == false) )
00083 {
00084 C.at(row_A,col_B) = alpha * acc;
00085 }
00086 else
00087 if( (use_alpha == false) && (use_beta == true) )
00088 {
00089 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00090 }
00091 else
00092 if( (use_alpha == true) && (use_beta == true) )
00093 {
00094 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00095 }
00096
00097 }
00098 }
00099 }
00100 else
00101 if( (do_trans_A == true) && (do_trans_B == false) )
00102 {
00103 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00104 {
00105
00106
00107 const in_eT1* A_coldata = A.colptr(col_A);
00108
00109 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00110 {
00111 const in_eT2* B_coldata = B.colptr(col_B);
00112
00113 out_eT acc = out_eT(0);
00114 for(u32 i=0; i < B_n_rows; ++i)
00115 {
00116 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00117 }
00118
00119 if( (use_alpha == false) && (use_beta == false) )
00120 {
00121 C.at(col_A,col_B) = acc;
00122 }
00123 else
00124 if( (use_alpha == true) && (use_beta == false) )
00125 {
00126 C.at(col_A,col_B) = alpha * acc;
00127 }
00128 else
00129 if( (use_alpha == false) && (use_beta == true) )
00130 {
00131 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00132 }
00133 else
00134 if( (use_alpha == true) && (use_beta == true) )
00135 {
00136 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00137 }
00138
00139 }
00140 }
00141 }
00142 else
00143 if( (do_trans_A == false) && (do_trans_B == true) )
00144 {
00145 Mat<in_eT2> B_tmp = trans(B);
00146 gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00147 }
00148 else
00149 if( (do_trans_A == true) && (do_trans_B == true) )
00150 {
00151
00152
00153
00154
00155
00156
00157
00158 podarray<in_eT2> tmp(B.n_cols);
00159 in_eT2* B_rowdata = tmp.memptr();
00160
00161 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00162 {
00163
00164 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00165 {
00166 B_rowdata[col_B] = B.at(row_B,col_B);
00167 }
00168
00169 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00170 {
00171 const in_eT1* A_coldata = A.colptr(col_A);
00172
00173 out_eT acc = out_eT(0);
00174 for(u32 i=0; i < A_n_rows; ++i)
00175 {
00176 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00177 }
00178
00179 if( (use_alpha == false) && (use_beta == false) )
00180 {
00181 C.at(col_A,row_B) = acc;
00182 }
00183 else
00184 if( (use_alpha == true) && (use_beta == false) )
00185 {
00186 C.at(col_A,row_B) = alpha * acc;
00187 }
00188 else
00189 if( (use_alpha == false) && (use_beta == true) )
00190 {
00191 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00192 }
00193 else
00194 if( (use_alpha == true) && (use_beta == true) )
00195 {
00196 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00197 }
00198
00199 }
00200 }
00201
00202 }
00203 }
00204
00205 };
00206
00207
00208
00209
00210
00211
00212 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00213 class gemm_mixed_simple
00214 {
00215 public:
00216
00217 template<typename out_eT, typename in_eT1, typename in_eT2>
00218 arma_hot
00219 inline
00220 static
00221 void
00222 apply
00223 (
00224 Mat<out_eT>& C,
00225 const Mat<in_eT1>& A,
00226 const Mat<in_eT2>& B,
00227 const out_eT alpha = out_eT(1),
00228 const out_eT beta = out_eT(0)
00229 )
00230 {
00231 arma_extra_debug_sigprint();
00232
00233 const u32 A_n_rows = A.n_rows;
00234 const u32 A_n_cols = A.n_cols;
00235
00236 const u32 B_n_rows = B.n_rows;
00237 const u32 B_n_cols = B.n_cols;
00238
00239 if( (do_trans_A == false) && (do_trans_B == false) )
00240 {
00241 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00242 {
00243 for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00244 {
00245 const in_eT2* B_coldata = B.colptr(col_B);
00246
00247 out_eT acc = out_eT(0);
00248 for(u32 i = 0; i < B_n_rows; ++i)
00249 {
00250 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00251 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00252 acc += val1 * val2;
00253
00254 }
00255
00256 if( (use_alpha == false) && (use_beta == false) )
00257 {
00258 C.at(row_A,col_B) = acc;
00259 }
00260 else
00261 if( (use_alpha == true) && (use_beta == false) )
00262 {
00263 C.at(row_A,col_B) = alpha * acc;
00264 }
00265 else
00266 if( (use_alpha == false) && (use_beta == true) )
00267 {
00268 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00269 }
00270 else
00271 if( (use_alpha == true) && (use_beta == true) )
00272 {
00273 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00274 }
00275 }
00276 }
00277 }
00278 else
00279 if( (do_trans_A == true) && (do_trans_B == false) )
00280 {
00281 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00282 {
00283
00284
00285 const in_eT1* A_coldata = A.colptr(col_A);
00286
00287 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00288 {
00289 const in_eT2* B_coldata = B.colptr(col_B);
00290
00291 out_eT acc = out_eT(0);
00292 for(u32 i=0; i < B_n_rows; ++i)
00293 {
00294 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00295 }
00296
00297 if( (use_alpha == false) && (use_beta == false) )
00298 {
00299 C.at(col_A,col_B) = acc;
00300 }
00301 else
00302 if( (use_alpha == true) && (use_beta == false) )
00303 {
00304 C.at(col_A,col_B) = alpha * acc;
00305 }
00306 else
00307 if( (use_alpha == false) && (use_beta == true) )
00308 {
00309 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00310 }
00311 else
00312 if( (use_alpha == true) && (use_beta == true) )
00313 {
00314 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00315 }
00316
00317 }
00318 }
00319 }
00320 else
00321 if( (do_trans_A == false) && (do_trans_B == true) )
00322 {
00323 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00324 {
00325 for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00326 {
00327 out_eT acc = out_eT(0);
00328 for(u32 i = 0; i < B_n_cols; ++i)
00329 {
00330 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00331 }
00332
00333 if( (use_alpha == false) && (use_beta == false) )
00334 {
00335 C.at(row_A,row_B) = acc;
00336 }
00337 else
00338 if( (use_alpha == true) && (use_beta == false) )
00339 {
00340 C.at(row_A,row_B) = alpha * acc;
00341 }
00342 else
00343 if( (use_alpha == false) && (use_beta == true) )
00344 {
00345 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00346 }
00347 else
00348 if( (use_alpha == true) && (use_beta == true) )
00349 {
00350 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00351 }
00352 }
00353 }
00354 }
00355 else
00356 if( (do_trans_A == true) && (do_trans_B == true) )
00357 {
00358 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00359 {
00360
00361 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00362 {
00363 const in_eT1* A_coldata = A.colptr(col_A);
00364
00365 out_eT acc = out_eT(0);
00366 for(u32 i=0; i < A_n_rows; ++i)
00367 {
00368 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00369 }
00370
00371 if( (use_alpha == false) && (use_beta == false) )
00372 {
00373 C.at(col_A,row_B) = acc;
00374 }
00375 else
00376 if( (use_alpha == true) && (use_beta == false) )
00377 {
00378 C.at(col_A,row_B) = alpha * acc;
00379 }
00380 else
00381 if( (use_alpha == false) && (use_beta == true) )
00382 {
00383 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00384 }
00385 else
00386 if( (use_alpha == true) && (use_beta == true) )
00387 {
00388 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00389 }
00390
00391 }
00392 }
00393
00394 }
00395 }
00396
00397 };
00398
00399
00400
00401
00402
00403
00404
00405
00406 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00407 class gemm_mixed
00408 {
00409 public:
00410
00411
00412 template<typename out_eT, typename in_eT1, typename in_eT2>
00413 inline
00414 static
00415 void
00416 apply
00417 (
00418 Mat<out_eT>& C,
00419 const Mat<in_eT1>& A,
00420 const Mat<in_eT2>& B,
00421 const out_eT alpha = out_eT(1),
00422 const out_eT beta = out_eT(0)
00423 )
00424 {
00425 arma_extra_debug_sigprint();
00426
00427 if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00428 {
00429 gemm_mixed_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00430 }
00431 else
00432 {
00433 gemm_mixed_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00434 }
00435 }
00436
00437 };
00438
00439
00440
00441