op_dot_meat.hpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 template<typename eT>
00024 inline
00025 arma_hot
00026 arma_pure
00027 eT
00028 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00029 {
00030 arma_extra_debug_sigprint();
00031
00032 eT val1 = eT(0);
00033 eT val2 = eT(0);
00034
00035 u32 i,j;
00036 for(i=0, j=1; j<n_elem; i+=2, j+=2)
00037 {
00038 val1 += A[i] * B[i];
00039 val2 += A[j] * B[j];
00040 }
00041
00042 if(i < n_elem)
00043 {
00044 val1 += A[i] * B[i];
00045 }
00046
00047 return val1+val2;
00048 }
00049
00050
00051
00052
00053 template<typename eT>
00054 inline
00055 arma_hot
00056 arma_pure
00057 eT
00058 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B, const eT* C)
00059 {
00060 arma_extra_debug_sigprint();
00061
00062 eT val = eT(0);
00063
00064 for(u32 i=0; i<n_elem; ++i)
00065 {
00066 val += A[i] * B[i] * C[i];
00067 }
00068
00069 return val;
00070 }
00071
00072
00073
00074 template<typename T1, typename T2>
00075 arma_inline
00076 arma_hot
00077 typename T1::elem_type
00078 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00079 {
00080 arma_extra_debug_sigprint();
00081
00082 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00083 {
00084 return op_dot::apply_unwrap(X,Y);
00085 }
00086 else
00087 {
00088 return op_dot::apply_proxy(X,Y);
00089 }
00090 }
00091
00092
00093
00094 template<typename T1, typename T2>
00095 arma_inline
00096 arma_hot
00097 typename T1::elem_type
00098 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00099 {
00100 arma_extra_debug_sigprint();
00101
00102 typedef typename T1::elem_type eT;
00103
00104 const unwrap<T1> tmp1(X.get_ref());
00105 const unwrap<T2> tmp2(Y.get_ref());
00106
00107 const Mat<eT>& A = tmp1.M;
00108 const Mat<eT>& B = tmp2.M;
00109
00110 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00111
00112 return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00113 }
00114
00115
00116
00117 template<typename T1, typename T2>
00118 inline
00119 arma_hot
00120 typename T1::elem_type
00121 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00122 {
00123 arma_extra_debug_sigprint();
00124
00125 typedef typename T1::elem_type eT;
00126
00127 const Proxy<T1> A(X.get_ref());
00128 const Proxy<T2> B(Y.get_ref());
00129
00130 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00131
00132 const u32 n_elem = A.n_elem;
00133 eT val = eT(0);
00134
00135 for(u32 i=0; i<n_elem; ++i)
00136 {
00137 val += A[i] * B[i];
00138 }
00139
00140 return val;
00141 }
00142
00143
00144
00145
00146
00147
00148
00149 template<typename T1, typename T2>
00150 arma_inline
00151 arma_hot
00152 typename T1::elem_type
00153 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00154 {
00155 arma_extra_debug_sigprint();
00156
00157 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00158 {
00159 return op_norm_dot::apply_unwrap(X,Y);
00160 }
00161 else
00162 {
00163 return op_norm_dot::apply_proxy(X,Y);
00164 }
00165 }
00166
00167
00168
00169 template<typename T1, typename T2>
00170 inline
00171 arma_hot
00172 typename T1::elem_type
00173 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00174 {
00175 arma_extra_debug_sigprint();
00176
00177 typedef typename T1::elem_type eT;
00178
00179 const unwrap<T1> tmp1(X.get_ref());
00180 const unwrap<T2> tmp2(Y.get_ref());
00181
00182 const Mat<eT>& A = tmp1.M;
00183 const Mat<eT>& B = tmp2.M;
00184
00185 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00186
00187 const eT* A_mem = A.memptr();
00188 const eT* B_mem = B.memptr();
00189
00190 const u32 N = A.n_elem;
00191
00192 eT acc1 = eT(0);
00193 eT acc2 = eT(0);
00194 eT acc3 = eT(0);
00195
00196 for(u32 i=0; i<N; ++i)
00197 {
00198 const eT tmpA = A_mem[i];
00199 const eT tmpB = B_mem[i];
00200
00201 acc1 += tmpA * tmpA;
00202 acc2 += tmpB * tmpB;
00203 acc3 += tmpA * tmpB;
00204 }
00205
00206 return acc3 / ( std::sqrt(acc1 * acc2) );
00207 }
00208
00209
00210
00211 template<typename T1, typename T2>
00212 inline
00213 arma_hot
00214 typename T1::elem_type
00215 op_norm_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00216 {
00217 arma_extra_debug_sigprint();
00218
00219 typedef typename T1::elem_type eT;
00220
00221 const Proxy<T1> A(X.get_ref());
00222 const Proxy<T2> B(Y.get_ref());
00223
00224 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00225
00226 const u32 N = A.n_elem;
00227
00228 eT acc1 = eT(0);
00229 eT acc2 = eT(0);
00230 eT acc3 = eT(0);
00231
00232 for(u32 i=0; i<N; ++i)
00233 {
00234 const eT tmpA = A[i];
00235 const eT tmpB = B[i];
00236
00237 acc1 += tmpA * tmpA;
00238 acc2 += tmpB * tmpB;
00239 acc3 += tmpA * tmpB;
00240 }
00241
00242 return acc3 / ( std::sqrt(acc1 * acc2) );
00243 }
00244
00245
00246
00247