fn_as_scalar.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup fn_as_scalar
00018 //! @{
00019 
00020 
00021 
00022 template<u32 N>
00023 struct as_scalar_redirect
00024   {
00025   template<typename T1>
00026   inline static typename T1::elem_type apply(const T1& X);
00027   };
00028 
00029 
00030 
00031 template<>
00032 struct as_scalar_redirect<2>
00033   {
00034   template<typename T1, typename T2>
00035   inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
00036   };
00037 
00038 
00039 template<>
00040 struct as_scalar_redirect<3>
00041   {
00042   template<typename T1, typename T2, typename T3>
00043   inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
00044   };
00045 
00046 
00047 
00048 template<u32 N>
00049 template<typename T1>
00050 inline
00051 typename T1::elem_type
00052 as_scalar_redirect<N>::apply(const T1& X)
00053   {
00054   arma_extra_debug_sigprint();
00055   
00056   typedef typename T1::elem_type eT;
00057   
00058   const unwrap<T1>   tmp(X);
00059   const Mat<eT>& A = tmp.M;
00060   
00061   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00062   
00063   return A.mem[0];
00064   }
00065 
00066 
00067 
00068 template<typename T1, typename T2>
00069 inline
00070 typename T1::elem_type
00071 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
00072   {
00073   arma_extra_debug_sigprint();
00074   
00075   typedef typename T1::elem_type eT;
00076   
00077   // T1 must result in a matrix with one row
00078   // T2 must result in a matrix with one column
00079   
00080   const partial_unwrap<T1> tmp1(X.A);
00081   const partial_unwrap<T2> tmp2(X.B);
00082   
00083   const Mat<eT>& A = tmp1.M;
00084   const Mat<eT>& B = tmp2.M;
00085   
00086   const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00087   const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00088   
00089   const u32 B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
00090   const u32 B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
00091   
00092   const eT val = tmp1.val * tmp2.val;
00093   
00094   arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
00095   
00096   return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00097   }
00098 
00099 
00100 
00101 template<typename T1, typename T2, typename T3>
00102 inline
00103 typename T1::elem_type
00104 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
00105   {
00106   arma_extra_debug_sigprint();
00107   
00108   typedef typename T1::elem_type eT;
00109   
00110   // T1 * T2 must result in a matrix with one row
00111   // T3 must result in a matrix with one column
00112   
00113   typedef typename strip_inv    <T2           >::stored_type T2_stripped_1;
00114   typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
00115   
00116   const strip_inv    <T2>            strip1(X.A.B);
00117   const strip_diagmat<T2_stripped_1> strip2(strip1.M);
00118   
00119   const bool tmp2_do_inv     = strip1.do_inv;
00120   const bool tmp2_do_diagmat = strip2.do_diagmat;
00121   
00122   const partial_unwrap<T1>            tmp1(X.A.A);
00123   const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
00124   const partial_unwrap<T3>            tmp3(X.B);
00125   
00126   const Mat<eT>& A = tmp1.M;
00127   const Mat<eT>& B = tmp2.M;
00128   const Mat<eT>& C = tmp3.M;
00129   
00130   
00131   if(tmp2_do_diagmat == false)
00132     {
00133     const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00134     const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00135     
00136     const u32 B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
00137     const u32 B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
00138     
00139     const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00140     const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00141     
00142     const eT val = tmp1.val * tmp2.val * tmp3.val;
00143     
00144     arma_debug_check
00145       (
00146       (A_n_rows != 1)        ||
00147       (C_n_cols != 1)        ||
00148       (A_n_cols != B_n_rows) ||
00149       (B_n_cols != C_n_rows)
00150       ,
00151       "as_scalar(): incompatible dimensions"
00152       );
00153     
00154     
00155     if(tmp2_do_inv == true)
00156       {
00157       arma_debug_check( (B.is_square() == false), "as_scalar(): incompatible dimensions" );
00158       
00159       Mat<eT> B_inv;
00160       
00161       if(tmp2.do_trans == false)
00162         {
00163         op_inv::apply(B_inv, B);
00164         }
00165       else
00166         {
00167         const Mat<eT> B_trans = trans(B);
00168         op_inv::apply(B_inv, B_trans);
00169         }
00170       
00171       return val * op_dotext::direct_rowvec_mat_colvec(A.mem, B_inv, C.mem);
00172       }
00173     else
00174       {
00175       if(tmp2.do_trans == false)
00176         {
00177         return val * op_dotext::direct_rowvec_mat_colvec(A.mem, B, C.mem);
00178         }
00179       else
00180         {
00181         return val * op_dotext::direct_rowvec_transmat_colvec(A.mem, B, C.mem);
00182         }
00183       }
00184     }
00185   else
00186     {
00187     const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00188     const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00189     
00190     const bool B_is_vec = B.is_vec();
00191     
00192     const u32 B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00193     const u32 B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00194     
00195     const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00196     const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00197     
00198     const eT val = tmp1.val * tmp2.val * tmp3.val;
00199     
00200     arma_debug_check
00201       (
00202       (A_n_rows != 1)        ||
00203       (C_n_cols != 1)        ||
00204       (A_n_cols != B_n_rows) ||
00205       (B_n_cols != C_n_rows)
00206       ,
00207       "as_scalar(): incompatible dimensions"
00208       );
00209     
00210     
00211     if(B_is_vec == true)
00212       {
00213       if(tmp2_do_inv == true)
00214         {
00215         return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
00216         }
00217       else
00218         {
00219         return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00220         }
00221       }
00222     else
00223       {
00224       if(tmp2_do_inv == true)
00225         {
00226         return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
00227         }
00228       else
00229         {
00230         return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00231         }
00232       }
00233     }
00234   }
00235 
00236 
00237 
00238 template<typename T1>
00239 inline
00240 typename T1::elem_type
00241 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
00242   {
00243   arma_extra_debug_sigprint();
00244   
00245   typedef typename T1::elem_type eT;
00246   
00247   const unwrap<T1>   tmp(X.get_ref());
00248   const Mat<eT>& A = tmp.M;
00249   
00250   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00251   
00252   return A.mem[0];
00253   }
00254 
00255 
00256 
00257 template<typename T1, typename T2, typename T3>
00258 inline
00259 typename T1::elem_type
00260 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
00261   {
00262   arma_extra_debug_sigprint();
00263   
00264   typedef typename T1::elem_type eT;
00265   
00266   // T1 * T2 must result in a matrix with one row
00267   // T3 must result in a matrix with one column
00268   
00269   typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00270   
00271   const strip_diagmat<T2> strip(X.A.B);
00272   
00273   const partial_unwrap<T1>          tmp1(X.A.A);
00274   const partial_unwrap<T2_stripped> tmp2(strip.M);
00275   const partial_unwrap<T3>          tmp3(X.B);
00276   
00277   const Mat<eT>& A = tmp1.M;
00278   const Mat<eT>& B = tmp2.M;
00279   const Mat<eT>& C = tmp3.M;
00280   
00281   
00282   const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00283   const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00284   
00285   const bool B_is_vec = B.is_vec();
00286   
00287   const u32 B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00288   const u32 B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00289   
00290   const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00291   const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00292   
00293   const eT val = tmp1.val * tmp2.val * tmp3.val;
00294   
00295   arma_debug_check
00296     (
00297     (A_n_rows != 1)        ||
00298     (C_n_cols != 1)        ||
00299     (A_n_cols != B_n_rows) ||
00300     (B_n_cols != C_n_rows)
00301     ,
00302     "as_scalar(): incompatible dimensions"
00303     );
00304   
00305   
00306   if(B_is_vec == true)
00307     {
00308     return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00309     }
00310   else
00311     {
00312     return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00313     }
00314   }
00315 
00316 
00317 
00318 template<typename T1, typename T2>
00319 arma_inline
00320 arma_warn_unused
00321 typename T1::elem_type
00322 as_scalar(const Glue<T1, T2, glue_times>& X)
00323   {
00324   arma_extra_debug_sigprint();
00325   
00326   if(is_glue_times_diag<T1>::value == false)
00327     {
00328     const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00329     
00330     arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00331     
00332     return as_scalar_redirect<N_mat>::apply(X);
00333     }
00334   else
00335     {
00336     return as_scalar_diag(X);
00337     }
00338   }
00339 
00340 
00341 
00342 template<typename T1>
00343 inline
00344 arma_warn_unused
00345 typename T1::elem_type
00346 as_scalar(const Base<typename T1::elem_type,T1>& X)
00347   {
00348   arma_extra_debug_sigprint();
00349   
00350   typedef typename T1::elem_type eT;
00351   
00352   const unwrap<T1>   tmp(X.get_ref());
00353   const Mat<eT>& A = tmp.M;
00354   
00355   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00356   
00357   return A.mem[0];
00358   }
00359 
00360 
00361 
00362 template<typename T1>
00363 arma_inline
00364 arma_warn_unused
00365 typename T1::elem_type
00366 as_scalar(const eOp<T1, eop_neg>& X)
00367   {
00368   arma_extra_debug_sigprint();
00369   
00370   return -(as_scalar(X.P.Q));
00371   }
00372 
00373 
00374 
00375 template<typename T1>
00376 inline
00377 arma_warn_unused
00378 typename T1::elem_type
00379 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
00380   {
00381   arma_extra_debug_sigprint();
00382   
00383   typedef typename T1::elem_type eT;
00384   
00385   const unwrap_cube<T1> tmp(X.get_ref());
00386   const Cube<eT>& A   = tmp.M;
00387   
00388   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00389   
00390   return A.mem[0];
00391   }
00392 
00393 
00394 
00395 template<typename T>
00396 arma_inline
00397 arma_warn_unused
00398 const typename arma_scalar_only<T>::result &
00399 as_scalar(const T& x)
00400   {
00401   return x;
00402   }
00403 
00404 
00405 
00406 //! @}