25 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
26 #define EIGEN_TRIANGULARMATRIXVECTOR_H
32 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder,
int Version=Specialized>
35 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
38 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
42 HasZeroDiag = (Mode &
ZeroDiag)==ZeroDiag
44 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols,
const LhsScalar* _lhs, Index lhsStride,
45 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
48 Index size = (std::min)(_rows,_cols);
49 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
50 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
52 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
53 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
54 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
56 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
57 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
58 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
60 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
61 ResMap res(_res,rows);
63 for (Index pi=0; pi<size; pi+=PanelWidth)
65 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
66 for (Index k=0; k<actualPanelWidth; ++k)
69 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
70 Index r = IsLower ? actualPanelWidth-k : k+1;
71 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
72 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
74 res.coeffRef(i) += alpha * cjRhs.coeff(i);
76 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
79 Index s = IsLower ? pi+actualPanelWidth : 0;
82 &lhs.coeffRef(s,pi), lhsStride,
83 &rhs.coeffRef(pi), rhsIncr,
84 &res.coeffRef(s), resIncr, alpha);
87 if((!IsLower) && cols>size)
91 &lhs.coeffRef(0,size), lhsStride,
92 &rhs.coeffRef(size), rhsIncr,
93 _res, resIncr, alpha);
98 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int Version>
101 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
105 HasZeroDiag = (Mode &
ZeroDiag)==ZeroDiag
107 static void run(Index _rows, Index _cols,
const LhsScalar* _lhs, Index lhsStride,
108 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
111 Index diagSize = (std::min)(_rows,_cols);
112 Index rows = IsLower ? _rows : diagSize;
113 Index cols = IsLower ? diagSize : _cols;
115 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
116 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
117 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
119 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
120 const RhsMap rhs(_rhs,cols);
121 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
123 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
124 ResMap res(_res,rows,InnerStride<>(resIncr));
126 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
128 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
129 for (Index k=0; k<actualPanelWidth; ++k)
132 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
133 Index r = IsLower ? k+1 : actualPanelWidth-k;
134 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
135 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
137 res.coeffRef(i) += alpha * cjRhs.coeff(i);
139 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142 Index s = IsLower ? 0 : pi + actualPanelWidth;
145 &lhs.coeffRef(pi,s), lhsStride,
146 &rhs.coeffRef(s), rhsIncr,
147 &res.coeffRef(pi), resIncr, alpha);
150 if(IsLower && rows>diagSize)
154 &lhs.coeffRef(diagSize,0), lhsStride,
155 &rhs.coeffRef(0), rhsIncr,
156 &res.coeffRef(diagSize), resIncr, alpha);
165 template<
int Mode,
bool LhsIsTriangular,
typename Lhs,
typename Rhs>
166 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
167 :
traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
170 template<
int Mode,
bool LhsIsTriangular,
typename Lhs,
typename Rhs>
171 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
172 :
traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
176 template<
int StorageOrder>
177 struct trmv_selector;
181 template<
int Mode,
typename Lhs,
typename Rhs>
182 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
183 :
public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
187 TriangularProduct(const Lhs& lhs, const Rhs& rhs) :
Base(lhs,rhs) {}
189 template<
typename Dest>
void scaleAndAddTo(Dest& dst,
Scalar alpha)
const
191 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
197 template<
int Mode,
typename Lhs,
typename Rhs>
198 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
199 :
public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
203 TriangularProduct(const Lhs& lhs, const Rhs& rhs) :
Base(lhs,rhs) {}
205 template<
typename Dest>
void scaleAndAddTo(Dest& dst,
Scalar alpha)
const
207 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
209 typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,
false,
Transpose<const Lhs>,
true> TriangularProductTranspose;
212 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
220 template<>
struct trmv_selector<
ColMajor>
222 template<
int Mode,
typename Lhs,
typename Rhs,
typename Dest>
223 static void run(
const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest,
typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
225 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
226 typedef typename ProductType::Index Index;
227 typedef typename ProductType::LhsScalar LhsScalar;
228 typedef typename ProductType::RhsScalar RhsScalar;
229 typedef typename ProductType::Scalar ResScalar;
230 typedef typename ProductType::RealScalar RealScalar;
231 typedef typename ProductType::ActualLhsType ActualLhsType;
232 typedef typename ProductType::ActualRhsType ActualRhsType;
233 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
234 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
235 typedef Map<Matrix<ResScalar,Dynamic,1>,
Aligned> MappedDest;
237 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
238 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
240 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
241 * RhsBlasTraits::extractScalarFactor(prod.rhs());
246 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
248 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
251 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
253 bool alphaIsCompatible = (!ComplexByReal) || (
imag(actualAlpha)==RealScalar(0));
254 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
256 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
259 evalToDest ? dest.data() : static_dest.data());
263 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
264 int size = dest.size();
265 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
267 if(!alphaIsCompatible)
269 MappedDest(actualDestPtr, dest.size()).setZero();
270 compatibleAlpha = RhsScalar(1);
273 MappedDest(actualDestPtr, dest.size()) = dest;
276 internal::triangular_matrix_vector_product
278 LhsScalar, LhsBlasTraits::NeedToConjugate,
279 RhsScalar, RhsBlasTraits::NeedToConjugate,
281 ::run(actualLhs.rows(),actualLhs.cols(),
282 actualLhs.data(),actualLhs.outerStride(),
283 actualRhs.data(),actualRhs.innerStride(),
284 actualDestPtr,1,compatibleAlpha);
288 if(!alphaIsCompatible)
289 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
291 dest = MappedDest(actualDestPtr, dest.size());
296 template<>
struct trmv_selector<
RowMajor>
298 template<
int Mode,
typename Lhs,
typename Rhs,
typename Dest>
299 static void run(
const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest,
typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
301 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
302 typedef typename ProductType::LhsScalar LhsScalar;
303 typedef typename ProductType::RhsScalar RhsScalar;
304 typedef typename ProductType::Scalar ResScalar;
305 typedef typename ProductType::Index Index;
306 typedef typename ProductType::ActualLhsType ActualLhsType;
307 typedef typename ProductType::ActualRhsType ActualRhsType;
308 typedef typename ProductType::_ActualRhsType _ActualRhsType;
309 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
310 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
312 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
313 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
315 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
316 * RhsBlasTraits::extractScalarFactor(prod.rhs());
319 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
322 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
325 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
329 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
330 int size = actualRhs.size();
331 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
333 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
336 internal::triangular_matrix_vector_product
338 LhsScalar, LhsBlasTraits::NeedToConjugate,
339 RhsScalar, RhsBlasTraits::NeedToConjugate,
341 ::run(actualLhs.rows(),actualLhs.cols(),
342 actualLhs.data(),actualLhs.outerStride(),
344 dest.data(),dest.innerStride(),
353 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H