op_dot_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup op_dot
00018 //! @{
00019 
00020 
00021 
00022 
00023 //! for two arrays, generic version
00024 template<typename eT>
00025 arma_hot
00026 arma_pure
00027 inline
00028 eT
00029 op_dot::direct_dot_arma(const u32 n_elem, const eT* const A, const eT* const B)
00030   {
00031   arma_extra_debug_sigprint();
00032   
00033   eT val1 = eT(0);
00034   eT val2 = eT(0);
00035   
00036   u32 i, j;
00037   
00038   for(i=0, j=1; j<n_elem; i+=2, j+=2)
00039     {
00040     val1 += A[i] * B[i];
00041     val2 += A[j] * B[j];
00042     }
00043   
00044   if(i < n_elem)
00045     {
00046     val1 += A[i] * B[i];
00047     }
00048   
00049   return val1 + val2;
00050   }
00051 
00052 
00053 
00054 //! for two arrays, float and double version
00055 template<typename eT>
00056 arma_hot
00057 arma_pure
00058 inline
00059 typename arma_float_only<eT>::result
00060 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00061   {
00062   arma_extra_debug_sigprint();
00063   
00064   if( n_elem <= (128/sizeof(eT)) )
00065     {
00066     return op_dot::direct_dot_arma(n_elem, A, B);
00067     }
00068   else
00069     {
00070     #if defined(ARMA_USE_ATLAS)
00071       {
00072       return atlas::cblas_dot(n_elem, A, B);
00073       }
00074     #elif defined(ARMA_USE_BLAS)
00075       {
00076       const int n = n_elem;
00077       return blas::dot_(&n, A, B);
00078       }
00079     #else
00080       {
00081       return op_dot::direct_dot_arma(n_elem, A, B);
00082       }
00083     #endif
00084     }
00085   }
00086 
00087 
00088 
00089 //! for two arrays, complex version
00090 template<typename eT>
00091 inline
00092 arma_hot
00093 arma_pure
00094 typename arma_cx_only<eT>::result
00095 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00096   {
00097   #if defined(ARMA_USE_ATLAS)
00098     {
00099     return atlas::cx_cblas_dot(n_elem, A, B);
00100     }
00101   #elif defined(ARMA_USE_BLAS)
00102     {
00103     // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS
00104     return op_dot::direct_dot_arma(n_elem, A, B);
00105     }
00106   #else
00107     {
00108     return op_dot::direct_dot_arma(n_elem, A, B);
00109     }
00110   #endif
00111   }
00112 
00113 
00114 
00115 //! for two arrays, integral version
00116 template<typename eT>
00117 arma_hot
00118 arma_pure
00119 inline
00120 typename arma_integral_only<eT>::result
00121 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00122   {
00123   return op_dot::direct_dot_arma(n_elem, A, B);
00124   }
00125 
00126 
00127 
00128 
00129 //! for three arrays
00130 template<typename eT>
00131 arma_hot
00132 arma_pure
00133 inline
00134 eT
00135 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B, const eT* C)
00136   {
00137   arma_extra_debug_sigprint();
00138   
00139   eT val = eT(0);
00140   
00141   for(u32 i=0; i<n_elem; ++i)
00142     {
00143     val += A[i] * B[i] * C[i];
00144     }
00145 
00146   return val;
00147   }
00148 
00149 
00150 
00151 template<typename T1, typename T2>
00152 arma_hot
00153 arma_inline
00154 typename T1::elem_type
00155 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00156   {
00157   arma_extra_debug_sigprint();
00158   
00159   if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00160     {
00161     return op_dot::apply_unwrap(X,Y);
00162     }
00163   else
00164     {
00165     return op_dot::apply_proxy(X,Y);
00166     }
00167   }
00168 
00169 
00170 
00171 template<typename T1, typename T2>
00172 arma_hot
00173 arma_inline
00174 typename T1::elem_type
00175 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00176   {
00177   arma_extra_debug_sigprint();
00178   
00179   typedef typename T1::elem_type eT;
00180   
00181   const unwrap<T1> tmp1(X.get_ref());
00182   const unwrap<T2> tmp2(Y.get_ref());
00183   
00184   const Mat<eT>& A = tmp1.M;
00185   const Mat<eT>& B = tmp2.M;
00186   
00187   arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00188   
00189   return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00190   }
00191 
00192 
00193 
00194 template<typename T1, typename T2>
00195 arma_hot
00196 inline
00197 typename T1::elem_type
00198 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00199   {
00200   arma_extra_debug_sigprint();
00201   
00202   typedef typename T1::elem_type eT;
00203   
00204   const Proxy<T1> A(X.get_ref());
00205   const Proxy<T2> B(Y.get_ref());
00206   
00207   arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00208   
00209   const u32 n_elem = A.n_elem;
00210   eT val = eT(0);
00211   
00212   for(u32 i=0; i<n_elem; ++i)
00213     {
00214     val += A[i] * B[i];
00215     }
00216   
00217   return val;
00218   }
00219 
00220 
00221 
00222 //
00223 
00224 
00225 
00226 template<typename T1, typename T2>
00227 arma_hot
00228 arma_inline
00229 typename T1::elem_type
00230 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00231   {
00232   arma_extra_debug_sigprint();
00233   
00234   if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00235     {
00236     return op_norm_dot::apply_unwrap(X,Y);
00237     }
00238   else
00239     {
00240     return op_norm_dot::apply_proxy(X,Y);
00241     }
00242   }
00243 
00244 
00245 
00246 template<typename T1, typename T2>
00247 arma_hot
00248 inline
00249 typename T1::elem_type
00250 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00251   {
00252   arma_extra_debug_sigprint();
00253   
00254   typedef typename T1::elem_type eT;
00255   
00256   const unwrap<T1> tmp1(X.get_ref());
00257   const unwrap<T2> tmp2(Y.get_ref());
00258   
00259   const Mat<eT>& A = tmp1.M;
00260   const Mat<eT>& B = tmp2.M;
00261 
00262   arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00263   
00264   const eT* A_mem = A.memptr();
00265   const eT* B_mem = B.memptr();
00266   
00267   const u32 N = A.n_elem;
00268   
00269   eT acc1 = eT(0);
00270   eT acc2 = eT(0);
00271   eT acc3 = eT(0);
00272   
00273   for(u32 i=0; i<N; ++i)
00274     {
00275     const eT tmpA = A_mem[i];
00276     const eT tmpB = B_mem[i];
00277     
00278     acc1 += tmpA * tmpA;
00279     acc2 += tmpB * tmpB;
00280     acc3 += tmpA * tmpB;
00281     }
00282     
00283   return acc3 / ( std::sqrt(acc1 * acc2) );
00284   }
00285 
00286 
00287 
00288 template<typename T1, typename T2>
00289 arma_hot
00290 inline
00291 typename T1::elem_type
00292 op_norm_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00293   {
00294   arma_extra_debug_sigprint();
00295   
00296   typedef typename T1::elem_type eT;
00297   
00298   const Proxy<T1> A(X.get_ref());
00299   const Proxy<T2> B(Y.get_ref());
00300 
00301   arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00302   
00303   const u32 N = A.n_elem;
00304   
00305   eT acc1 = eT(0);
00306   eT acc2 = eT(0);
00307   eT acc3 = eT(0);
00308   
00309   for(u32 i=0; i<N; ++i)
00310     {
00311     const eT tmpA = A[i];
00312     const eT tmpB = B[i];
00313     
00314     acc1 += tmpA * tmpA;
00315     acc2 += tmpB * tmpB;
00316     acc3 += tmpA * tmpB;
00317     }
00318     
00319   return acc3 / ( std::sqrt(acc1 * acc2) );
00320   }
00321 
00322 
00323 
00324 //! @}