SparseCwiseBinaryOp.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24 
25 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
26 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
27 
28 namespace Eigen {
29 
30 // Here we have to handle 3 cases:
31 // 1 - sparse op dense
32 // 2 - dense op sparse
33 // 3 - sparse op sparse
34 // We also need to implement a 4th iterator for:
35 // 4 - dense op dense
36 // Finally, we also need to distinguish between the product and other operations :
37 // configuration returned mode
38 // 1 - sparse op dense product sparse
39 // generic dense
40 // 2 - dense op sparse product sparse
41 // generic dense
42 // 3 - sparse op sparse product sparse
43 // generic sparse
44 // 4 - dense op dense product dense
45 // generic dense
46 
47 namespace internal {
48 
49 template<> struct promote_storage_type<Dense,Sparse>
50 { typedef Sparse ret; };
51 
52 template<> struct promote_storage_type<Sparse,Dense>
53 { typedef Sparse ret; };
54 
55 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
56  typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
57  typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
58 class sparse_cwise_binary_op_inner_iterator_selector;
59 
60 } // end namespace internal
61 
62 template<typename BinaryOp, typename Lhs, typename Rhs>
63 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
64  : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
65 {
66  public:
67  class InnerIterator;
68  class ReverseInnerIterator;
72  {
73  typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
74  typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
76  (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
77  || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
78  THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
79  }
80 };
81 
82 template<typename BinaryOp, typename Lhs, typename Rhs>
83 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
84  : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
85 {
86  public:
87  typedef typename Lhs::Index Index;
88  typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
89  BinaryOp,Lhs,Rhs, InnerIterator> Base;
90 
91  EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename CwiseBinaryOpImpl::Index outer)
92  : Base(binOp.derived(),outer)
93  {}
94 };
95 
96 /***************************************************************************
97 * Implementation of inner-iterators
98 ***************************************************************************/
99 
100 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; };
101 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; };
102 
103 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any !
104 
105 namespace internal {
106 
107 // sparse - sparse (generic)
108 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
109 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
110 {
111  typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
112  typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
113  typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
114  typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
115  typedef typename _LhsNested::InnerIterator LhsIterator;
116  typedef typename _RhsNested::InnerIterator RhsIterator;
117  typedef typename Lhs::Index Index;
118 
119  public:
120 
121  EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
122  : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
123  {
124  this->operator++();
125  }
126 
127  EIGEN_STRONG_INLINE Derived& operator++()
128  {
129  if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
130  {
131  m_id = m_lhsIter.index();
132  m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
133  ++m_lhsIter;
134  ++m_rhsIter;
135  }
136  else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
137  {
138  m_id = m_lhsIter.index();
139  m_value = m_functor(m_lhsIter.value(), Scalar(0));
140  ++m_lhsIter;
141  }
142  else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
143  {
144  m_id = m_rhsIter.index();
145  m_value = m_functor(Scalar(0), m_rhsIter.value());
146  ++m_rhsIter;
147  }
148  else
149  {
150  m_value = 0; // this is to avoid a compilation warning
151  m_id = -1;
152  }
153  return *static_cast<Derived*>(this);
154  }
155 
156  EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
157 
158  EIGEN_STRONG_INLINE Index index() const { return m_id; }
159  EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
160  EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
161 
162  EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
163 
164  protected:
165  LhsIterator m_lhsIter;
166  RhsIterator m_rhsIter;
167  const BinaryOp& m_functor;
168  Scalar m_value;
169  Index m_id;
170 };
171 
172 // sparse - sparse (product)
173 template<typename T, typename Lhs, typename Rhs, typename Derived>
174 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
175 {
176  typedef scalar_product_op<T> BinaryFunc;
177  typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
178  typedef typename CwiseBinaryXpr::Scalar Scalar;
179  typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
180  typedef typename _LhsNested::InnerIterator LhsIterator;
181  typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
182  typedef typename _RhsNested::InnerIterator RhsIterator;
183  typedef typename Lhs::Index Index;
184  public:
185 
186  EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
187  : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
188  {
189  while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
190  {
191  if (m_lhsIter.index() < m_rhsIter.index())
192  ++m_lhsIter;
193  else
194  ++m_rhsIter;
195  }
196  }
197 
198  EIGEN_STRONG_INLINE Derived& operator++()
199  {
200  ++m_lhsIter;
201  ++m_rhsIter;
202  while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
203  {
204  if (m_lhsIter.index() < m_rhsIter.index())
205  ++m_lhsIter;
206  else
207  ++m_rhsIter;
208  }
209  return *static_cast<Derived*>(this);
210  }
211 
212  EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
213 
214  EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
215  EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
216  EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
217 
218  EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
219 
220  protected:
221  LhsIterator m_lhsIter;
222  RhsIterator m_rhsIter;
223  const BinaryFunc& m_functor;
224 };
225 
226 // sparse - dense (product)
227 template<typename T, typename Lhs, typename Rhs, typename Derived>
228 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
229 {
230  typedef scalar_product_op<T> BinaryFunc;
231  typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
232  typedef typename CwiseBinaryXpr::Scalar Scalar;
233  typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
234  typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
235  typedef typename _LhsNested::InnerIterator LhsIterator;
236  typedef typename Lhs::Index Index;
237  enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
238  public:
239 
240  EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
241  : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
242  {}
243 
244  EIGEN_STRONG_INLINE Derived& operator++()
245  {
246  ++m_lhsIter;
247  return *static_cast<Derived*>(this);
248  }
249 
250  EIGEN_STRONG_INLINE Scalar value() const
251  { return m_functor(m_lhsIter.value(),
252  m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
253 
254  EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
255  EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
256  EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
257 
258  EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
259 
260  protected:
261  RhsNested m_rhs;
262  LhsIterator m_lhsIter;
263  const BinaryFunc m_functor;
264  const Index m_outer;
265 };
266 
267 // sparse - dense (product)
268 template<typename T, typename Lhs, typename Rhs, typename Derived>
269 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
270 {
271  typedef scalar_product_op<T> BinaryFunc;
272  typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
273  typedef typename CwiseBinaryXpr::Scalar Scalar;
274  typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
275  typedef typename _RhsNested::InnerIterator RhsIterator;
276  typedef typename Lhs::Index Index;
277 
278  enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
279  public:
280 
281  EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
282  : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
283  {}
284 
285  EIGEN_STRONG_INLINE Derived& operator++()
286  {
287  ++m_rhsIter;
288  return *static_cast<Derived*>(this);
289  }
290 
291  EIGEN_STRONG_INLINE Scalar value() const
292  { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
293 
294  EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
295  EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
296  EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
297 
298  EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
299 
300  protected:
301  const CwiseBinaryXpr& m_xpr;
302  RhsIterator m_rhsIter;
303  const BinaryFunc& m_functor;
304  const Index m_outer;
305 };
306 
307 } // end namespace internal
308 
309 /***************************************************************************
310 * Implementation of SparseMatrixBase and SparseCwise functions/operators
311 ***************************************************************************/
312 
313 template<typename Derived>
314 template<typename OtherDerived>
315 EIGEN_STRONG_INLINE Derived &
317 {
318  return *this = derived() - other.derived();
319 }
320 
321 template<typename Derived>
322 template<typename OtherDerived>
323 EIGEN_STRONG_INLINE Derived &
325 {
326  return *this = derived() + other.derived();
327 }
328 
329 template<typename Derived>
330 template<typename OtherDerived>
333 {
334  return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
335 }
336 
337 } // end namespace Eigen
338 
339 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H