SparseDenseProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24 
25 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
26 #define EIGEN_SPARSEDENSEPRODUCT_H
27 
28 namespace Eigen {
29 
30 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
31 {
33 };
34 
35 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
36 {
38 };
39 
40 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
41 {
43 };
44 
45 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
46 {
48 };
49 
50 namespace internal {
51 
52 template<typename Lhs, typename Rhs, bool Tr>
53 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
54 {
55  typedef Sparse StorageKind;
56  typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
57  typename traits<Rhs>::Scalar>::ReturnType Scalar;
58  typedef typename Lhs::Index Index;
59  typedef typename Lhs::Nested LhsNested;
60  typedef typename Rhs::Nested RhsNested;
61  typedef typename remove_all<LhsNested>::type _LhsNested;
62  typedef typename remove_all<RhsNested>::type _RhsNested;
63 
64  enum {
65  LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
66  RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
67 
68  RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
69  ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
70  MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
71  MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
72 
73  Flags = Tr ? RowMajorBit : 0,
74 
75  CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
76  };
77 };
78 
79 } // end namespace internal
80 
81 template<typename Lhs, typename Rhs, bool Tr>
83  : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
84 {
85  public:
86 
90 
91  private:
92 
93  typedef typename Traits::LhsNested LhsNested;
94  typedef typename Traits::RhsNested RhsNested;
95  typedef typename Traits::_LhsNested _LhsNested;
96  typedef typename Traits::_RhsNested _RhsNested;
97 
98  public:
99 
100  class InnerIterator;
101 
103  : m_lhs(lhs), m_rhs(rhs)
104  {
105  EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
106  }
107 
109  : m_lhs(lhs), m_rhs(rhs)
110  {
111  EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
112  }
113 
114  EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
115  EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
116 
117  EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
118  EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
119 
120  protected:
121  LhsNested m_lhs;
122  RhsNested m_rhs;
123 };
124 
125 template<typename Lhs, typename Rhs, bool Transpose>
126 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
127 {
128  typedef typename _LhsNested::InnerIterator Base;
129  public:
130  EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
131  : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
132  {
133  }
134 
135  inline Index outer() const { return m_outer; }
136  inline Index row() const { return Transpose ? Base::row() : m_outer; }
137  inline Index col() const { return Transpose ? m_outer : Base::row(); }
138 
139  inline Scalar value() const { return Base::value() * m_factor; }
140 
141  protected:
142  int m_outer;
143  Scalar m_factor;
144 };
145 
146 namespace internal {
147 template<typename Lhs, typename Rhs>
148 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
149  : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
150 {
151  typedef Dense StorageKind;
152  typedef MatrixXpr XprKind;
153 };
154 
155 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
156  int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
157  bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
158 struct sparse_time_dense_product_impl;
159 
160 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
161 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
162 {
163  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
164  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
165  typedef typename internal::remove_all<DenseResType>::type Res;
166  typedef typename Lhs::Index Index;
167  typedef typename Lhs::InnerIterator LhsInnerIterator;
168  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
169  {
170  for(Index c=0; c<rhs.cols(); ++c)
171  {
172  int n = lhs.outerSize();
173  for(Index j=0; j<n; ++j)
174  {
175  typename Res::Scalar tmp(0);
176  for(LhsInnerIterator it(lhs,j); it ;++it)
177  tmp += it.value() * rhs.coeff(it.index(),c);
178  res.coeffRef(j,c) = alpha * tmp;
179  }
180  }
181  }
182 };
183 
184 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
185 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
186 {
187  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
188  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
189  typedef typename internal::remove_all<DenseResType>::type Res;
190  typedef typename Lhs::InnerIterator LhsInnerIterator;
191  typedef typename Lhs::Index Index;
192  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
193  {
194  for(Index c=0; c<rhs.cols(); ++c)
195  {
196  for(Index j=0; j<lhs.outerSize(); ++j)
197  {
198  typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
199  for(LhsInnerIterator it(lhs,j); it ;++it)
200  res.coeffRef(it.index(),c) += it.value() * rhs_j;
201  }
202  }
203  }
204 };
205 
206 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
207 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
208 {
209  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
210  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
211  typedef typename internal::remove_all<DenseResType>::type Res;
212  typedef typename Lhs::InnerIterator LhsInnerIterator;
213  typedef typename Lhs::Index Index;
214  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
215  {
216  for(Index j=0; j<lhs.outerSize(); ++j)
217  {
218  typename Res::RowXpr res_j(res.row(j));
219  for(LhsInnerIterator it(lhs,j); it ;++it)
220  res_j += (alpha*it.value()) * rhs.row(it.index());
221  }
222  }
223 };
224 
225 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
226 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
227 {
228  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
229  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
230  typedef typename internal::remove_all<DenseResType>::type Res;
231  typedef typename Lhs::InnerIterator LhsInnerIterator;
232  typedef typename Lhs::Index Index;
233  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
234  {
235  for(Index j=0; j<lhs.outerSize(); ++j)
236  {
237  typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
238  for(LhsInnerIterator it(lhs,j); it ;++it)
239  res.row(it.index()) += (alpha*it.value()) * rhs_j;
240  }
241  }
242 };
243 
244 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
245 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
246 {
247  sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
248 }
249 
250 } // end namespace internal
251 
252 template<typename Lhs, typename Rhs>
254  : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
255 {
256  public:
258 
259  SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
260  {}
261 
262  template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
263  {
265  }
266 
267  private:
269 };
270 
271 
272 // dense = dense * sparse
273 namespace internal {
274 template<typename Lhs, typename Rhs>
275 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
276  : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
277 {
278  typedef Dense StorageKind;
279 };
280 } // end namespace internal
281 
282 template<typename Lhs, typename Rhs>
284  : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
285 {
286  public:
288 
289  DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
290  {}
291 
292  template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
293  {
296  Transpose<Dest> dest_t(dest);
297  internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
298  }
299 
300  private:
302 };
303 
304 // sparse * dense
305 template<typename Derived>
306 template<typename OtherDerived>
309 {
310  return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
311 }
312 
313 } // end namespace Eigen
314 
315 #endif // EIGEN_SPARSEDENSEPRODUCT_H