00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 template<u32 N>
00023 struct as_scalar_redirect
00024 {
00025 template<typename T1>
00026 inline static typename T1::elem_type apply(const T1& X);
00027 };
00028
00029
00030
00031 template<>
00032 struct as_scalar_redirect<2>
00033 {
00034 template<typename T1, typename T2>
00035 inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
00036 };
00037
00038
00039 template<>
00040 struct as_scalar_redirect<3>
00041 {
00042 template<typename T1, typename T2, typename T3>
00043 inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
00044 };
00045
00046
00047
00048 template<u32 N>
00049 template<typename T1>
00050 inline
00051 typename T1::elem_type
00052 as_scalar_redirect<N>::apply(const T1& X)
00053 {
00054 arma_extra_debug_sigprint();
00055
00056 typedef typename T1::elem_type eT;
00057
00058 const unwrap<T1> tmp(X);
00059 const Mat<eT>& A = tmp.M;
00060
00061 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00062
00063 return A.mem[0];
00064 }
00065
00066
00067
00068 template<typename T1, typename T2>
00069 inline
00070 typename T1::elem_type
00071 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
00072 {
00073 arma_extra_debug_sigprint();
00074
00075 typedef typename T1::elem_type eT;
00076
00077
00078
00079
00080 const partial_unwrap<T1> tmp1(X.A);
00081 const partial_unwrap<T2> tmp2(X.B);
00082
00083 const Mat<eT>& A = tmp1.M;
00084 const Mat<eT>& B = tmp2.M;
00085
00086 const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00087 const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00088
00089 const u32 B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
00090 const u32 B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
00091
00092 const eT val = tmp1.val * tmp2.val;
00093
00094 arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
00095
00096 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00097 }
00098
00099
00100
00101 template<typename T1, typename T2, typename T3>
00102 inline
00103 typename T1::elem_type
00104 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
00105 {
00106 arma_extra_debug_sigprint();
00107
00108 typedef typename T1::elem_type eT;
00109
00110
00111
00112
00113 typedef typename strip_inv <T2 >::stored_type T2_stripped_1;
00114 typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
00115
00116 const strip_inv <T2> strip1(X.A.B);
00117 const strip_diagmat<T2_stripped_1> strip2(strip1.M);
00118
00119 const bool tmp2_do_inv = strip1.do_inv;
00120 const bool tmp2_do_diagmat = strip2.do_diagmat;
00121
00122 const partial_unwrap<T1> tmp1(X.A.A);
00123 const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
00124 const partial_unwrap<T3> tmp3(X.B);
00125
00126 const Mat<eT>& A = tmp1.M;
00127 const Mat<eT>& B = tmp2.M;
00128 const Mat<eT>& C = tmp3.M;
00129
00130
00131 if(tmp2_do_diagmat == false)
00132 {
00133 const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00134 const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00135
00136 const u32 B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
00137 const u32 B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
00138
00139 const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00140 const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00141
00142 const eT val = tmp1.val * tmp2.val * tmp3.val;
00143
00144 arma_debug_check
00145 (
00146 (A_n_rows != 1) ||
00147 (C_n_cols != 1) ||
00148 (A_n_cols != B_n_rows) ||
00149 (B_n_cols != C_n_rows)
00150 ,
00151 "as_scalar(): incompatible dimensions"
00152 );
00153
00154
00155 if(tmp2_do_inv == true)
00156 {
00157 arma_debug_check( (B.is_square() == false), "as_scalar(): incompatible dimensions" );
00158
00159 Mat<eT> B_inv;
00160
00161 if(tmp2.do_trans == false)
00162 {
00163 op_inv::apply(B_inv, B);
00164 }
00165 else
00166 {
00167 const Mat<eT> B_trans = trans(B);
00168 op_inv::apply(B_inv, B_trans);
00169 }
00170
00171 return val * op_dotext::direct_rowvec_mat_colvec(A.mem, B_inv, C.mem);
00172 }
00173 else
00174 {
00175 if(tmp2.do_trans == false)
00176 {
00177 return val * op_dotext::direct_rowvec_mat_colvec(A.mem, B, C.mem);
00178 }
00179 else
00180 {
00181 return val * op_dotext::direct_rowvec_transmat_colvec(A.mem, B, C.mem);
00182 }
00183 }
00184 }
00185 else
00186 {
00187 const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00188 const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00189
00190 const bool B_is_vec = B.is_vec();
00191
00192 const u32 B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00193 const u32 B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00194
00195 const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00196 const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00197
00198 const eT val = tmp1.val * tmp2.val * tmp3.val;
00199
00200 arma_debug_check
00201 (
00202 (A_n_rows != 1) ||
00203 (C_n_cols != 1) ||
00204 (A_n_cols != B_n_rows) ||
00205 (B_n_cols != C_n_rows)
00206 ,
00207 "as_scalar(): incompatible dimensions"
00208 );
00209
00210
00211 if(B_is_vec == true)
00212 {
00213 if(tmp2_do_inv == true)
00214 {
00215 return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
00216 }
00217 else
00218 {
00219 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00220 }
00221 }
00222 else
00223 {
00224 if(tmp2_do_inv == true)
00225 {
00226 return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
00227 }
00228 else
00229 {
00230 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00231 }
00232 }
00233 }
00234 }
00235
00236
00237
00238 template<typename T1>
00239 inline
00240 typename T1::elem_type
00241 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
00242 {
00243 arma_extra_debug_sigprint();
00244
00245 typedef typename T1::elem_type eT;
00246
00247 const unwrap<T1> tmp(X.get_ref());
00248 const Mat<eT>& A = tmp.M;
00249
00250 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00251
00252 return A.mem[0];
00253 }
00254
00255
00256
00257 template<typename T1, typename T2, typename T3>
00258 inline
00259 typename T1::elem_type
00260 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
00261 {
00262 arma_extra_debug_sigprint();
00263
00264 typedef typename T1::elem_type eT;
00265
00266
00267
00268
00269 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00270
00271 const strip_diagmat<T2> strip(X.A.B);
00272
00273 const partial_unwrap<T1> tmp1(X.A.A);
00274 const partial_unwrap<T2_stripped> tmp2(strip.M);
00275 const partial_unwrap<T3> tmp3(X.B);
00276
00277 const Mat<eT>& A = tmp1.M;
00278 const Mat<eT>& B = tmp2.M;
00279 const Mat<eT>& C = tmp3.M;
00280
00281
00282 const u32 A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00283 const u32 A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00284
00285 const bool B_is_vec = B.is_vec();
00286
00287 const u32 B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00288 const u32 B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00289
00290 const u32 C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00291 const u32 C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00292
00293 const eT val = tmp1.val * tmp2.val * tmp3.val;
00294
00295 arma_debug_check
00296 (
00297 (A_n_rows != 1) ||
00298 (C_n_cols != 1) ||
00299 (A_n_cols != B_n_rows) ||
00300 (B_n_cols != C_n_rows)
00301 ,
00302 "as_scalar(): incompatible dimensions"
00303 );
00304
00305
00306 if(B_is_vec == true)
00307 {
00308 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00309 }
00310 else
00311 {
00312 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00313 }
00314 }
00315
00316
00317
00318 template<typename T1, typename T2>
00319 arma_inline
00320 arma_warn_unused
00321 typename T1::elem_type
00322 as_scalar(const Glue<T1, T2, glue_times>& X)
00323 {
00324 arma_extra_debug_sigprint();
00325
00326 if(is_glue_times_diag<T1>::value == false)
00327 {
00328 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00329
00330 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00331
00332 return as_scalar_redirect<N_mat>::apply(X);
00333 }
00334 else
00335 {
00336 return as_scalar_diag(X);
00337 }
00338 }
00339
00340
00341
00342 template<typename T1>
00343 inline
00344 arma_warn_unused
00345 typename T1::elem_type
00346 as_scalar(const Base<typename T1::elem_type,T1>& X)
00347 {
00348 arma_extra_debug_sigprint();
00349
00350 typedef typename T1::elem_type eT;
00351
00352 const unwrap<T1> tmp(X.get_ref());
00353 const Mat<eT>& A = tmp.M;
00354
00355 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00356
00357 return A.mem[0];
00358 }
00359
00360
00361
00362 template<typename T1>
00363 arma_inline
00364 arma_warn_unused
00365 typename T1::elem_type
00366 as_scalar(const eOp<T1, eop_neg>& X)
00367 {
00368 arma_extra_debug_sigprint();
00369
00370 return -(as_scalar(X.P.Q));
00371 }
00372
00373
00374
00375 template<typename T1>
00376 inline
00377 arma_warn_unused
00378 typename T1::elem_type
00379 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
00380 {
00381 arma_extra_debug_sigprint();
00382
00383 typedef typename T1::elem_type eT;
00384
00385 const unwrap_cube<T1> tmp(X.get_ref());
00386 const Cube<eT>& A = tmp.M;
00387
00388 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00389
00390 return A.mem[0];
00391 }
00392
00393
00394
00395 template<typename T>
00396 arma_inline
00397 arma_warn_unused
00398 const typename arma_scalar_only<T>::result &
00399 as_scalar(const T& x)
00400 {
00401 return x;
00402 }
00403
00404
00405
00406