00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 template<typename T1, typename T2>
00022 void
00023 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00024 {
00025 arma_extra_debug_sigprint();
00026
00027 typedef typename T1::elem_type eT;
00028
00029 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00030
00031 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00032
00033 if(N_mat == 2)
00034 {
00035 const unwrap<T1> tmp1(X.A);
00036 const unwrap<T2> tmp2(X.B);
00037
00038 glue_times::apply(out, tmp1.M, tmp2.M);
00039 }
00040 else
00041 {
00042
00043
00044 const Mat<eT>* ptrs[N_mat];
00045 bool del[N_mat];
00046
00047
00048 mat_ptrs_outcheck<glue_times, Glue<T1,T2,glue_times> >::get_ptrs(ptrs, del, X, &out);
00049
00050 for(s32 i=0; i<N_mat; ++i) arma_extra_debug_print( arma_boost::format("ptrs[%d] = %x") % i % ptrs[i] );
00051 for(s32 i=0; i<N_mat; ++i) arma_extra_debug_print( arma_boost::format(" del[%d] = %d") % i % del[i] );
00052
00053
00054 arma_extra_debug_print( arma_boost::format("required size of 'out': %d, %d") % ptrs[0]->n_rows % ptrs[N_mat-1]->n_cols );
00055
00056 int order[N_mat]; for(s32 i=0; i<N_mat; ++i) order[i] = -1;
00057
00058 int first_id = 0;
00059 int last_id = N_mat-1;
00060 int starting_id = -1;
00061
00062 int mat_count = N_mat;
00063
00064 int largest_size = 0;
00065
00066 while(mat_count != 0)
00067 {
00068
00069 for(s32 i=first_id; i != N_mat; ++i)
00070 {
00071 if(order[i] == -1) { first_id = i; break; }
00072 }
00073
00074 for(s32 i=last_id; i != -1; --i)
00075 {
00076 if(order[i] == -1) { last_id = i; break; }
00077 }
00078
00079 arma_extra_debug_print();
00080 arma_extra_debug_print(arma_boost::format("mat_count = %d") % mat_count );
00081 arma_extra_debug_print(arma_boost::format("first_id = %d") % first_id );
00082 arma_extra_debug_print(arma_boost::format("last_id = %d") % last_id );
00083
00084 if(first_id == last_id) { order[first_id] = 0; starting_id = first_id; break; }
00085
00086 s32 storage_cost_wo_last = mul_storage_cost( *ptrs[ first_id ], *ptrs[ last_id-1 ] );
00087 s32 storage_cost_wo_first = mul_storage_cost( *ptrs[ first_id+1 ], *ptrs[ last_id ] );
00088
00089 if(storage_cost_wo_last < storage_cost_wo_first)
00090 {
00091 order[last_id] = mat_count-1;
00092 if(storage_cost_wo_last > largest_size) largest_size = storage_cost_wo_last;
00093 }
00094 else
00095 {
00096 order[first_id] = mat_count-1;
00097 if(storage_cost_wo_first > largest_size) largest_size = storage_cost_wo_first;
00098 }
00099
00100 arma_extra_debug_print(arma_boost::format("storage_cost_wo_last = %d") % storage_cost_wo_last );
00101 arma_extra_debug_print(arma_boost::format("storage_cost_wo_first = %d") % storage_cost_wo_first );
00102
00103 arma_extra_debug_print("order = ");
00104 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00105
00106 --mat_count;
00107 }
00108
00109 arma_extra_debug_print("final order = ");
00110 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00111
00112 arma_extra_debug_print(arma_boost::format("*** largest_size = %d") % largest_size);
00113 arma_extra_debug_print(arma_boost::format("starting_id = %d") % starting_id);
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131 const u32 N_mul = N_mat - 1;
00132 int mul_count = N_mul;
00133 int current_id = starting_id;
00134
00135 const Mat<eT>* src_mat_1_ptr = ptrs[current_id];
00136 const Mat<eT>* src_mat_2_ptr = 0;
00137
00138
00139
00140
00141
00142 Mat<eT> tmp_mat_1;
00143 Mat<eT> tmp_mat_2;
00144
00145 Mat<eT>* tmp_mat_1_ptr = &tmp_mat_1;
00146 Mat<eT>* tmp_mat_2_ptr = (N_mul <= 2) ? 0 : &tmp_mat_2;
00147
00148 Mat<eT>* dest_mat_ptr = tmp_mat_2_ptr;
00149
00150 arma_extra_debug_print(arma_boost::format("tmp_mat_1_ptr = %x") % tmp_mat_1_ptr );
00151 arma_extra_debug_print(arma_boost::format("tmp_mat_2_ptr = %x") % tmp_mat_2_ptr );
00152 arma_extra_debug_print(arma_boost::format("&out = %x") % &out );
00153
00154 while(mul_count != 0)
00155 {
00156 arma_extra_debug_print("");
00157 arma_extra_debug_print("");
00158 arma_extra_debug_print(arma_boost::format("mul_count = %d") % mul_count);
00159
00160 arma_extra_debug_print("order = ");
00161 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00162 arma_extra_debug_print("");
00163
00164
00165 if(mul_count == 1)
00166 {
00167 arma_extra_debug_print("dest_mat = &out");
00168 dest_mat_ptr = &out;
00169 }
00170 else
00171 {
00172 if(dest_mat_ptr == tmp_mat_2_ptr)
00173 {
00174 arma_extra_debug_print("dest_mat_ptr = tmp_mat_2_ptr");
00175 dest_mat_ptr = tmp_mat_1_ptr;
00176 }
00177 else
00178 {
00179 arma_extra_debug_print("dest_mat_ptr = tmp_mat_1_ptr");
00180 dest_mat_ptr = tmp_mat_2_ptr;
00181 }
00182 }
00183
00184 arma_extra_debug_print(arma_boost::format("dest_mat_ptr = %x") % dest_mat_ptr );
00185
00186
00187 s32 left_val = N_mat;
00188 s32 left_id = -1;
00189
00190 s32 right_val = N_mat;
00191 s32 right_id = -1;
00192
00193
00194 for(s32 i=current_id-1; i >= 0; --i)
00195 if( order[i] > order[current_id] ) { left_val = order[i]; left_id = i; break; }
00196
00197
00198 for(s32 i=current_id+1; i < N_mat; ++i)
00199 if( order[current_id] < order[i] ) { right_val = order[i]; right_id = i; break; }
00200
00201 arma_extra_debug_print("");
00202 arma_extra_debug_print(arma_boost::format("left_id = %d") % left_id );
00203 arma_extra_debug_print(arma_boost::format("left_val = %f") % left_val );
00204
00205 arma_extra_debug_print("");
00206 arma_extra_debug_print(arma_boost::format("right_id = %d") % right_id );
00207 arma_extra_debug_print(arma_boost::format("right_val = %f") % right_val );
00208
00209
00210 if(left_val < right_val)
00211 {
00212
00213 src_mat_2_ptr = ptrs[left_id];
00214
00215 arma_extra_debug_print("");
00216 arma_extra_debug_print(arma_boost::format("case pre-multiply with matrix %d") % left_id);
00217 arma_extra_debug_print(arma_boost::format("required destination size: %d, %d (%d)") % src_mat_2_ptr->n_rows % src_mat_1_ptr->n_cols % (src_mat_2_ptr->n_rows * src_mat_1_ptr->n_cols) );
00218
00219 glue_times::apply_noalias(*dest_mat_ptr, *src_mat_2_ptr, *src_mat_1_ptr);
00220
00221 order[current_id] = -1;
00222 current_id = left_id;
00223 }
00224 else
00225 {
00226
00227 src_mat_2_ptr = ptrs[right_id];
00228
00229 arma_extra_debug_print("");
00230 arma_extra_debug_print(arma_boost::format("case post-multiply with matrix %d") % right_id);
00231 arma_extra_debug_print(arma_boost::format("required destination size: %d, %d (%d)") % src_mat_1_ptr->n_rows % src_mat_2_ptr->n_cols % (src_mat_1_ptr->n_rows * src_mat_2_ptr->n_cols) );
00232
00233 glue_times::apply_noalias(*dest_mat_ptr, *src_mat_1_ptr, *src_mat_2_ptr);
00234
00235 order[current_id] = -1;
00236 current_id = right_id;
00237 }
00238
00239
00240 src_mat_1_ptr = dest_mat_ptr;
00241
00242 --mul_count;
00243 }
00244
00245
00246 for(s32 i=0; i<N_mat; ++i)
00247 {
00248 if(del[i] == true)
00249 {
00250 arma_extra_debug_print(arma_boost::format("delete mat_ptr[%d]") % i );
00251 delete ptrs[i];
00252 }
00253 }
00254 }
00255 }
00256
00257
00258
00259 template<typename T1>
00260 inline
00261 void
00262 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00263 {
00264 arma_extra_debug_sigprint();
00265
00266 typedef typename T1::elem_type eT;
00267
00268 const unwrap<T1> tmp(X);
00269 const Mat<eT>& B = tmp.M;
00270
00271 arma_debug_assert_mul_size(out, B, "matrix multiply");
00272
00273 if(out.n_cols == B.n_cols)
00274 {
00275 podarray<eT> tmp(out.n_cols);
00276 eT* tmp_rowdata = tmp.memptr();
00277
00278 for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00279 {
00280 for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00281 {
00282 tmp_rowdata[out_col] = out.at(out_row,out_col);
00283 }
00284
00285 for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00286 {
00287 const eT* B_coldata = B.colptr(B_col);
00288
00289 eT val = eT(0);
00290 for(u32 i=0; i < B.n_rows; ++i)
00291 {
00292 val += tmp_rowdata[i] * B_coldata[i];
00293 }
00294
00295 out.at(out_row,B_col) = val;
00296 }
00297 }
00298
00299 }
00300 else
00301 {
00302 const Mat<eT> tmp(out);
00303 glue_times::apply(out, tmp, B);
00304 }
00305
00306 }
00307
00308
00309
00310
00311 template<typename eT1, typename eT2>
00312 inline
00313 void
00314 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00315 {
00316 arma_extra_debug_sigprint();
00317
00318 typedef typename promote_type<eT1,eT2>::result out_eT;
00319
00320 arma_debug_assert_mul_size(X,Y, "matrix multiply");
00321
00322 out.set_size(X.n_rows,Y.n_cols);
00323 gemm_mixed<>::apply(out, X, Y);
00324 }
00325
00326
00327
00328 template<typename eT>
00329 arma_inline
00330 u32 glue_times::mul_storage_cost(const Mat<eT>& X, const Mat<eT>& Y)
00331 {
00332 return X.n_rows * Y.n_cols;
00333 }
00334
00335
00336
00337
00338
00339 template<typename eT>
00340 inline
00341 void
00342 glue_times::apply_noalias(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B)
00343 {
00344 arma_extra_debug_sigprint();
00345
00346 arma_debug_assert_mul_size(A, B, "matrix multiply");
00347
00348 out.set_size(A.n_rows,B.n_cols);
00349 gemm<>::apply(out,A,B);
00350 }
00351
00352
00353
00354 template<typename eT>
00355 inline
00356 void
00357 glue_times::apply(Mat<eT>& out, const Mat<eT>& A_in, const Mat<eT>& B_in)
00358 {
00359 arma_extra_debug_sigprint();
00360
00361 if( (&out != &A_in) && (&out != &B_in) )
00362 {
00363 glue_times::apply_noalias(out,A_in,B_in);
00364 }
00365 else
00366 {
00367
00368 if( (&out == &A_in) && (&out != &B_in) )
00369 {
00370 Mat<eT> A_copy(A_in);
00371 glue_times::apply_noalias(out,A_copy,B_in);
00372 }
00373 else
00374 if( (&out != &A_in) && (&out == &B_in) )
00375 {
00376 Mat<eT> B_copy(B_in);
00377 glue_times::apply_noalias(out,A_in,B_copy);
00378 }
00379 else
00380 if( (&out == &A_in) && (&out == &B_in) )
00381 {
00382 Mat<eT> tmp(A_in);
00383 glue_times::apply_noalias(out,tmp,tmp);
00384 }
00385
00386 }
00387
00388 }
00389
00390
00391 template<typename eT>
00392 inline
00393 void
00394 glue_times::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
00395 {
00396 arma_extra_debug_sigprint();
00397
00398 arma_debug_assert_mul_size(A, B, "matrix multiply");
00399 arma_debug_assert_mul_size(B, C, "matrix multiply");
00400
00401 if( mul_storage_cost(A,B) <= mul_storage_cost(B,C) )
00402 {
00403 Mat<eT> tmp;
00404 glue_times::apply_noalias(tmp, A, B);
00405
00406 if(&out != &C)
00407 {
00408 glue_times::apply_noalias(out, tmp, C);
00409 }
00410 else
00411 {
00412 Mat<eT> C_copy = C;
00413 glue_times::apply_noalias(out, tmp, C_copy);
00414 }
00415
00416 }
00417 else
00418 {
00419 Mat<eT> tmp;
00420 glue_times::apply_noalias(tmp, B, C);
00421
00422 if(&out != &A)
00423 {
00424 glue_times::apply_noalias(out, A, tmp);
00425 }
00426 else
00427 {
00428 Mat<eT> A_copy = A;
00429 glue_times::apply_noalias(out, A_copy, tmp);
00430 }
00431 }
00432
00433 }
00434
00435
00436
00437 template<typename eT>
00438 inline
00439 eT
00440 glue_times::direct_rowvec_mat_colvec
00441 (
00442 const eT* A_mem,
00443 const Mat<eT>& B,
00444 const eT* C_mem
00445 )
00446 {
00447 arma_extra_debug_sigprint();
00448
00449 const u32 cost_AB = B.n_cols;
00450 const u32 cost_BC = B.n_rows;
00451
00452 if(cost_AB <= cost_BC)
00453 {
00454 podarray<eT> tmp(B.n_cols);
00455
00456 for(u32 col=0; col<B.n_cols; ++col)
00457 {
00458 const eT* B_coldata = B.colptr(col);
00459
00460 eT val = eT(0);
00461 for(u32 i=0; i<B.n_rows; ++i)
00462 {
00463 val += A_mem[i] * B_coldata[i];
00464 }
00465
00466 tmp[col] = val;
00467 }
00468
00469 return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
00470 }
00471 else
00472 {
00473 podarray<eT> tmp(B.n_rows);
00474
00475 for(u32 row=0; row<B.n_rows; ++row)
00476 {
00477 eT val = eT(0);
00478 for(u32 col=0; col<B.n_cols; ++col)
00479 {
00480 val += B.at(row,col) * C_mem[col];
00481 }
00482
00483 tmp[row] = val;
00484 }
00485
00486 return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
00487 }
00488
00489
00490 }
00491
00492
00493
00494 template<typename eT>
00495 inline
00496 eT
00497 glue_times::direct_rowvec_diagmat_colvec
00498 (
00499 const eT* A_mem,
00500 const Mat<eT>& B,
00501 const eT* C_mem
00502 )
00503 {
00504 arma_extra_debug_sigprint();
00505
00506 eT val = eT(0);
00507
00508 for(u32 i=0; i<B.n_rows; ++i)
00509 {
00510 val += A_mem[i] * B.at(i,i) * C_mem[i];
00511 }
00512
00513 return val;
00514 }
00515
00516
00517
00518 template<typename eT>
00519 inline
00520 eT
00521 glue_times::direct_rowvec_invdiagmat_colvec
00522 (
00523 const eT* A_mem,
00524 const Mat<eT>& B,
00525 const eT* C_mem
00526 )
00527 {
00528 arma_extra_debug_sigprint();
00529
00530 eT val = eT(0);
00531
00532 for(u32 i=0; i<B.n_rows; ++i)
00533 {
00534 val += (A_mem[i] * C_mem[i]) / B.at(i,i);
00535 }
00536
00537 return val;
00538 }
00539
00540
00541
00542 template<typename eT>
00543 inline
00544 eT
00545 glue_times::direct_rowvec_invdiagvec_colvec
00546 (
00547 const eT* A_mem,
00548 const Mat<eT>& B,
00549 const eT* C_mem
00550 )
00551 {
00552 arma_extra_debug_sigprint();
00553
00554 const eT* B_mem = B.mem;
00555
00556 eT val = eT(0);
00557
00558 for(u32 i=0; i<B.n_elem; ++i)
00559 {
00560 val += (A_mem[i] * C_mem[i]) / B_mem[i];
00561 }
00562
00563 return val;
00564 }
00565
00566
00567
00568 #if defined(ARMA_GOOD_COMPILER)
00569
00570
00571 template<typename eT>
00572 inline
00573 void
00574 glue_times::apply(Mat<eT>& out, const Glue<Mat<eT>,Mat<eT>,glue_times>& X)
00575 {
00576 glue_times::apply(out, X.A, X.B);
00577 }
00578
00579
00580
00581 template<typename eT>
00582 inline
00583 void
00584 glue_times::apply(Mat<eT>& out, const Glue< Glue<Mat<eT>,Mat<eT>, glue_times>, Mat<eT>, glue_times>& X)
00585 {
00586 glue_times::apply(out, X.A.A, X.A.B, X.B);
00587 }
00588
00589
00590
00591
00592 template<typename T1, typename T2>
00593 inline
00594 void
00595 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_trans>, glue_times>& X)
00596 {
00597 arma_extra_debug_sigprint();
00598
00599
00600 typedef typename T1::elem_type eT;
00601
00602
00603
00604 const unwrap<T1> tmp1(X.A);
00605 const unwrap<T2> tmp2(X.B.m);
00606
00607 const Mat<eT>& A = tmp1.M;
00608 const Mat<eT>& B = tmp2.M;
00609
00610 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "matrix multiply");
00611
00612 if( (A.n_rows*B.n_rows) > 0)
00613 {
00614 if(&A != &B)
00615 {
00616 unwrap_check< Mat<eT> > A_safe_tmp(A, out);
00617 unwrap_check< Mat<eT> > B_safe_tmp(B, out);
00618
00619 const Mat<eT>& A_safe = A_safe_tmp.M;
00620 const Mat<eT>& B_safe = B_safe_tmp.M;
00621
00622 out.set_size(A_safe.n_rows, B_safe.n_rows);
00623
00624 gemm<false,true>::apply(out, A, B);
00625 }
00626 else
00627 {
00628 arma_extra_debug_print("glue_times::apply(): detected A*A'");
00629
00630 Mat<eT> tmp;
00631 op_trans::apply(tmp,A);
00632
00633
00634 out.set_size(A.n_rows, A.n_rows);
00635
00636 for(u32 row=0; row != A.n_rows; ++row)
00637 {
00638 for(u32 col=0; col <= row; ++col)
00639 {
00640 const eT* coldata1 = tmp.colptr(row);
00641 const eT* coldata2 = tmp.colptr(col);
00642
00643 eT val = eT(0);
00644 for(u32 i=0; i < tmp.n_rows; ++i)
00645 {
00646 val += coldata1[i] * coldata2[i];
00647 }
00648
00649 out.at(row,col) = val;
00650 out.at(col,row) = val;
00651 }
00652 }
00653
00654 }
00655
00656 }
00657
00658 }
00659
00660
00661
00662
00663 template<typename T1, typename T2>
00664 inline
00665 void
00666 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, T2, glue_times>& X)
00667 {
00668 arma_extra_debug_sigprint();
00669
00670 typedef typename T1::elem_type eT;
00671
00672 const unwrap_check<T1> tmp1(X.A.m, out);
00673 const unwrap_check<T2> tmp2(X.B, out);
00674
00675 const Mat<eT>& A = tmp1.M;
00676 const Mat<eT>& B = tmp2.M;
00677
00678 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiply");
00679
00680 if( (A.n_cols*B.n_cols) > 0 )
00681 {
00682 out.set_size(A.n_cols, B.n_cols);
00683
00684 gemm<true,false>::apply(out, A, B);
00685 }
00686
00687 }
00688
00689
00690
00691
00692 template<typename T1, typename T2>
00693 inline
00694 void
00695 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, Op<T2,op_trans>, glue_times>& X)
00696 {
00697 arma_extra_debug_sigprint();
00698
00699 typedef typename T1::elem_type eT;
00700
00701 const unwrap_check<T1> tmp1(X.A.m, out);
00702 const unwrap_check<T2> tmp2(X.B.m, out);
00703
00704 const Mat<eT>& A = tmp1.M;
00705 const Mat<eT>& B = tmp2.M;
00706
00707 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_cols, B.n_rows, "matrix multiply");
00708
00709 if( (A.n_cols*B.n_rows) > 0 )
00710 {
00711 out.set_size(A.n_cols, B.n_rows);
00712
00713 gemm<true,true>::apply(out, A, B);
00714
00715 }
00716
00717 }
00718
00719
00720
00721
00722
00723 template<typename T1, typename T2>
00724 inline
00725 void
00726 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1, op_neg>, T2, glue_times>& X)
00727 {
00728 arma_extra_debug_sigprint();
00729
00730 typedef typename T1::elem_type eT;
00731
00732 const unwrap_check<T1> tmp1(X.A.m, out);
00733 const unwrap_check<T2> tmp2(X.B, out);
00734
00735 const Mat<eT>& A = tmp1.M;
00736 const Mat<eT>& B = tmp2.M;
00737
00738 glue_times::apply(out, A, B);
00739
00740 const u32 n_elem = out.n_elem;
00741 for(u32 i=0; i<n_elem; ++i)
00742 {
00743 out[i] = -out[i];
00744 }
00745 }
00746
00747
00748
00749 template<typename T1, typename T2>
00750 inline
00751 void
00752 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X)
00753 {
00754 arma_extra_debug_sigprint();
00755
00756 out = out * X;
00757 }
00758
00759
00760
00761 #endif
00762
00763
00764
00765
00766
00767
00768
00769 template<typename T1, typename T2>
00770 inline
00771 void
00772 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const T1& A_orig, const Op<T2,op_diagmat>& B_orig)
00773 {
00774 arma_extra_debug_sigprint();
00775
00776 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00777
00778 const unwrap_check<T1> tmp1(A_orig, out);
00779 const unwrap_check<T2> tmp2(B_orig.m, out);
00780
00781 typedef typename T1::elem_type eT;
00782
00783 const Mat<eT>& A = tmp1.M;
00784 const Mat<eT>& B = tmp2.M;
00785
00786 arma_debug_check( (B.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00787 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00788
00789 out.set_size(A.n_rows, B.n_cols);
00790
00791 for(u32 col=0; col<A.n_cols; ++col)
00792 {
00793 const eT val = B.at(col,col);
00794
00795 const eT* A_coldata = A.colptr(col);
00796 eT* out_coldata = out.colptr(col);
00797
00798 for(u32 row=0; row<B.n_rows; ++row)
00799 {
00800 out_coldata[row] = A_coldata[row] * val;
00801 }
00802
00803 }
00804
00805 }
00806
00807
00808
00809 template<typename T1, typename T2>
00810 inline
00811 void
00812 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const T2& B_orig)
00813 {
00814 arma_extra_debug_sigprint();
00815
00816 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00817
00818 const unwrap_check<T1> tmp1(A_orig.m, out);
00819 const unwrap_check<T2> tmp2(B_orig, out);
00820
00821 typedef typename T1::elem_type eT;
00822
00823 const Mat<eT>& A = tmp1.M;
00824 const Mat<eT>& B = tmp2.M;
00825
00826 arma_debug_check( (A.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00827 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00828
00829 out.set_size(A.n_rows, B.n_cols);
00830
00831
00832 for(u32 col=0; col<A.n_cols; ++col)
00833 {
00834 const eT* B_coldata = B.colptr(col);
00835 eT* out_coldata = out.colptr(col);
00836
00837 for(u32 row=0; row<B.n_rows; ++row)
00838 {
00839 out_coldata[row] = A.at(row,row) * B_coldata[row];
00840 }
00841
00842 }
00843
00844 }
00845
00846
00847
00848 template<typename T1, typename T2>
00849 inline
00850 void
00851 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const Op<T2,op_diagmat>& B_orig)
00852 {
00853 arma_extra_debug_sigprint();
00854
00855 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00856
00857 unwrap_check<T1> tmp1(A_orig.m, out);
00858 unwrap_check<T2> tmp2(B_orig.m, out);
00859
00860 typedef typename T1::elem_type eT;
00861
00862 const Mat<eT>& A = tmp1.M;
00863 const Mat<eT>& B = tmp2.M;
00864
00865 arma_debug_check( !A.is_square() || !B.is_square(), "glue_times_diag::apply(): incompatible matrix dimensions" );
00866 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00867
00868 out.zeros(A.n_rows, B.n_cols);
00869
00870 for(u32 i=0; i<A.n_rows; ++i)
00871 {
00872 out.at(i,i) = A.at(i,i) * B.at(i,i);
00873 }
00874 }
00875
00876
00877
00878 template<typename T1, typename T2>
00879 inline
00880 void
00881 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_diagmat>, glue_times_diag>& X)
00882 {
00883 glue_times_diag::apply(out, X.A, X.B);
00884 }
00885
00886
00887
00888 template<typename T1, typename T2>
00889 inline
00890 void
00891 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, T2, glue_times_diag>& X)
00892 {
00893 glue_times_diag::apply(out, X.A, X.B);
00894 }
00895
00896
00897
00898 template<typename T1, typename T2>
00899 inline
00900 void
00901 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, Op<T2,op_diagmat>, glue_times_diag>& X)
00902 {
00903 glue_times_diag::apply(out, X.A, X.B);
00904 }
00905
00906
00907
00908
00909
00910
00911
00912
00913 template<typename T1, typename T2>
00914 inline
00915 void
00916 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_vec>& X)
00917 {
00918 arma_extra_debug_sigprint();
00919
00920 typedef typename T1::elem_type eT;
00921
00922 unwrap_check<T1> tmp1(X.A, out);
00923 unwrap_check<T2> tmp2(X.B, out);
00924
00925 const Mat<eT>& A = tmp1.M;
00926 const Mat<eT>& B = tmp2.M;
00927
00928 arma_debug_assert_mul_size(A, B, "vector multiply");
00929
00930
00931
00932
00933
00934
00935
00936
00937
00938
00939 out.set_size(A.n_rows, B.n_cols);
00940
00941 if(A.n_cols == 1)
00942 {
00943 glue_times_vec::mul_col_row(out, A.mem, B.mem);
00944 }
00945 else
00946 {
00947 if(A.n_rows == 1)
00948 {
00949 if(B.n_cols == 1)
00950 {
00951 out[0] = op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00952 }
00953 else
00954 {
00955 gemv<true>::apply(out.memptr(), B, A.mem);
00956 }
00957 }
00958 else
00959 {
00960 gemv<>::apply(out.memptr(), A, B.mem);
00961 }
00962 }
00963
00964 }
00965
00966
00967
00968 template<typename eT>
00969 inline
00970 void
00971 glue_times_vec::mul_col_row(Mat<eT>& out, const eT* A, const eT* B)
00972 {
00973 const u32 n_rows = out.n_rows;
00974 const u32 n_cols = out.n_cols;
00975
00976 for(u32 col=0; col < n_cols; ++col)
00977 {
00978 const eT val = B[col];
00979
00980 eT* out_coldata = out.colptr(col);
00981
00982 for(u32 row=0; row < n_rows; ++row)
00983 {
00984 out_coldata[row] = A[row] * val;
00985 }
00986 }
00987
00988 }
00989
00990
00991
00992 template<typename eT>
00993 inline
00994 void
00995 glue_times_vec::mul_col_row_inplace_add(Mat<eT>& out, const eT* A, const eT* B)
00996 {
00997 const u32 n_rows = out.n_rows;
00998 const u32 n_cols = out.n_cols;
00999
01000 for(u32 col=0; col < n_cols; ++col)
01001 {
01002 const eT val = B[col];
01003
01004 eT* out_coldata = out.colptr(col);
01005
01006 for(u32 row=0; row < n_rows; ++row)
01007 {
01008 out_coldata[row] += A[row] * val;
01009 }
01010 }
01011
01012 }
01013
01014
01015
01016 #if defined(ARMA_GOOD_COMPILER)
01017
01018
01019
01020 template<typename eT>
01021 inline
01022 void
01023 glue_times_vec::apply(Mat<eT>& out, const Glue<Col<eT>,Row<eT>,glue_times_vec>& X)
01024 {
01025 arma_extra_debug_sigprint();
01026
01027 unwrap_check< Col<eT> > tmp1(X.A, out);
01028 unwrap_check< Row<eT> > tmp2(X.B, out);
01029
01030 const Col<eT>& A = tmp1.M;
01031 const Row<eT>& B = tmp2.M;
01032
01033 arma_debug_assert_mul_size(A, B, "vector multiply");
01034
01035 out.set_size(A.n_rows, B.n_cols);
01036
01037 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01038 }
01039
01040
01041
01042 template<typename eT>
01043 inline
01044 void
01045 glue_times_vec::apply(Mat<eT>& out, const Glue< Op<Row<eT>, op_trans>, Row<eT>, glue_times_vec>& X)
01046 {
01047 arma_extra_debug_sigprint();
01048
01049 unwrap_check< Row<eT> > tmp1(X.A.m, out);
01050 unwrap_check< Row<eT> > tmp2(X.B, out);
01051
01052 const Row<eT>& A = tmp1.M;
01053 const Row<eT>& B = tmp2.M;
01054
01055 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01056
01057 out.set_size(A.n_cols, B.n_cols);
01058
01059 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01060 }
01061
01062
01063
01064 template<typename eT>
01065 inline
01066 void
01067 glue_times_vec::apply(Mat<eT>& out, const Glue< Col<eT>, Op<Col<eT>, op_trans>, glue_times_vec>& X)
01068 {
01069 arma_extra_debug_sigprint();
01070
01071 unwrap_check< Col<eT> > tmp1(X.A, out);
01072 unwrap_check< Col<eT> > tmp2(X.B.m, out);
01073
01074 const Col<eT>& A = tmp1.M;
01075 const Col<eT>& B = tmp2.M;
01076
01077 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "vector multiply");
01078
01079 out.set_size(A.n_rows, B.n_rows);
01080
01081 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01082 }
01083
01084
01085
01086 template<typename T1>
01087 inline
01088 void
01089 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1, op_trans>, Col<typename T1::elem_type>,glue_times_vec>& X)
01090 {
01091 arma_extra_debug_sigprint();
01092
01093 typedef typename T1::elem_type eT;
01094
01095 unwrap_check< T1 > tmp1(X.A.m, out);
01096 unwrap_check< Col<eT> > tmp2(X.B, out);
01097
01098 const Mat<eT>& A = tmp1.M;
01099 const Col<eT>& B = tmp2.M;
01100
01101 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01102
01103 out.set_size(A.n_cols, B.n_cols);
01104
01105
01106
01107
01108
01109
01110
01111
01112
01113
01114
01115
01116
01117
01118
01119
01120
01121
01122
01123
01124 gemv<true>::apply(out.memptr(), A, B.mem);
01125 }
01126
01127
01128
01129 #endif
01130
01131
01132
01133