10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
17 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder,
int Version=Specialized>
18 struct triangular_matrix_vector_product;
20 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,
ColMajor,Version>
23 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
27 HasZeroDiag = (Mode &
ZeroDiag)==ZeroDiag
29 static EIGEN_DONT_INLINE
void run(Index _rows, Index _cols,
const LhsScalar* _lhs, Index lhsStride,
30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
32 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
33 Index size = (std::min)(_rows,_cols);
34 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
35 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
37 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
38 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
39 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
41 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
42 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
43 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
45 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
46 ResMap res(_res,rows);
48 for (Index pi=0; pi<size; pi+=PanelWidth)
50 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
51 for (Index k=0; k<actualPanelWidth; ++k)
54 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
55 Index r = IsLower ? actualPanelWidth-k : k+1;
56 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
57 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
59 res.coeffRef(i) += alpha * cjRhs.coeff(i);
61 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
64 Index s = IsLower ? pi+actualPanelWidth : 0;
65 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
67 &lhs.coeffRef(s,pi), lhsStride,
68 &rhs.coeffRef(pi), rhsIncr,
69 &res.coeffRef(s), resIncr, alpha);
72 if((!IsLower) && cols>size)
74 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
76 &lhs.coeffRef(0,size), lhsStride,
77 &rhs.coeffRef(size), rhsIncr,
78 _res, resIncr, alpha);
83 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
84 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,
RowMajor,Version>
86 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
90 HasZeroDiag = (Mode &
ZeroDiag)==ZeroDiag
92 static void run(Index _rows, Index _cols,
const LhsScalar* _lhs, Index lhsStride,
93 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
95 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
96 Index diagSize = (std::min)(_rows,_cols);
97 Index rows = IsLower ? _rows : diagSize;
98 Index cols = IsLower ? diagSize : _cols;
100 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
101 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
102 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
104 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
105 const RhsMap rhs(_rhs,cols);
106 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
108 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
109 ResMap res(_res,rows,InnerStride<>(resIncr));
111 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
113 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
114 for (Index k=0; k<actualPanelWidth; ++k)
117 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
118 Index r = IsLower ? k+1 : actualPanelWidth-k;
119 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
120 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
122 res.coeffRef(i) += alpha * cjRhs.coeff(i);
124 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
127 Index s = IsLower ? 0 : pi + actualPanelWidth;
128 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
130 &lhs.coeffRef(pi,s), lhsStride,
131 &rhs.coeffRef(s), rhsIncr,
132 &res.coeffRef(pi), resIncr, alpha);
135 if(IsLower && rows>diagSize)
137 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
139 &lhs.coeffRef(diagSize,0), lhsStride,
140 &rhs.coeffRef(0), rhsIncr,
141 &res.coeffRef(diagSize), resIncr, alpha);
150 template<
int Mode,
bool LhsIsTriangular,
typename Lhs,
typename Rhs>
151 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
152 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
155 template<
int Mode,
bool LhsIsTriangular,
typename Lhs,
typename Rhs>
156 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
157 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
161 template<
int StorageOrder>
162 struct trmv_selector;
166 template<
int Mode,
typename Lhs,
typename Rhs>
167 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
168 :
public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
170 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
172 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
174 template<
typename Dest>
void scaleAndAddTo(Dest& dst, Scalar alpha)
const
176 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
182 template<
int Mode,
typename Lhs,
typename Rhs>
183 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
184 :
public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
186 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
188 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
190 template<
typename Dest>
void scaleAndAddTo(Dest& dst, Scalar alpha)
const
192 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
194 typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,
false,Transpose<const Lhs>,
true> TriangularProductTranspose;
195 Transpose<Dest> dstT(dst);
197 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
205 template<>
struct trmv_selector<
ColMajor>
207 template<
int Mode,
typename Lhs,
typename Rhs,
typename Dest>
208 static void run(
const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest,
typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
210 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
211 typedef typename ProductType::Index Index;
212 typedef typename ProductType::LhsScalar LhsScalar;
213 typedef typename ProductType::RhsScalar RhsScalar;
214 typedef typename ProductType::Scalar ResScalar;
215 typedef typename ProductType::RealScalar RealScalar;
216 typedef typename ProductType::ActualLhsType ActualLhsType;
217 typedef typename ProductType::ActualRhsType ActualRhsType;
218 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
219 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
220 typedef Map<Matrix<ResScalar,Dynamic,1>,
Aligned> MappedDest;
222 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
223 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
225 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
226 * RhsBlasTraits::extractScalarFactor(prod.rhs());
231 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
232 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
233 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
236 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
238 bool alphaIsCompatible = (!ComplexByReal) || (
imag(actualAlpha)==RealScalar(0));
239 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
241 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
243 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
244 evalToDest ? dest.data() : static_dest.data());
248 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249 int size = dest.size();
250 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
252 if(!alphaIsCompatible)
254 MappedDest(actualDestPtr, dest.size()).setZero();
255 compatibleAlpha = RhsScalar(1);
258 MappedDest(actualDestPtr, dest.size()) = dest;
261 internal::triangular_matrix_vector_product
263 LhsScalar, LhsBlasTraits::NeedToConjugate,
264 RhsScalar, RhsBlasTraits::NeedToConjugate,
266 ::run(actualLhs.rows(),actualLhs.cols(),
267 actualLhs.data(),actualLhs.outerStride(),
268 actualRhs.data(),actualRhs.innerStride(),
269 actualDestPtr,1,compatibleAlpha);
273 if(!alphaIsCompatible)
274 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
276 dest = MappedDest(actualDestPtr, dest.size());
281 template<>
struct trmv_selector<
RowMajor>
283 template<
int Mode,
typename Lhs,
typename Rhs,
typename Dest>
284 static void run(
const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest,
typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
286 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
287 typedef typename ProductType::LhsScalar LhsScalar;
288 typedef typename ProductType::RhsScalar RhsScalar;
289 typedef typename ProductType::Scalar ResScalar;
290 typedef typename ProductType::Index Index;
291 typedef typename ProductType::ActualLhsType ActualLhsType;
292 typedef typename ProductType::ActualRhsType ActualRhsType;
293 typedef typename ProductType::_ActualRhsType _ActualRhsType;
294 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
295 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
297 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
298 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
300 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
301 * RhsBlasTraits::extractScalarFactor(prod.rhs());
304 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
307 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
309 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
310 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
314 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
315 int size = actualRhs.size();
316 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
318 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
321 internal::triangular_matrix_vector_product
323 LhsScalar, LhsBlasTraits::NeedToConjugate,
324 RhsScalar, RhsBlasTraits::NeedToConjugate,
326 ::run(actualLhs.rows(),actualLhs.cols(),
327 actualLhs.data(),actualLhs.outerStride(),
329 dest.data(),dest.innerStride(),
338 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H