Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm.hpp>
Static Public Member Functions | |
template<typename eT > | |
static void | apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0)) |
Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 26 of file gemm.hpp.
static void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< eT > & | C, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | alpha = eT(1) , |
|||
const eT | beta = eT(0) | |||
) | [inline, static] |
Definition at line 35 of file gemm.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00042 { 00043 arma_extra_debug_sigprint(); 00044 00045 const u32 A_n_rows = A.n_rows; 00046 const u32 A_n_cols = A.n_cols; 00047 00048 const u32 B_n_rows = B.n_rows; 00049 const u32 B_n_cols = B.n_cols; 00050 00051 if( (do_trans_A == false) && (do_trans_B == false) ) 00052 { 00053 arma_aligned podarray<eT> tmp(A_n_cols); 00054 eT* A_rowdata = tmp.memptr(); 00055 00056 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00057 { 00058 00059 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00060 { 00061 A_rowdata[col_A] = A.at(row_A,col_A); 00062 } 00063 00064 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00065 { 00066 const eT* B_coldata = B.colptr(col_B); 00067 00068 eT acc = eT(0); 00069 for(u32 i=0; i < B_n_rows; ++i) 00070 { 00071 acc += A_rowdata[i] * B_coldata[i]; 00072 } 00073 00074 if( (use_alpha == false) && (use_beta == false) ) 00075 { 00076 C.at(row_A,col_B) = acc; 00077 } 00078 else 00079 if( (use_alpha == true) && (use_beta == false) ) 00080 { 00081 C.at(row_A,col_B) = alpha * acc; 00082 } 00083 else 00084 if( (use_alpha == false) && (use_beta == true) ) 00085 { 00086 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00087 } 00088 else 00089 if( (use_alpha == true) && (use_beta == true) ) 00090 { 00091 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00092 } 00093 00094 } 00095 } 00096 } 00097 else 00098 if( (do_trans_A == true) && (do_trans_B == false) ) 00099 { 00100 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00101 { 00102 // col_A is interpreted as row_A when storing the results in matrix C 00103 00104 const eT* A_coldata = A.colptr(col_A); 00105 00106 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00107 { 00108 const eT* B_coldata = B.colptr(col_B); 00109 00110 eT acc = eT(0); 00111 for(u32 i=0; i < B_n_rows; ++i) 00112 { 00113 acc += A_coldata[i] * B_coldata[i]; 00114 } 00115 00116 if( (use_alpha == false) && (use_beta == false) ) 00117 { 00118 C.at(col_A,col_B) = acc; 00119 } 00120 else 00121 if( (use_alpha == true) && (use_beta == false) ) 00122 { 00123 C.at(col_A,col_B) = alpha * acc; 00124 } 00125 else 00126 if( (use_alpha == false) && (use_beta == true) ) 00127 { 00128 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00129 } 00130 else 00131 if( (use_alpha == true) && (use_beta == true) ) 00132 { 00133 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00134 } 00135 00136 } 00137 } 00138 } 00139 else 00140 if( (do_trans_A == false) && (do_trans_B == true) ) 00141 { 00142 Mat<eT> B_tmp = trans(B); 00143 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00144 } 00145 else 00146 if( (do_trans_A == true) && (do_trans_B == true) ) 00147 { 00148 // mat B_tmp = trans(B); 00149 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00150 00151 00152 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00153 // transpose operations are not needed 00154 00155 arma_aligned podarray<eT> tmp(B.n_cols); 00156 eT* B_rowdata = tmp.memptr(); 00157 00158 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00159 { 00160 00161 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00162 { 00163 B_rowdata[col_B] = B.at(row_B,col_B); 00164 } 00165 00166 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00167 { 00168 const eT* A_coldata = A.colptr(col_A); 00169 00170 eT acc = eT(0); 00171 for(u32 i=0; i < A_n_rows; ++i) 00172 { 00173 acc += B_rowdata[i] * A_coldata[i]; 00174 } 00175 00176 if( (use_alpha == false) && (use_beta == false) ) 00177 { 00178 C.at(col_A,row_B) = acc; 00179 } 00180 else 00181 if( (use_alpha == true) && (use_beta == false) ) 00182 { 00183 C.at(col_A,row_B) = alpha * acc; 00184 } 00185 else 00186 if( (use_alpha == false) && (use_beta == true) ) 00187 { 00188 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00189 } 00190 else 00191 if( (use_alpha == true) && (use_beta == true) ) 00192 { 00193 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00194 } 00195 00196 } 00197 } 00198 00199 } 00200 }