op_sort_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup op_sort
00018 //! @{
00019 
00020 
00021 // using qsort() rather than std::sort() for now.
00022 // std::sort() will be used when a Random Access Iterator wrapper for plain arrays is ready,
00023 // otherwise using std::sort() would currently entail copying elements to/from std::vector
00024 
00025 template<typename eT>
00026 class arma_qsort_helper
00027   {
00028   public:
00029   
00030   static
00031   int
00032   ascend_compare(const void* A_orig, const void* B_orig)
00033     {
00034     const eT& A = *(static_cast<const eT*>(A_orig));
00035     const eT& B = *(static_cast<const eT*>(B_orig));
00036     
00037     if(A < B)
00038       {
00039       return -1;
00040       }
00041     else
00042     if(A > B)
00043       {
00044       return +1;
00045       }
00046     else
00047       {
00048       return 0;
00049       }
00050     }
00051   
00052   
00053   
00054   static
00055   int
00056   descend_compare(const void* A_orig, const void* B_orig)
00057     {
00058     const eT& A = *(static_cast<const eT*>(A_orig));
00059     const eT& B = *(static_cast<const eT*>(B_orig));
00060     
00061     if(A < B)
00062       {
00063       return +1;
00064       }
00065     else
00066     if(A > B)
00067       {
00068       return -1;
00069       }
00070     else
00071       {
00072       return 0;
00073       }
00074     }
00075   
00076   
00077   };
00078 
00079 
00080 
00081 //template<>
00082 template<typename T>
00083 class arma_qsort_helper< std::complex<T> >
00084   {
00085   public:
00086   
00087   typedef typename std::complex<T> eT;
00088   
00089   
00090   static
00091   int
00092   ascend_compare(const void* A_orig, const void* B_orig)
00093     {
00094     const eT& A = *(static_cast<const eT*>(A_orig));
00095     const eT& B = *(static_cast<const eT*>(B_orig));
00096     
00097     const T abs_A = std::abs(A);
00098     const T abs_B = std::abs(B);
00099     
00100     if(abs_A < abs_B)
00101       {
00102       return -1;
00103       }
00104     else
00105     if(abs_A > abs_B)
00106       {
00107       return +1;
00108       }
00109     else
00110       {
00111       return 0;
00112       }
00113     }
00114   
00115   
00116   
00117   static
00118   int
00119   descend_compare(const void* A_orig, const void* B_orig)
00120     {
00121     const eT& A = *(static_cast<const eT*>(A_orig));
00122     const eT& B = *(static_cast<const eT*>(B_orig));
00123     
00124     const T abs_A = std::abs(A);
00125     const T abs_B = std::abs(B);
00126     
00127     if(abs_A < abs_B)
00128       {
00129       return +1;
00130       }
00131     else
00132     if(abs_A > abs_B)
00133       {
00134       return -1;
00135       }
00136     else
00137       {
00138       return 0;
00139       }
00140     }
00141   
00142   
00143   };
00144 
00145 
00146 
00147 template<typename eT>
00148 inline 
00149 void
00150 op_sort::direct_sort(eT* X, const u32 n_elem, const u32 sort_type)
00151   {
00152   arma_extra_debug_sigprint();
00153   
00154   if(sort_type == 0)
00155     {
00156     std::qsort(X, n_elem, sizeof(eT), arma_qsort_helper<eT>::ascend_compare);
00157     }
00158   else
00159     {
00160     std::qsort(X, n_elem, sizeof(eT), arma_qsort_helper<eT>::descend_compare);
00161     }
00162   }
00163 
00164 
00165 
00166 template<typename T1>
00167 inline
00168 void
00169 op_sort::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sort>& in)
00170   {
00171   arma_extra_debug_sigprint();
00172   
00173   typedef typename T1::elem_type eT;
00174   
00175   const unwrap<T1>   tmp(in.m);
00176   const Mat<eT>& X = tmp.M;
00177   
00178   const u32 sort_type = in.aux_u32_a;
00179   const u32 dim       = in.aux_u32_b;
00180   
00181   arma_debug_check( (X.is_finite() == false), "sort(): given object has non-finite elements"     );
00182   arma_debug_check( (sort_type > 1),          "sort(): incorrect usage. sort_type must be 0 or 1");
00183   arma_debug_check( (dim > 1),                "sort(): incorrect usage. dim must be 0 or 1"      );
00184   
00185   
00186   if(dim == 0)  // column-wise
00187     {
00188     arma_extra_debug_print("op_sort::apply(), dim = 0");
00189     
00190     out = X;
00191     
00192     for(u32 col=0; col<out.n_cols; ++col)
00193       {
00194       op_sort::direct_sort( out.colptr(col), out.n_rows, sort_type );
00195       }
00196     }
00197   else
00198   if(dim == 1)  // row-wise
00199     {
00200     if(X.n_rows != 1)  // not a row vector
00201       {
00202       arma_extra_debug_print("op_sort::apply(), dim = 1, generic");
00203       
00204       //out.set_size(X.n_rows, X.n_cols);
00205       out.copy_size(X);
00206       
00207       podarray<eT> tmp_array(X.n_cols);
00208       
00209       for(u32 row=0; row<out.n_rows; ++row)
00210         {
00211         
00212         for(u32 col=0; col<out.n_cols; ++col)
00213           {
00214           tmp_array[col] = X.at(row,col);
00215           }
00216         
00217         op_sort::direct_sort( tmp_array.memptr(), out.n_cols, sort_type );
00218         
00219         for(u32 col=0; col<out.n_cols; ++col)
00220           {
00221           out.at(row,col) = tmp_array[col];
00222           }
00223         
00224         }
00225       }
00226     else  // a row vector
00227       {
00228       arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific");
00229       
00230       out = X;
00231       op_sort::direct_sort(out.memptr(), out.n_elem, sort_type);
00232       }
00233     }
00234   
00235   }
00236 
00237 
00238 //! @}