#include <algebra/EigenvalueSolver.hpp>
#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_SLEPC
#include <slepc.h>

struct EigenvalueSolver::Internals
{
  static PetscReal
  computeSmallestRealEigenvalueOfSymmetricMatrix(EPS& eps)
  {
    EPSSetWhichEigenpairs(eps, EPS_SMALLEST_REAL);
    EPSSolve(eps);

    PetscReal smallest_eigenvalue;
    EPSGetEigenpair(eps, 0, &smallest_eigenvalue, nullptr, nullptr, nullptr);
    return smallest_eigenvalue;
  }

  static PetscReal
  computeLargestRealEigenvalueOfSymmetricMatrix(EPS& eps)
  {
    EPSSetWhichEigenpairs(eps, EPS_LARGEST_REAL);
    EPSSolve(eps);

    PetscReal largest_eigenvalue;
    EPSGetEigenpair(eps, 0, &largest_eigenvalue, nullptr, nullptr, nullptr);
    return largest_eigenvalue;
  }

  static void
  computeAllEigenvaluesOfSymmetricMatrixInInterval(EPS& eps, const PetscReal left_bound, const PetscReal right_bound)
  {
    Assert(left_bound < right_bound);
    EPSSetWhichEigenpairs(eps, EPS_ALL);
    EPSSetInterval(eps, left_bound - 0.01 * std::abs(left_bound), right_bound + 0.01 * std::abs(right_bound));

    ST st;
    EPSGetST(eps, &st);
    STSetType(st, STSINVERT);

    KSP ksp;
    STGetKSP(st, &ksp);
    KSPSetType(ksp, KSPPREONLY);

    PC pc;
    KSPGetPC(ksp, &pc);
    PCSetType(pc, PCCHOLESKY);
    EPSSetFromOptions(eps);

    EPSSolve(eps);
  }

  static void
  computeAllEigenvaluesOfSymmetricMatrix(EPS& eps)
  {
    const PetscReal left_bound  = computeSmallestRealEigenvalueOfSymmetricMatrix(eps);
    const PetscReal right_bound = computeLargestRealEigenvalueOfSymmetricMatrix(eps);

    computeAllEigenvaluesOfSymmetricMatrixInInterval(eps, left_bound - 0.01 * std::abs(left_bound),
                                                     right_bound + 0.01 * std::abs(right_bound));
  }
};

void
EigenvalueSolver::computeForSymmetricMatrix(const PETScAijMatrixEmbedder& A, Array<double>& eigenvalues)
{
  EPS eps;

  EPSCreate(PETSC_COMM_SELF, &eps);
  EPSSetOperators(eps, A, nullptr);
  EPSSetProblemType(eps, EPS_HEP);

  Internals::computeAllEigenvaluesOfSymmetricMatrix(eps);

  PetscInt nb_eigenvalues;
  EPSGetDimensions(eps, &nb_eigenvalues, nullptr, nullptr);

  eigenvalues = Array<double>(nb_eigenvalues);
  for (PetscInt i = 0; i < nb_eigenvalues; ++i) {
    EPSGetEigenpair(eps, i, &(eigenvalues[i]), nullptr, nullptr, nullptr);
  }

  EPSDestroy(&eps);
}

void
EigenvalueSolver::computeForSymmetricMatrix(const PETScAijMatrixEmbedder& A,
                                            Array<double>& eigenvalues,
                                            std::vector<Vector<double>>& eigenvector_list)
{
  EPS eps;

  EPSCreate(PETSC_COMM_SELF, &eps);
  EPSSetOperators(eps, A, nullptr);
  EPSSetProblemType(eps, EPS_HEP);

  Internals::computeAllEigenvaluesOfSymmetricMatrix(eps);

  PetscInt nb_eigenvalues;
  EPSGetDimensions(eps, &nb_eigenvalues, nullptr, nullptr);

  eigenvalues = Array<double>(nb_eigenvalues);
  eigenvector_list.reserve(nb_eigenvalues);
  for (PetscInt i = 0; i < nb_eigenvalues; ++i) {
    Vec Vr;
    Vector<double> eigenvector{A.numberOfRows()};
    VecCreateSeqWithArray(PETSC_COMM_SELF, 1, A.numberOfRows(), &(eigenvector[0]), &Vr);
    EPSGetEigenpair(eps, i, &(eigenvalues[i]), nullptr, Vr, nullptr);
    VecDestroy(&Vr);
    eigenvector_list.push_back(eigenvector);
  }

  EPSDestroy(&eps);
}

void
EigenvalueSolver::computeForSymmetricMatrix(const PETScAijMatrixEmbedder& A,
                                            Array<double>& eigenvalues,
                                            DenseMatrix<double>& P)
{
  EPS eps;

  EPSCreate(PETSC_COMM_SELF, &eps);
  EPSSetOperators(eps, A, nullptr);
  EPSSetProblemType(eps, EPS_HEP);

  Internals::computeAllEigenvaluesOfSymmetricMatrix(eps);

  PetscInt nb_eigenvalues;
  EPSGetDimensions(eps, &nb_eigenvalues, nullptr, nullptr);

  eigenvalues = Array<double>(nb_eigenvalues);
  P           = DenseMatrix<double>(nb_eigenvalues, nb_eigenvalues);

  Array<double> eigenvector(nb_eigenvalues);
  for (PetscInt i = 0; i < nb_eigenvalues; ++i) {
    Vec Vr;
    VecCreateSeqWithArray(PETSC_COMM_SELF, 1, A.numberOfRows(), &(eigenvector[0]), &Vr);
    EPSGetEigenpair(eps, i, &(eigenvalues[i]), nullptr, Vr, nullptr);
    VecDestroy(&Vr);
    for (size_t j = 0; j < eigenvector.size(); ++j) {
      P(j, i) = eigenvector[j];
    }
  }

  EPSDestroy(&eps);
}

#endif   // PUGS_HAS_SLEPC

EigenvalueSolver::EigenvalueSolver() {}
