SparseSparseProductWithPruning.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24 
25 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
26 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
27 
28 namespace Eigen {
29 
30 namespace internal {
31 
32 
33 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
34 template<typename Lhs, typename Rhs, typename ResultType>
35 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance)
36 {
37  // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
38 
39  typedef typename remove_all<Lhs>::type::Scalar Scalar;
40  typedef typename remove_all<Lhs>::type::Index Index;
41 
42  // make sure to call innerSize/outerSize since we fake the storage order.
43  Index rows = lhs.innerSize();
44  Index cols = rhs.outerSize();
45  //int size = lhs.outerSize();
46  eigen_assert(lhs.outerSize() == rhs.innerSize());
47 
48  // allocate a temporary buffer
49  AmbiVector<Scalar,Index> tempVector(rows);
50 
51  // estimate the number of non zero entries
52  // given a rhs column containing Y non zeros, we assume that the respective Y columns
53  // of the lhs differs in average of one non zeros, thus the number of non zeros for
54  // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
55  // per column of the lhs.
56  // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
57  Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
58 
59  // mimics a resizeByInnerOuter:
60  if(ResultType::IsRowMajor)
61  res.resize(cols, rows);
62  else
63  res.resize(rows, cols);
64 
65  res.reserve(estimated_nnz_prod);
66  double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
67  for (Index j=0; j<cols; ++j)
68  {
69  // FIXME:
70  //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
71  // let's do a more accurate determination of the nnz ratio for the current column j of res
72  tempVector.init(ratioColRes);
73  tempVector.setZero();
74  for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
75  {
76  // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
77  tempVector.restart();
78  Scalar x = rhsIt.value();
79  for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
80  {
81  tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
82  }
83  }
84  res.startVec(j);
85  for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
86  res.insertBackByOuterInner(j,it.index()) = it.value();
87  }
88  res.finalize();
89 }
90 
91 template<typename Lhs, typename Rhs, typename ResultType,
92  int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
93  int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
94  int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
95 struct sparse_sparse_product_with_pruning_selector;
96 
97 template<typename Lhs, typename Rhs, typename ResultType>
98 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
99 {
100  typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
101  typedef typename ResultType::RealScalar RealScalar;
102 
103  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
104  {
105  typename remove_all<ResultType>::type _res(res.rows(), res.cols());
106  internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
107  res.swap(_res);
108  }
109 };
110 
111 template<typename Lhs, typename Rhs, typename ResultType>
112 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
113 {
114  typedef typename ResultType::RealScalar RealScalar;
115  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
116  {
117  // we need a col-major matrix to hold the result
118  typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
119  SparseTemporaryType _res(res.rows(), res.cols());
120  internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
121  res = _res;
122  }
123 };
124 
125 template<typename Lhs, typename Rhs, typename ResultType>
126 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
127 {
128  typedef typename ResultType::RealScalar RealScalar;
129  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
130  {
131  // let's transpose the product to get a column x column product
132  typename remove_all<ResultType>::type _res(res.rows(), res.cols());
133  internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
134  res.swap(_res);
135  }
136 };
137 
138 template<typename Lhs, typename Rhs, typename ResultType>
139 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
140 {
141  typedef typename ResultType::RealScalar RealScalar;
142  static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
143  {
144  typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
145  ColMajorMatrix colLhs(lhs);
146  ColMajorMatrix colRhs(rhs);
147  internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance);
148 
149  // let's transpose the product to get a column x column product
150 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
151 // SparseTemporaryType _res(res.cols(), res.rows());
152 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
153 // res = _res.transpose();
154  }
155 };
156 
157 // NOTE the 2 others cases (col row *) must never occur since they are caught
158 // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
159 
160 } // end namespace internal
161 
162 } // end namespace Eigen
163 
164 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H