op_pinv_meat.hpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 template<typename eT>
00025 inline
00026 void
00027 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, eT tol)
00028 {
00029 arma_extra_debug_sigprint();
00030
00031 const u32 n_rows = A.n_rows;
00032 const u32 n_cols = A.n_cols;
00033
00034
00035 Mat<eT> U;
00036 Col<eT> s;
00037 Mat<eT> V;
00038
00039 (n_cols > n_rows) ? svd(U,s,V,trans(A)) : svd(U,s,V,A);
00040
00041
00042 if(tol == eT(0))
00043 {
00044 tol = (std::max)(n_rows,n_cols) * op_eps::direct_eps(max(s));
00045 }
00046
00047
00048 u32 count = 0;
00049 for(u32 i = 0; i < s.n_elem; ++i)
00050 {
00051 if(s[i] > tol)
00052 {
00053 ++count;
00054 }
00055 }
00056
00057
00058 if(count != 0)
00059 {
00060
00061 s = s.rows(0,count-1);
00062
00063
00064 s = eT(1) / s;
00065
00066 if(A.n_cols <= A.n_rows)
00067 {
00068 out = V.cols(0,count-1) * diagmat(s) * trans(U.cols(0,count-1));
00069 }
00070 else
00071 {
00072 out = U.cols(0,count-1) * diagmat(s) * trans(V.cols(0,count-1));
00073 }
00074 }
00075 else
00076 {
00077 out.zeros(n_cols, n_rows);
00078 }
00079 }
00080
00081
00082
00083 template<typename T>
00084 inline
00085 void
00086 op_pinv::direct_pinv(Mat< std::complex<T> >& out, const Mat< std::complex<T> >& A, T tol)
00087 {
00088 arma_extra_debug_sigprint();
00089
00090 const u32 n_rows = A.n_rows;
00091 const u32 n_cols = A.n_cols;
00092
00093 typedef typename std::complex<T> eT;
00094
00095
00096 Mat<eT> U;
00097 Col< T> s;
00098 Mat<eT> V;
00099
00100 (n_cols > n_rows) ? svd(U,s,V,htrans(A)) : svd(U,s,V,A);
00101
00102
00103 if(tol == T(0))
00104 {
00105 tol = (std::max)(n_rows,n_cols) * op_eps::direct_eps(max(s));
00106 }
00107
00108
00109
00110 u32 count = 0;
00111 for(u32 i = 0; i < s.n_elem; ++i)
00112 {
00113 if(s[i] > tol)
00114 {
00115 ++count;
00116 }
00117 }
00118
00119 if(count != 0)
00120 {
00121
00122 s = s.rows(0,count-1);
00123
00124
00125 s = T(1) / s;
00126
00127 if(n_rows >= n_cols)
00128 {
00129 out = V.cols(0,count-1) * diagmat(s) * htrans(U.cols(0,count-1));
00130 }
00131 else
00132 {
00133 out = U.cols(0,count-1) * diagmat(s) * htrans(V.cols(0,count-1));
00134 }
00135 }
00136 else
00137 {
00138 out.zeros(n_cols, n_rows);
00139 }
00140 }
00141
00142
00143
00144 template<typename T1>
00145 inline
00146 void
00147 op_pinv::apply(Mat<typename T1::pod_type>& out, const Op<T1,op_pinv>& in)
00148 {
00149 arma_extra_debug_sigprint();
00150
00151 typedef typename T1::pod_type eT;
00152
00153 const eT tol = in.aux;
00154
00155 arma_debug_check((tol < eT(0)), "pinv(): tol must be >= 0");
00156
00157 const unwrap_check<T1> tmp(in.m, out);
00158 const Mat<eT>& A = tmp.M;
00159
00160 op_pinv::direct_pinv(out, A, tol);
00161 }
00162
00163
00164
00165 template<typename T1>
00166 inline
00167 void
00168 op_pinv::apply(Mat< std::complex<typename T1::pod_type> >& out, const Op<T1,op_pinv>& in)
00169 {
00170 arma_extra_debug_sigprint();
00171
00172 typedef typename T1::pod_type T;
00173
00174 typedef typename std::complex<typename T1::pod_type> eT;
00175
00176 const T tol = in.aux.real();
00177
00178 arma_debug_check((tol < T(0)), "pinv(): tol must be >= 0");
00179
00180 const unwrap_check<T1> tmp(in.m, out);
00181 const Mat<eT>& A = tmp.M;
00182
00183 op_pinv::direct_pinv(out, A, tol);
00184 }
00185
00186
00187
00188