op_pinv_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // Copyright (C) 2009 Dimitrios Bouzas
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // - Dimitrios Bouzas (dimitris dot mpouzas at gmail dot com)
00007 // 
00008 // This file is part of the Armadillo C++ library.
00009 // It is provided without any warranty of fitness
00010 // for any purpose. You can redistribute this file
00011 // and/or modify it under the terms of the GNU
00012 // Lesser General Public License (LGPL) as published
00013 // by the Free Software Foundation, either version 3
00014 // of the License or (at your option) any later version.
00015 // (see http://www.opensource.org/licenses for more info)
00016 
00017 
00018 
00019 //! \addtogroup op_pinv
00020 //! @{
00021 
00022 
00023 
00024 template<typename eT>
00025 inline
00026 void
00027 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, eT tol)
00028   {
00029   arma_extra_debug_sigprint();
00030   
00031   const u32 n_rows = A.n_rows;
00032   const u32 n_cols = A.n_cols;
00033   
00034   // SVD decomposition 
00035   Mat<eT> U;
00036   Col<eT> s;
00037   Mat<eT> V;
00038   
00039   (n_cols > n_rows) ? svd(U,s,V,trans(A)) : svd(U,s,V,A);
00040    
00041   // set tolerance to default if it hasn't been specified as an argument
00042   if(tol == eT(0))
00043     {
00044     tol = (std::max)(n_rows,n_cols) * op_eps::direct_eps(max(s));
00045     }
00046    
00047   // count non zero valued elements in s
00048   u32 count = 0; 
00049   for(u32 i = 0; i < s.n_elem; ++i)
00050     {
00051     if(s[i] > tol)
00052       {
00053       ++count;
00054       }
00055     }
00056   
00057   
00058   if(count != 0)
00059     {
00060     // reduce the length of s in order to contain only the values above tolerance
00061     s = s.rows(0,count-1);
00062     
00063     // set the elements of s equal to their reciprocals
00064     s = eT(1) / s;
00065     
00066     if(A.n_cols <= A.n_rows)
00067       {
00068       out = V.cols(0,count-1) * diagmat(s) * trans(U.cols(0,count-1));
00069       }
00070     else
00071       {
00072       out = U.cols(0,count-1) * diagmat(s) * trans(V.cols(0,count-1));
00073       }
00074     }
00075   else
00076     {
00077     out.zeros(n_cols, n_rows);
00078     }
00079   }
00080 
00081 
00082 
00083 template<typename T>
00084 inline
00085 void
00086 op_pinv::direct_pinv(Mat< std::complex<T> >& out, const Mat< std::complex<T> >& A, T tol)
00087   {
00088   arma_extra_debug_sigprint();
00089   
00090   const u32 n_rows = A.n_rows;
00091   const u32 n_cols = A.n_cols;
00092   
00093   typedef typename std::complex<T> eT;
00094  
00095   // SVD decomposition 
00096   Mat<eT> U;
00097   Col< T> s;
00098   Mat<eT> V;
00099   
00100   (n_cols > n_rows) ? svd(U,s,V,htrans(A)) : svd(U,s,V,A);
00101  
00102   // set tolerance to default if it hasn't been specified as an argument 
00103   if(tol == T(0))
00104     {
00105     tol = (std::max)(n_rows,n_cols) * op_eps::direct_eps(max(s));
00106     }
00107   
00108   
00109   // count non zero valued elements in s
00110   u32 count = 0;
00111   for(u32 i = 0; i < s.n_elem; ++i)
00112     {
00113     if(s[i] > tol)
00114       {
00115       ++count;
00116       }
00117     }
00118   
00119   if(count != 0)
00120     {
00121     // reduce the length of s in order to contain only the values above tolerance
00122     s = s.rows(0,count-1);
00123 
00124     // set the elements of s equal to their reciprocals
00125     s = T(1) / s;
00126     
00127     if(n_rows >= n_cols)
00128       {
00129       out = V.cols(0,count-1) * diagmat(s) * htrans(U.cols(0,count-1));
00130       }
00131     else
00132       {
00133       out = U.cols(0,count-1) * diagmat(s) * htrans(V.cols(0,count-1));
00134       }
00135     }
00136   else
00137     {
00138     out.zeros(n_cols, n_rows);
00139     }
00140   }
00141 
00142 
00143 
00144 template<typename T1>
00145 inline
00146 void
00147 op_pinv::apply(Mat<typename T1::pod_type>& out, const Op<T1,op_pinv>& in)
00148   {
00149   arma_extra_debug_sigprint();
00150   
00151   typedef typename T1::pod_type eT;
00152   
00153   const eT tol = in.aux; 
00154   
00155   arma_debug_check((tol < eT(0)), "pinv(): tol must be >= 0");
00156   
00157   const unwrap_check<T1> tmp(in.m, out);
00158   const Mat<eT>& A     = tmp.M;
00159   
00160   op_pinv::direct_pinv(out, A, tol);
00161   }
00162 
00163 
00164 
00165 template<typename T1>
00166 inline
00167 void
00168 op_pinv::apply(Mat< std::complex<typename T1::pod_type> >& out, const Op<T1,op_pinv>& in)
00169   {
00170   arma_extra_debug_sigprint();
00171   
00172   typedef typename T1::pod_type T;
00173   
00174   typedef typename std::complex<typename T1::pod_type> eT;
00175   
00176   const T tol = in.aux.real();
00177   
00178   arma_debug_check((tol < T(0)), "pinv(): tol must be >= 0");
00179   
00180   const unwrap_check<T1> tmp(in.m, out);
00181   const Mat<eT>& A     = tmp.M;
00182   
00183   op_pinv::direct_pinv(out, A, tol);
00184   }
00185 
00186 
00187 
00188 //! @}