Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 27 of file gemm_mixed.hpp.
static void gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 36 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), podarray< T1 >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00043 { 00044 arma_extra_debug_sigprint(); 00045 00046 const u32 A_n_rows = A.n_rows; 00047 const u32 A_n_cols = A.n_cols; 00048 00049 const u32 B_n_rows = B.n_rows; 00050 const u32 B_n_cols = B.n_cols; 00051 00052 if( (do_trans_A == false) && (do_trans_B == false) ) 00053 { 00054 podarray<in_eT1> tmp(A_n_cols); 00055 in_eT1* A_rowdata = tmp.memptr(); 00056 00057 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00058 { 00059 00060 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00061 { 00062 A_rowdata[col_A] = A.at(row_A,col_A); 00063 } 00064 00065 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00066 { 00067 const in_eT2* B_coldata = B.colptr(col_B); 00068 00069 out_eT acc = out_eT(0); 00070 for(u32 i=0; i < B_n_rows; ++i) 00071 { 00072 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00073 } 00074 00075 if( (use_alpha == false) && (use_beta == false) ) 00076 { 00077 C.at(row_A,col_B) = acc; 00078 } 00079 else 00080 if( (use_alpha == true) && (use_beta == false) ) 00081 { 00082 C.at(row_A,col_B) = alpha * acc; 00083 } 00084 else 00085 if( (use_alpha == false) && (use_beta == true) ) 00086 { 00087 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00088 } 00089 else 00090 if( (use_alpha == true) && (use_beta == true) ) 00091 { 00092 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00093 } 00094 00095 } 00096 } 00097 } 00098 else 00099 if( (do_trans_A == true) && (do_trans_B == false) ) 00100 { 00101 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00102 { 00103 // col_A is interpreted as row_A when storing the results in matrix C 00104 00105 const in_eT1* A_coldata = A.colptr(col_A); 00106 00107 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00108 { 00109 const in_eT2* B_coldata = B.colptr(col_B); 00110 00111 out_eT acc = out_eT(0); 00112 for(u32 i=0; i < B_n_rows; ++i) 00113 { 00114 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00115 } 00116 00117 if( (use_alpha == false) && (use_beta == false) ) 00118 { 00119 C.at(col_A,col_B) = acc; 00120 } 00121 else 00122 if( (use_alpha == true) && (use_beta == false) ) 00123 { 00124 C.at(col_A,col_B) = alpha * acc; 00125 } 00126 else 00127 if( (use_alpha == false) && (use_beta == true) ) 00128 { 00129 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00130 } 00131 else 00132 if( (use_alpha == true) && (use_beta == true) ) 00133 { 00134 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00135 } 00136 00137 } 00138 } 00139 } 00140 else 00141 if( (do_trans_A == false) && (do_trans_B == true) ) 00142 { 00143 Mat<in_eT2> B_tmp = trans(B); 00144 gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00145 } 00146 else 00147 if( (do_trans_A == true) && (do_trans_B == true) ) 00148 { 00149 // mat B_tmp = trans(B); 00150 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00151 00152 00153 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00154 // transpose operations are not needed 00155 00156 podarray<in_eT2> tmp(B.n_cols); 00157 in_eT2* B_rowdata = tmp.memptr(); 00158 00159 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00160 { 00161 00162 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00163 { 00164 B_rowdata[col_B] = B.at(row_B,col_B); 00165 } 00166 00167 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00168 { 00169 const in_eT1* A_coldata = A.colptr(col_A); 00170 00171 out_eT acc = out_eT(0); 00172 for(u32 i=0; i < A_n_rows; ++i) 00173 { 00174 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00175 } 00176 00177 if( (use_alpha == false) && (use_beta == false) ) 00178 { 00179 C.at(col_A,row_B) = acc; 00180 } 00181 else 00182 if( (use_alpha == true) && (use_beta == false) ) 00183 { 00184 C.at(col_A,row_B) = alpha * acc; 00185 } 00186 else 00187 if( (use_alpha == false) && (use_beta == true) ) 00188 { 00189 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00190 } 00191 else 00192 if( (use_alpha == true) && (use_beta == true) ) 00193 { 00194 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00195 } 00196 00197 } 00198 } 00199 00200 } 00201 }