41 using Scalar =
typename OpType::Scalar;
42 using Index = Eigen::Index;
43 using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
44 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
45 using Array = Eigen::Array<Scalar, Eigen::Dynamic, 1>;
46 using BoolArray = Eigen::Array<bool, Eigen::Dynamic, 1>;
47 using MapMat = Eigen::Map<Matrix>;
48 using MapVec = Eigen::Map<Vector>;
49 using MapConstVec = Eigen::Map<const Vector>;
51 using Complex = std::complex<Scalar>;
52 using ComplexMatrix = Eigen::Matrix<Complex, Eigen::Dynamic, Eigen::Dynamic>;
53 using ComplexVector = Eigen::Matrix<Complex, Eigen::Dynamic, 1>;
55 using ArnoldiOpType = ArnoldiOp<Scalar, OpType, BOpType>;
56 using ArnoldiFac = Arnoldi<Scalar, ArnoldiOpType>;
70 ComplexVector m_ritz_val;
71 ComplexMatrix m_ritz_vec;
72 ComplexVector m_ritz_est;
75 BoolArray m_ritz_conv;
82 static bool is_complex(
const Complex& v) {
return v.imag() != Scalar(0); }
83 static bool is_conj(
const Complex& v1,
const Complex& v2) {
return v1 == Eigen::numext::conj(v2); }
86 void restart(Index k,
SortRule selection)
93 DoubleShiftQR<Scalar> decomp_ds(m_ncv);
94 UpperHessenbergQR<Scalar> decomp_hb(m_ncv);
95 Matrix Q = Matrix::Identity(m_ncv, m_ncv);
97 for (Index i = k; i < m_ncv; i++)
99 if (is_complex(m_ritz_val[i]) && is_conj(m_ritz_val[i], m_ritz_val[i + 1]))
107 const Scalar s = Scalar(2) * m_ritz_val[i].real();
108 const Scalar t = norm(m_ritz_val[i]);
110 decomp_ds.compute(m_fac.matrix_H(), s, t);
113 decomp_ds.apply_YQ(Q);
118 m_fac.compress_H(decomp_ds);
125 decomp_hb.compute(m_fac.matrix_H(), m_ritz_val[i].real());
128 decomp_hb.apply_YQ(Q);
130 m_fac.compress_H(decomp_hb);
135 m_fac.factorize_from(k, m_ncv, m_nmatop);
137 retrieve_ritzpair(selection);
141 Index num_converged(
const Scalar& tol)
146 const Scalar eps = TypeTraits<Scalar>::epsilon();
149 const Scalar eps23 = pow(eps, Scalar(2) / 3);
152 Array thresh = tol * m_ritz_val.head(m_nev).array().abs().max(eps23);
153 Array resid = m_ritz_est.head(m_nev).array().abs() * m_fac.f_norm();
155 m_ritz_conv = (resid < thresh);
157 return m_ritz_conv.count();
161 Index nev_adjusted(Index nconv)
167 const Scalar near_0 = TypeTraits<Scalar>::min() * Scalar(10);
169 Index nev_new = m_nev;
170 for (Index i = m_nev; i < m_ncv; i++)
171 if (abs(m_ritz_est[i]) < near_0)
175 nev_new += (std::min)(nconv, (m_ncv - nev_new) / 2);
176 if (nev_new == 1 && m_ncv >= 6)
178 else if (nev_new == 1 && m_ncv > 3)
181 if (nev_new > m_ncv - 2)
186 if (is_complex(m_ritz_val[nev_new - 1]) &&
187 is_conj(m_ritz_val[nev_new - 1], m_ritz_val[nev_new]))
196 void retrieve_ritzpair(
SortRule selection)
198 UpperHessenbergEigen<Scalar> decomp(m_fac.matrix_H());
199 const ComplexVector& evals = decomp.eigenvalues();
200 ComplexMatrix evecs = decomp.eigenvectors();
203 std::vector<Index> ind;
208 SortEigenvalue<Complex, SortRule::LargestMagn> sorting(evals.data(), m_ncv);
214 SortEigenvalue<Complex, SortRule::LargestReal> sorting(evals.data(), m_ncv);
220 SortEigenvalue<Complex, SortRule::LargestImag> sorting(evals.data(), m_ncv);
226 SortEigenvalue<Complex, SortRule::SmallestMagn> sorting(evals.data(), m_ncv);
232 SortEigenvalue<Complex, SortRule::SmallestReal> sorting(evals.data(), m_ncv);
238 SortEigenvalue<Complex, SortRule::SmallestImag> sorting(evals.data(), m_ncv);
243 throw std::invalid_argument(
"unsupported selection rule");
247 for (Index i = 0; i < m_ncv; i++)
249 m_ritz_val[i] = evals[ind[i]];
250 m_ritz_est[i] = evecs(m_ncv - 1, ind[i]);
252 for (Index i = 0; i < m_nev; i++)
254 m_ritz_vec.col(i).noalias() = evecs.col(ind[i]);
261 virtual void sort_ritzpair(
SortRule sort_rule)
263 std::vector<Index> ind;
268 SortEigenvalue<Complex, SortRule::LargestMagn> sorting(m_ritz_val.data(), m_nev);
274 SortEigenvalue<Complex, SortRule::LargestReal> sorting(m_ritz_val.data(), m_nev);
280 SortEigenvalue<Complex, SortRule::LargestImag> sorting(m_ritz_val.data(), m_nev);
286 SortEigenvalue<Complex, SortRule::SmallestMagn> sorting(m_ritz_val.data(), m_nev);
292 SortEigenvalue<Complex, SortRule::SmallestReal> sorting(m_ritz_val.data(), m_nev);
298 SortEigenvalue<Complex, SortRule::SmallestImag> sorting(m_ritz_val.data(), m_nev);
303 throw std::invalid_argument(
"unsupported sorting rule");
306 ComplexVector new_ritz_val(m_ncv);
307 ComplexMatrix new_ritz_vec(m_ncv, m_nev);
308 BoolArray new_ritz_conv(m_nev);
310 for (Index i = 0; i < m_nev; i++)
312 new_ritz_val[i] = m_ritz_val[ind[i]];
313 new_ritz_vec.col(i).noalias() = m_ritz_vec.col(ind[i]);
314 new_ritz_conv[i] = m_ritz_conv[ind[i]];
317 m_ritz_val.swap(new_ritz_val);
318 m_ritz_vec.swap(new_ritz_vec);
319 m_ritz_conv.swap(new_ritz_conv);
325 GenEigsBase(OpType& op,
const BOpType& Bop, Index nev, Index ncv) :
329 m_ncv(ncv > m_n ? m_n : ncv),
332 m_fac(ArnoldiOpType(op, Bop), m_ncv),
335 if (nev < 1 || nev > m_n - 2)
336 throw std::invalid_argument(
"nev must satisfy 1 <= nev <= n - 2, n is the size of matrix");
338 if (ncv < nev + 2 || ncv > m_n)
339 throw std::invalid_argument(
"ncv must satisfy nev + 2 <= ncv <= n, n is the size of matrix");
358 void init(
const Scalar* init_resid)
361 m_ritz_val.resize(m_ncv);
362 m_ritz_vec.resize(m_ncv, m_nev);
363 m_ritz_est.resize(m_ncv);
364 m_ritz_conv.resize(m_nev);
366 m_ritz_val.setZero();
367 m_ritz_vec.setZero();
368 m_ritz_est.setZero();
369 m_ritz_conv.setZero();
375 MapConstVec v0(init_resid, m_n);
376 m_fac.init(v0, m_nmatop);
388 SimpleRandom<Scalar> rng(0);
389 Vector init_resid = rng.random_vec(m_n);
390 init(init_resid.data());
421 m_fac.factorize_from(1, m_ncv, m_nmatop);
422 retrieve_ritzpair(selection);
424 Index i, nconv = 0, nev_adj;
425 for (i = 0; i < maxit; i++)
427 nconv = num_converged(tol);
431 nev_adj = nev_adjusted(nconv);
432 restart(nev_adj, selection);
435 sort_ritzpair(sorting);
440 return (std::min)(m_nev, nconv);
468 const Index nconv = m_ritz_conv.cast<Index>().sum();
469 ComplexVector res(nconv);
475 for (Index i = 0; i < m_nev; i++)
479 res[j] = m_ritz_val[i];
498 const Index nconv = m_ritz_conv.cast<Index>().sum();
499 nvec = (std::min)(nvec, nconv);
500 ComplexMatrix res(m_n, nvec);
505 ComplexMatrix ritz_vec_conv(m_ncv, nvec);
507 for (Index i = 0; i < m_nev && j < nvec; i++)
511 ritz_vec_conv.col(j).noalias() = m_ritz_vec.col(i);
516 res.noalias() = m_fac.matrix_V() * ritz_vec_conv;