op_reshape_meat.hpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 template<typename T1>
00023 inline
00024 void
00025 op_reshape::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_reshape>& in)
00026 {
00027 arma_extra_debug_sigprint();
00028
00029 typedef typename T1::elem_type eT;
00030
00031 const unwrap<T1> tmp(in.m);
00032 const Mat<eT>& A = tmp.M;
00033
00034 const u32 in_n_rows = in.aux_u32_a;
00035 const u32 in_n_cols = in.aux_u32_b;
00036
00037 const u32 in_n_elem = in_n_rows * in_n_cols;
00038
00039 arma_debug_check( (A.n_elem != in_n_elem), "reshape(): incompatible dimensions");
00040
00041 if(in.aux == eT(0))
00042 {
00043 if(&out != &A)
00044 {
00045 out.set_size(in_n_rows, in_n_cols);
00046 syslib::copy_elem( out.memptr(), A.memptr(), out.n_elem );
00047 }
00048 else
00049 {
00050 access::rw(out.n_rows) = in_n_rows;
00051 access::rw(out.n_cols) = in_n_cols;
00052 }
00053 }
00054 else
00055 {
00056 unwrap_check< Mat<eT> > tmp(A, out);
00057 const Mat<eT>& B = tmp.M;
00058
00059 out.set_size(in_n_rows, in_n_cols);
00060
00061 eT* out_mem = out.memptr();
00062 u32 i = 0;
00063
00064 for(u32 row=0; row<B.n_rows; ++row)
00065 {
00066 for(u32 col=0; col<B.n_cols; ++col)
00067 {
00068 out_mem[i] = B.at(row,col);
00069 ++i;
00070 }
00071 }
00072
00073 }
00074
00075 }
00076
00077
00078
00079 template<typename T1>
00080 inline
00081 void
00082 op_reshape::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_reshape>& in)
00083 {
00084 arma_extra_debug_sigprint();
00085
00086 typedef typename T1::elem_type eT;
00087
00088 const unwrap_cube<T1> tmp(in.m);
00089 const Cube<eT>& A = tmp.M;
00090
00091 const u32 in_n_rows = in.aux_u32_a;
00092 const u32 in_n_cols = in.aux_u32_b;
00093 const u32 in_n_slices = in.aux_u32_c;
00094
00095 const u32 in_n_elem = in_n_rows * in_n_cols * in_n_slices;
00096
00097 arma_debug_check( (A.n_elem != in_n_elem), "reshape(): incompatible dimensions");
00098
00099 if(in.aux == eT(0))
00100 {
00101 if(&out != &A)
00102 {
00103 out.set_size(in_n_rows, in_n_cols, in_n_slices);
00104 syslib::copy_elem( out.memptr(), A.memptr(), out.n_elem );
00105 }
00106 else
00107 {
00108 access::rw(out.n_rows) = in_n_rows;
00109 access::rw(out.n_cols) = in_n_cols;
00110 access::rw(out.n_slices) = in_n_slices;
00111 }
00112 }
00113 else
00114 {
00115 unwrap_cube_check< Cube<eT> > tmp(A, out);
00116 const Cube<eT>& B = tmp.M;
00117
00118 out.set_size(in_n_rows, in_n_cols, in_n_slices);
00119
00120 eT* out_mem = out.memptr();
00121 u32 i = 0;
00122
00123 for(u32 slice=0; slice<B.n_slices; ++slice)
00124 {
00125 for(u32 row=0; row<B.n_rows; ++row)
00126 {
00127 for(u32 col=0; col<B.n_cols; ++col)
00128 {
00129 out_mem[i] = B.at(row,col,slice);
00130 ++i;
00131 }
00132 }
00133 }
00134
00135 }
00136
00137 }
00138
00139
00140
00141