25 #ifndef EIGEN_PASTIXSUPPORT_H
26 #define EIGEN_PASTIXSUPPORT_H
39 template<
typename _MatrixType,
bool IsStrSym = false>
class PastixLU;
40 template<
typename _MatrixType,
int Options>
class PastixLLT;
41 template<
typename _MatrixType,
int Options>
class PastixLDLT;
46 template<
class Pastix>
struct pastix_traits;
48 template<
typename _MatrixType>
49 struct pastix_traits< PastixLU<_MatrixType> >
51 typedef _MatrixType MatrixType;
52 typedef typename _MatrixType::Scalar Scalar;
53 typedef typename _MatrixType::RealScalar RealScalar;
54 typedef typename _MatrixType::Index Index;
57 template<
typename _MatrixType,
int Options>
58 struct pastix_traits< PastixLLT<_MatrixType,Options> >
60 typedef _MatrixType MatrixType;
61 typedef typename _MatrixType::Scalar Scalar;
62 typedef typename _MatrixType::RealScalar RealScalar;
63 typedef typename _MatrixType::Index Index;
66 template<
typename _MatrixType,
int Options>
67 struct pastix_traits< PastixLDLT<_MatrixType,Options> >
69 typedef _MatrixType MatrixType;
70 typedef typename _MatrixType::Scalar Scalar;
71 typedef typename _MatrixType::RealScalar RealScalar;
72 typedef typename _MatrixType::Index Index;
75 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
float *vals,
int *perm,
int * invp,
float *x,
int nbrhs,
int *iparm,
double *dparm)
77 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
78 if (nbrhs == 0) x = NULL;
79 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
82 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
double *vals,
int *perm,
int * invp,
double *x,
int nbrhs,
int *iparm,
double *dparm)
84 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
85 if (nbrhs == 0) x = NULL;
86 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
89 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<float> *vals,
int *perm,
int * invp, std::complex<float> *x,
int nbrhs,
int *iparm,
double *dparm)
91 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm);
94 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<double> *vals,
int *perm,
int * invp, std::complex<double> *x,
int nbrhs,
int *iparm,
double *dparm)
96 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
97 if (nbrhs == 0) x = NULL;
98 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm);
102 template <
typename MatrixType>
105 if ( !(mat.outerIndexPtr()[0]) )
108 for(i = 0; i <= mat.rows(); ++i)
109 ++mat.outerIndexPtr()[i];
110 for(i = 0; i < mat.nonZeros(); ++i)
111 ++mat.innerIndexPtr()[i];
116 template <
typename MatrixType>
120 if ( mat.outerIndexPtr()[0] == 1 )
123 for(i = 0; i <= mat.rows(); ++i)
124 --mat.outerIndexPtr()[i];
125 for(i = 0; i < mat.nonZeros(); ++i)
126 --mat.innerIndexPtr()[i];
137 template <
typename MatrixType>
140 eigen_assert(In.cols()==In.rows() &&
" Can only symmetrize the graph of a square matrix");
143 StrMatTrans = In.transpose();
145 for (
int i = 0; i < StrMatTrans.rows(); i++)
147 for (
typename MatrixType::InnerIterator it(StrMatTrans, i); it; ++it)
152 Out = (StrMatTrans + In).eval();
159 template <
class Derived>
163 typedef typename internal::pastix_traits<Derived>::MatrixType
_MatrixType;
165 typedef typename MatrixType::Scalar
Scalar;
167 typedef typename MatrixType::Index
Index;
197 template<
typename Rhs>
198 inline const internal::solve_retval<PastixBase, Rhs>
203 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
204 return internal::solve_retval<PastixBase, Rhs>(*
this, b.derived());
207 template<
typename Rhs,
typename Dest>
211 template<
typename Rhs,
typename DestScalar,
int DestOptions,
typename DestIndex>
214 eigen_assert(
m_factorizationIsOk &&
"The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
218 static const int NbColsAtOnce = 1;
219 int rhsCols = b.cols();
222 for(
int k=0; k<rhsCols; k+=NbColsAtOnce)
224 int actualCols = std::min<int>(rhsCols-k, NbColsAtOnce);
225 tmp.
leftCols(actualCols) = b.middleCols(k,actualCols);
233 return *
static_cast<Derived*
>(
this);
237 return *
static_cast<const Derived*
>(
this);
299 template<
typename Rhs>
300 inline const internal::sparse_solve_retval<PastixBase, Rhs>
305 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
306 return internal::sparse_solve_retval<PastixBase, Rhs>(*
this, b.
derived());
314 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
315 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
350 template <
class Derived>
354 m_iparm.resize(IPARM_SIZE);
355 m_dparm.resize(DPARM_SIZE);
357 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
363 m_iparm(IPARM_MODIFY_PARAMETER) = API_YES;
364 m_hasTranspose =
false;
367 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
368 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
369 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
370 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, m_mat_null.outerIndexPtr(), m_mat_null.innerIndexPtr(),
371 m_mat_null.valuePtr(), m_perm.data(), m_invp.data(), m_vec_null.data(), 1, m_iparm.data(), m_dparm.data());
373 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
376 if(m_iparm(IPARM_ERROR_NUMBER)) {
386 template <
class Derived>
389 eigen_assert(mat.rows() == mat.cols() &&
"The input matrix should be squared");
390 typedef typename MatrixType::Scalar
Scalar;
398 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
399 if (m_factorizationIsOk) m_isInitialized =
true;
408 template <
class Derived>
411 eigen_assert(m_initisOk &&
"PastixInit should be called first to set the default parameters");
413 m_perm.resize(m_size);
414 m_invp.resize(m_size);
419 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
420 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
423 mat.valuePtr(), m_perm.data(), m_invp.data(), m_vec_null.data(), 0, m_iparm.data(), m_dparm.data());
426 if(m_iparm(IPARM_ERROR_NUMBER)) {
428 m_analysisIsOk =
false;
432 m_analysisIsOk =
true;
437 template <
class Derived>
440 eigen_assert(m_analysisIsOk &&
"The analysis phase should be called before the factorization phase");
441 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
442 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
449 mat.valuePtr(), m_perm.data(), m_invp.data(), m_vec_null.data(), 0, m_iparm.data(), m_dparm.data());
452 if(m_iparm(IPARM_ERROR_NUMBER)) {
454 m_factorizationIsOk =
false;
455 m_isInitialized =
false;
459 m_factorizationIsOk =
true;
460 m_isInitialized =
true;
466 template<
typename Base>
467 template<
typename Rhs,
typename Dest>
470 eigen_assert(m_isInitialized &&
"The matrix should be factorized first");
472 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
477 for (
int i = 0; i < b.cols(); i++){
478 m_iparm(IPARM_START_TASK) = API_TASK_SOLVE;
479 m_iparm(IPARM_END_TASK) = API_TASK_REFINE;
480 m_iparm(IPARM_RHS_MAKING) = API_RHS_B;
481 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), m_mat_null.outerIndexPtr(), m_mat_null.innerIndexPtr(),
482 m_mat_null.valuePtr(), m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
485 if(m_iparm(IPARM_ERROR_NUMBER)) {
513 template<
typename _MatrixType,
bool IsStrSym>
519 typedef typename MatrixType::Scalar
Scalar;
538 PaStiXType temp(matrix.rows(), matrix.cols());
546 m_iparm[IPARM_SYM] = API_SYM_NO;
547 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
569 m_iparm(IPARM_SYM) = API_SYM_NO;
570 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
590 m_iparm(IPARM_SYM) = API_SYM_NO;
591 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
618 template<
typename _MatrixType,
int _UpLo>
624 typedef typename MatrixType::Scalar
Scalar;
625 typedef typename MatrixType::Index
Index;
645 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
646 m_iparm(IPARM_SYM) = API_SYM_YES;
647 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
661 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
662 m_iparm(IPARM_SYM) = API_SYM_YES;
663 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
674 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
675 m_iparm(IPARM_SYM) = API_SYM_YES;
676 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
700 template<
typename _MatrixType,
int _UpLo>
706 typedef typename MatrixType::Scalar
Scalar;
707 typedef typename MatrixType::Index
Index;
727 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
728 m_iparm(IPARM_SYM) = API_SYM_YES;
729 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
743 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
745 m_iparm(IPARM_SYM) = API_SYM_YES;
746 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
757 temp.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>().twistedBy(pnull);
759 m_iparm(IPARM_SYM) = API_SYM_YES;
760 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
773 template<
typename _MatrixType,
typename Rhs>
774 struct solve_retval<PastixBase<_MatrixType>, Rhs>
775 : solve_retval_base<PastixBase<_MatrixType>, Rhs>
777 typedef PastixBase<_MatrixType> Dec;
780 template<typename Dest>
void evalTo(Dest& dst)
const
782 dec()._solve(rhs(),dst);
786 template<
typename _MatrixType,
typename Rhs>
787 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
788 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
790 typedef PastixBase<_MatrixType> Dec;
793 template<typename Dest>
void evalTo(Dest& dst)
const
795 dec()._solve_sparse(rhs(),dst);