op_pinv_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // - Dimitrios Bouzas (dimitris dot mpouzas at gmail dot com)
00007 // 
00008 // This file is part of the Armadillo C++ library.
00009 // It is provided without any warranty of fitness
00010 // for any purpose. You can redistribute this file
00011 // and/or modify it under the terms of the GNU
00012 // Lesser General Public License (LGPL) as published
00013 // by the Free Software Foundation, either version 3
00014 // of the License or (at your option) any later version.
00015 // (see http://www.opensource.org/licenses for more info)
00016 
00017 
00018 
00019 //! \addtogroup op_pinv
00020 //! @{
00021 
00022 
00023 
00024 template<typename eT>
00025 inline
00026 void
00027 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, eT tol)
00028   {
00029   arma_extra_debug_sigprint();
00030   
00031   const u32 n_rows = A.n_rows;
00032   const u32 n_cols = A.n_cols;
00033   
00034   // SVD decomposition 
00035   Mat<eT> U;
00036   Col<eT> s;
00037   Mat<eT> V;
00038   
00039   const bool status = (n_cols > n_rows) ? svd(U,s,V,trans(A)) : svd(U,s,V,A);
00040   
00041   if(status == false)
00042     {
00043     out.set_size(0,0);
00044     return;
00045     }
00046   
00047   // set tolerance to default if it hasn't been specified as an argument
00048   if(tol == eT(0))
00049     {
00050     tol = (std::max)(n_rows,n_cols) * eop_aux::direct_eps(max(s));
00051     }
00052    
00053   // count non zero valued elements in s
00054   u32 count = 0; 
00055   for(u32 i = 0; i < s.n_elem; ++i)
00056     {
00057     if(s[i] > tol)
00058       {
00059       ++count;
00060       }
00061     }
00062   
00063   
00064   if(count != 0)
00065     {
00066     // reduce the length of s in order to contain only the values above tolerance
00067     s = s.rows(0,count-1);
00068     
00069     // set the elements of s equal to their reciprocals
00070     s = eT(1) / s;
00071     
00072     if(A.n_cols <= A.n_rows)
00073       {
00074       out = V.cols(0,count-1) * diagmat(s) * trans(U.cols(0,count-1));
00075       }
00076     else
00077       {
00078       out = U.cols(0,count-1) * diagmat(s) * trans(V.cols(0,count-1));
00079       }
00080     }
00081   else
00082     {
00083     out.zeros(n_cols, n_rows);
00084     }
00085   }
00086 
00087 
00088 
00089 template<typename T>
00090 inline
00091 void
00092 op_pinv::direct_pinv(Mat< std::complex<T> >& out, const Mat< std::complex<T> >& A, T tol)
00093   {
00094   arma_extra_debug_sigprint();
00095   
00096   const u32 n_rows = A.n_rows;
00097   const u32 n_cols = A.n_cols;
00098   
00099   typedef typename std::complex<T> eT;
00100  
00101   // SVD decomposition 
00102   Mat<eT> U;
00103   Col< T> s;
00104   Mat<eT> V;
00105   
00106   const bool status = (n_cols > n_rows) ? svd(U,s,V,htrans(A)) : svd(U,s,V,A);
00107   
00108   if(status == false)
00109     {
00110     out.set_size(0,0);
00111     return;
00112     }
00113  
00114   // set tolerance to default if it hasn't been specified as an argument 
00115   if(tol == T(0))
00116     {
00117     tol = (std::max)(n_rows,n_cols) * eop_aux::direct_eps(max(s));
00118     }
00119   
00120   
00121   // count non zero valued elements in s
00122   u32 count = 0;
00123   for(u32 i = 0; i < s.n_elem; ++i)
00124     {
00125     if(s[i] > tol)
00126       {
00127       ++count;
00128       }
00129     }
00130   
00131   if(count != 0)
00132     {
00133     // reduce the length of s in order to contain only the values above tolerance
00134     s = s.rows(0,count-1);
00135 
00136     // set the elements of s equal to their reciprocals
00137     s = T(1) / s;
00138     
00139     if(n_rows >= n_cols)
00140       {
00141       out = V.cols(0,count-1) * diagmat(s) * htrans(U.cols(0,count-1));
00142       }
00143     else
00144       {
00145       out = U.cols(0,count-1) * diagmat(s) * htrans(V.cols(0,count-1));
00146       }
00147     }
00148   else
00149     {
00150     out.zeros(n_cols, n_rows);
00151     }
00152   }
00153 
00154 
00155 
00156 template<typename T1>
00157 inline
00158 void
00159 op_pinv::apply(Mat<typename T1::pod_type>& out, const Op<T1,op_pinv>& in)
00160   {
00161   arma_extra_debug_sigprint();
00162   
00163   typedef typename T1::pod_type eT;
00164   
00165   const eT tol = in.aux; 
00166   
00167   arma_debug_check((tol < eT(0)), "pinv(): tol must be >= 0");
00168   
00169   const unwrap_check<T1> tmp(in.m, out);
00170   const Mat<eT>& A     = tmp.M;
00171   
00172   op_pinv::direct_pinv(out, A, tol);
00173   }
00174 
00175 
00176 
00177 template<typename T1>
00178 inline
00179 void
00180 op_pinv::apply(Mat< std::complex<typename T1::pod_type> >& out, const Op<T1,op_pinv>& in)
00181   {
00182   arma_extra_debug_sigprint();
00183   
00184   typedef typename T1::pod_type T;
00185   
00186   typedef typename std::complex<typename T1::pod_type> eT;
00187   
00188   const T tol = in.aux.real();
00189   
00190   arma_debug_check((tol < T(0)), "pinv(): tol must be >= 0");
00191   
00192   const unwrap_check<T1> tmp(in.m, out);
00193   const Mat<eT>& A     = tmp.M;
00194   
00195   op_pinv::direct_pinv(out, A, tol);
00196   }
00197 
00198 
00199 
00200 //! @}