op_sum_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup op_sum
00018 //! @{
00019 
00020 //! \brief
00021 //! Immediate sum of elements of a matrix along a specified dimension (either rows or columns).
00022 //! The result is stored in a dense matrix that has either one column or one row.
00023 //! See the sum() function for more details.
00024 template<typename T1>
00025 inline
00026 void
00027 op_sum::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sum>& in)
00028   {
00029   arma_extra_debug_sigprint();
00030   
00031   const u32 dim = in.aux_u32_a;
00032   arma_debug_check( (dim > 1), "sum(): incorrect usage. dim must be 0 or 1");
00033   
00034   typedef typename T1::elem_type eT;
00035   
00036   const unwrap_check<T1> tmp(in.m, out);
00037   const Mat<eT>& X     = tmp.M;
00038   
00039   arma_debug_check( (X.n_elem < 1), "sum(): given object has no elements");
00040   
00041   
00042   if(dim == 0)  // traverse across rows (i.e. find the sum in each column)
00043     {
00044     out.set_size(1, X.n_cols);
00045     
00046     for(u32 col=0; col < X.n_cols; ++col)
00047       {
00048       const eT* X_colptr = X.colptr(col);
00049       
00050       eT val = eT(0);
00051       
00052       for(u32 row=0; row < X.n_rows; ++row)
00053         {
00054         val += X_colptr[row];
00055         }
00056     
00057       out.at(0,col) = val;
00058       }
00059     }
00060   else  // traverse across columns (i.e. find the sum in each row)
00061     {
00062     out.set_size(X.n_rows, 1);
00063     
00064     for(u32 row=0; row < X.n_rows; ++row)
00065       {
00066       eT val = eT(0);
00067       
00068       for(u32 col=0; col<X.n_cols; ++col)
00069         {
00070         val += X.at(row,col);
00071         }
00072     
00073       out.at(row,0) = val;
00074       }
00075     
00076     }
00077   
00078   }
00079 
00080 
00081 
00082 //! @}