Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 211 of file gemm_mixed.hpp.
static void gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 220 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
00227 { 00228 arma_extra_debug_sigprint(); 00229 00230 const u32 A_n_rows = A.n_rows; 00231 const u32 A_n_cols = A.n_cols; 00232 00233 const u32 B_n_rows = B.n_rows; 00234 const u32 B_n_cols = B.n_cols; 00235 00236 if( (do_trans_A == false) && (do_trans_B == false) ) 00237 { 00238 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00239 { 00240 for(u32 col_B = 0; col_B < B_n_cols; ++col_B) 00241 { 00242 const in_eT2* B_coldata = B.colptr(col_B); 00243 00244 out_eT acc = out_eT(0); 00245 for(u32 i = 0; i < B_n_rows; ++i) 00246 { 00247 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); 00248 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00249 acc += val1 * val2; 00250 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00251 } 00252 00253 if( (use_alpha == false) && (use_beta == false) ) 00254 { 00255 C.at(row_A,col_B) = acc; 00256 } 00257 else 00258 if( (use_alpha == true) && (use_beta == false) ) 00259 { 00260 C.at(row_A,col_B) = alpha * acc; 00261 } 00262 else 00263 if( (use_alpha == false) && (use_beta == true) ) 00264 { 00265 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00266 } 00267 else 00268 if( (use_alpha == true) && (use_beta == true) ) 00269 { 00270 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00271 } 00272 } 00273 } 00274 } 00275 else 00276 if( (do_trans_A == true) && (do_trans_B == false) ) 00277 { 00278 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00279 { 00280 // col_A is interpreted as row_A when storing the results in matrix C 00281 00282 const in_eT1* A_coldata = A.colptr(col_A); 00283 00284 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00285 { 00286 const in_eT2* B_coldata = B.colptr(col_B); 00287 00288 out_eT acc = out_eT(0); 00289 for(u32 i=0; i < B_n_rows; ++i) 00290 { 00291 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00292 } 00293 00294 if( (use_alpha == false) && (use_beta == false) ) 00295 { 00296 C.at(col_A,col_B) = acc; 00297 } 00298 else 00299 if( (use_alpha == true) && (use_beta == false) ) 00300 { 00301 C.at(col_A,col_B) = alpha * acc; 00302 } 00303 else 00304 if( (use_alpha == false) && (use_beta == true) ) 00305 { 00306 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00307 } 00308 else 00309 if( (use_alpha == true) && (use_beta == true) ) 00310 { 00311 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00312 } 00313 00314 } 00315 } 00316 } 00317 else 00318 if( (do_trans_A == false) && (do_trans_B == true) ) 00319 { 00320 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00321 { 00322 for(u32 row_B = 0; row_B < B_n_rows; ++row_B) 00323 { 00324 out_eT acc = out_eT(0); 00325 for(u32 i = 0; i < B_n_cols; ++i) 00326 { 00327 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); 00328 } 00329 00330 if( (use_alpha == false) && (use_beta == false) ) 00331 { 00332 C.at(row_A,row_B) = acc; 00333 } 00334 else 00335 if( (use_alpha == true) && (use_beta == false) ) 00336 { 00337 C.at(row_A,row_B) = alpha * acc; 00338 } 00339 else 00340 if( (use_alpha == false) && (use_beta == true) ) 00341 { 00342 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); 00343 } 00344 else 00345 if( (use_alpha == true) && (use_beta == true) ) 00346 { 00347 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); 00348 } 00349 } 00350 } 00351 } 00352 else 00353 if( (do_trans_A == true) && (do_trans_B == true) ) 00354 { 00355 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00356 { 00357 00358 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00359 { 00360 const in_eT1* A_coldata = A.colptr(col_A); 00361 00362 out_eT acc = out_eT(0); 00363 for(u32 i=0; i < A_n_rows; ++i) 00364 { 00365 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00366 } 00367 00368 if( (use_alpha == false) && (use_beta == false) ) 00369 { 00370 C.at(col_A,row_B) = acc; 00371 } 00372 else 00373 if( (use_alpha == true) && (use_beta == false) ) 00374 { 00375 C.at(col_A,row_B) = alpha * acc; 00376 } 00377 else 00378 if( (use_alpha == false) && (use_beta == true) ) 00379 { 00380 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00381 } 00382 else 00383 if( (use_alpha == true) && (use_beta == true) ) 00384 { 00385 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00386 } 00387 00388 } 00389 } 00390 00391 } 00392 }