7 #ifndef SPECTRA_SYM_SHIFT_INVERT_H
8 #define SPECTRA_SYM_SHIFT_INVERT_H
11 #include <Eigen/SparseCore>
12 #include <Eigen/SparseLU>
14 #include <type_traits>
16 #include "../LinAlg/BKLDLT.h"
17 #include "../Util/CompInfo.h"
25 template <
bool AIsSparse,
bool BIsSparse,
int UploA,
int UploB>
26 class SymShiftInvertHelper
29 template <
typename Scalar,
typename Fac,
typename ArgA,
typename ArgB>
30 static bool factorize(Fac& fac,
const ArgA& A,
const ArgB& B,
const Scalar& sigma)
32 using SpMat =
typename ArgA::PlainObject;
33 SpMat matA = A.template selfadjointView<UploA>();
34 SpMat matB = B.template selfadjointView<UploB>();
35 SpMat mat = matA - sigma * matB;
37 fac.isSymmetric(
true);
40 return fac.info() == Eigen::Success;
45 template <
bool BIsSparse,
int UploA,
int UploB>
46 class SymShiftInvertHelper<false, BIsSparse, UploA, UploB>
49 template <
typename Scalar,
typename Fac,
typename ArgA,
typename ArgB>
50 static bool factorize(Fac& fac,
const ArgA& A,
const ArgB& B,
const Scalar& sigma)
52 using Matrix =
typename ArgA::PlainObject;
54 Matrix mat(A.rows(), A.cols());
55 mat.template triangularView<UploA>() = A;
58 mat -= (B * sigma).
template triangularView<UploA>();
60 mat -= (B * sigma).
template triangularView<UploB>().transpose();
62 fac.compute(mat, UploA);
69 template <
int UploA,
int UploB>
70 class SymShiftInvertHelper<true, false, UploA, UploB>
73 template <
typename Scalar,
typename Fac,
typename ArgA,
typename ArgB>
74 static bool factorize(Fac& fac,
const ArgA& A,
const ArgB& B,
const Scalar& sigma)
76 using Matrix =
typename ArgB::PlainObject;
78 Matrix mat(B.rows(), B.cols());
79 mat.template triangularView<UploB>() = -sigma * B;
82 mat += A.template triangularView<UploB>();
84 mat += A.template triangularView<UploA>().transpose();
86 fac.compute(mat, UploB);
123 template <
typename Scalar_,
typename TypeA = Eigen::Sparse,
typename TypeB = Eigen::Sparse,
124 int UploA = Eigen::Lower,
int UploB = Eigen::Lower,
125 int FlagsA = Eigen::ColMajor,
int FlagsB = Eigen::ColMajor,
126 typename StorageIndexA = int,
typename StorageIndexB =
int>
136 using Index = Eigen::Index;
139 using DenseTypeA = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, FlagsA>;
140 using SparseTypeA = Eigen::SparseMatrix<Scalar, FlagsA, StorageIndexA>;
142 using ASparse = std::is_same<TypeA, Eigen::Sparse>;
144 using MatrixA =
typename std::conditional<ASparse::value, SparseTypeA, DenseTypeA>::type;
147 using DenseTypeB = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, FlagsB>;
148 using SparseTypeB = Eigen::SparseMatrix<Scalar, FlagsB, StorageIndexB>;
150 using BSparse = std::is_same<TypeB, Eigen::Sparse>;
152 using MatrixB =
typename std::conditional<BSparse::value, SparseTypeB, DenseTypeB>::type;
154 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
155 using MapConstVec = Eigen::Map<const Vector>;
156 using MapVec = Eigen::Map<Vector>;
160 using DenseType =
typename std::conditional<ASparse::value, MatrixB, MatrixA>::type;
163 using ResType =
typename std::conditional<ASparse::value && BSparse::value, MatrixA, DenseType>::type;
167 using FacType =
typename std::conditional<
168 ASparse::value && BSparse::value,
169 Eigen::SparseLU<ResType>,
170 BKLDLT<Scalar>>::type;
172 using ConstGenericMatrixA =
const Eigen::Ref<const MatrixA>;
173 using ConstGenericMatrixB =
const Eigen::Ref<const MatrixB>;
175 ConstGenericMatrixA m_matA;
176 ConstGenericMatrixB m_matB;
190 template <
typename DerivedA,
typename DerivedB>
191 SymShiftInvert(
const Eigen::EigenBase<DerivedA>& A,
const Eigen::EigenBase<DerivedB>& B) :
192 m_matA(A.derived()), m_matB(B.derived()), m_n(A.
rows())
195 static_cast<int>(DerivedA::PlainObject::IsRowMajor) ==
static_cast<int>(MatrixA::IsRowMajor),
196 "SymShiftInvert: the \"FlagsA\" template parameter does not match the input matrix (Eigen::ColMajor/Eigen::RowMajor)");
199 static_cast<int>(DerivedB::PlainObject::IsRowMajor) ==
static_cast<int>(MatrixB::IsRowMajor),
200 "SymShiftInvert: the \"FlagsB\" template parameter does not match the input matrix (Eigen::ColMajor/Eigen::RowMajor)");
202 if (m_n != A.cols() || m_n != B.rows() || m_n != B.cols())
203 throw std::invalid_argument(
"SymShiftInvert: A and B must be square matrices of the same size");
209 Index
rows()
const {
return m_n; }
213 Index
cols()
const {
return m_n; }
220 constexpr
bool AIsSparse = ASparse::value;
221 constexpr
bool BIsSparse = BSparse::value;
222 using Helper = SymShiftInvertHelper<AIsSparse, BIsSparse, UploA, UploB>;
223 const bool success = Helper::factorize(m_solver, m_matA, m_matB, sigma);
225 throw std::invalid_argument(
"SymShiftInvert: factorization failed with the given shift");
237 MapConstVec x(x_in, m_n);
238 MapVec y(y_out, m_n);
239 y.noalias() = m_solver.solve(x);
SymShiftInvert(const Eigen::EigenBase< DerivedA > &A, const Eigen::EigenBase< DerivedB > &B)
void set_shift(const Scalar &sigma)
void perform_op(const Scalar *x_in, Scalar *y_out) const
@ Successful
Computation was successful.