#include <algebra/LinearSolver.hpp>
#include <utils/pugs_config.hpp>

#include <algebra/BiCGStab.hpp>
#include <algebra/CG.hpp>

#ifdef PUGS_HAS_PETSC
#include <petsc.h>
#endif   // PUGS_HAS_PETSC

struct LinearSolver::Internals
{
  static bool
  hasLibrary(const LSLibrary library)
  {
    switch (library) {
    case LSLibrary::builtin: {
      return true;
    }
    case LSLibrary::petsc: {
#ifdef PUGS_HAS_PETSC
      return true;
#else
      return false;
#endif
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("Linear system library (" + ::name(library) + ") was not set!");
    }
      // LCOV_EXCL_STOP
    }
  }

  static void
  checkHasLibrary(const LSLibrary library)
  {
    if (not hasLibrary(library)) {
      // LCOV_EXCL_START
      throw NormalError(::name(library) + " is not linked to pugs. Cannot use it!");
      // LCOV_EXCL_STOP
    }
  }

  static void
  checkBuiltinMethod(const LSMethod method)
  {
    switch (method) {
    case LSMethod::cg:
    case LSMethod::bicgstab: {
      break;
    }
    default: {
      throw NormalError(name(method) + " is not a builtin linear solver!");
    }
    }
  }

  static void
  checkPETScMethod(const LSMethod method)
  {
    switch (method) {
    case LSMethod::cg:
    case LSMethod::bicgstab:
    case LSMethod::bicgstab2:
    case LSMethod::gmres:
    case LSMethod::lu:
    case LSMethod::choleski: {
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw NormalError(name(method) + " is not a builtin linear solver!");
    }
      // LCOV_EXCL_STOP
    }
  }

  static void
  checkBuiltinPrecond(const LSPrecond precond)
  {
    switch (precond) {
    case LSPrecond::none: {
      break;
    }
    default: {
      throw NormalError(name(precond) + " is not a builtin preconditioner!");
    }
    }
  }

  static void
  checkPETScPrecond(const LSPrecond precond)
  {
    switch (precond) {
    case LSPrecond::none:
    case LSPrecond::amg:
    case LSPrecond::diagonal:
    case LSPrecond::incomplete_choleski:
    case LSPrecond::incomplete_LU: {
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw NormalError(name(precond) + " is not a PETSc preconditioner!");
    }
      // LCOV_EXCL_STOP
    }
  }

  static void
  checkOptions(const LinearSolverOptions& options)
  {
    switch (options.library()) {
    case LSLibrary::builtin: {
      checkBuiltinMethod(options.method());
      checkBuiltinPrecond(options.precond());
      break;
    }
    case LSLibrary::petsc: {
      checkPETScMethod(options.method());
      checkPETScPrecond(options.precond());
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("undefined options compatibility for this library (" + ::name(options.library()) + ")!");
    }
      // LCOV_EXCL_STOP
    }
  }

  static void
  builtinSolveLocalSystem(const CRSMatrix<double, size_t>& A,
                          Vector<double>& x,
                          const Vector<double>& b,
                          const LinearSolverOptions& options)
  {
    if (options.precond() != LSPrecond::none) {
      // LCOV_EXCL_START
      throw UnexpectedError("builtin linear solver do not allow any preconditioner!");
      // LCOV_EXCL_STOP
    }
    switch (options.method()) {
    case LSMethod::cg: {
      CG{A, x, b, options.epsilon(), options.maximumIteration(), options.verbose()};
      break;
    }
    case LSMethod::bicgstab: {
      BiCGStab{A, x, b, options.epsilon(), options.maximumIteration(), options.verbose()};
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw NotImplementedError("undefined builtin method: " + name(options.method()));
    }
      // LCOV_EXCL_STOP
    }
  }

#ifdef PUGS_HAS_PETSC
  static int
  petscMonitor(KSP, int i, double residu, void*)
  {
    std::cout << "  - iteration: " << std::setw(6) << i << " residu: " << std::scientific << residu << '\n';
    return 0;
  }

  static void
  petscSolveLocalSystem(const CRSMatrix<double, size_t>& A,
                        Vector<double>& x,
                        const Vector<double>& b,
                        const LinearSolverOptions& options)
  {
    Assert(x.size() == b.size() and x.size() == A.numberOfRows());

    Vec petscB;
    VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, b.size(), b.size(), &b[0], &petscB);
    Vec petscX;
    VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, x.size(), x.size(), &x[0], &petscX);

    Array<PetscScalar> values = copy(A.values());

    const auto A_row_indices = A.rowIndices();
    Array<PetscInt> row_indices{A_row_indices.size()};
    for (size_t i = 0; i < row_indices.size(); ++i) {
      row_indices[i] = A_row_indices[i];
    }

    Array<PetscInt> column_indices{values.size()};
    size_t l = 0;
    for (size_t i = 0; i < A.numberOfRows(); ++i) {
      const auto row_i = A.row(i);
      for (size_t j = 0; j < row_i.length; ++j) {
        column_indices[l++] = row_i.colidx(j);
      }
    }

    Mat petscMat;
    MatCreateSeqAIJWithArrays(PETSC_COMM_WORLD, x.size(), x.size(), &row_indices[0], &column_indices[0], &values[0],
                              &petscMat);

    MatAssemblyBegin(petscMat, MAT_FINAL_ASSEMBLY);
    MatAssemblyEnd(petscMat, MAT_FINAL_ASSEMBLY);

    KSP ksp;
    KSPCreate(PETSC_COMM_WORLD, &ksp);
    KSPSetTolerances(ksp, options.epsilon(), 1E-100, 1E5, options.maximumIteration());

    KSPSetOperators(ksp, petscMat, petscMat);

    PC pc;
    KSPGetPC(ksp, &pc);

    bool direct_solver = false;

    switch (options.method()) {
    case LSMethod::bicgstab: {
      KSPSetType(ksp, KSPBCGS);
      break;
    }
    case LSMethod::bicgstab2: {
      KSPSetType(ksp, KSPBCGSL);
      KSPBCGSLSetEll(ksp, 2);
      break;
    }
    case LSMethod::cg: {
      KSPSetType(ksp, KSPCG);
      break;
    }
    case LSMethod::gmres: {
      KSPSetType(ksp, KSPGMRES);

      break;
    }
    case LSMethod::lu: {
      KSPSetType(ksp, KSPPREONLY);
      PCSetType(pc, PCLU);
      PCFactorSetShiftType(pc, MAT_SHIFT_NONZERO);
      direct_solver = true;
      break;
    }
    case LSMethod::choleski: {
      KSPSetType(ksp, KSPPREONLY);
      PCSetType(pc, PCCHOLESKY);
      direct_solver = true;
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("unexpected method: " + name(options.method()));
    }
      // LCOV_EXCL_STOP
    }

    if (not direct_solver) {
      switch (options.precond()) {
      case LSPrecond::amg: {
        PCSetType(pc, PCGAMG);
        break;
      }
      case LSPrecond::diagonal: {
        PCSetType(pc, PCJACOBI);
        break;
      }
      case LSPrecond::incomplete_LU: {
        PCSetType(pc, PCILU);
        break;
      }
      case LSPrecond::incomplete_choleski: {
        PCSetType(pc, PCICC);
        break;
      }
      case LSPrecond::none: {
        PCSetType(pc, PCNONE);
        break;
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected preconditioner: " + name(options.precond()));
      }
        // LCOV_EXCL_STOP
      }
    }
    if (options.verbose()) {
      KSPMonitorSet(ksp, petscMonitor, 0, 0);
    }

    KSPSolve(ksp, petscB, petscX);

    // free used memory
    MatDestroy(&petscMat);
    VecDestroy(&petscB);
    VecDestroy(&petscX);
    KSPDestroy(&ksp);
  }

#else   // PUGS_HAS_PETSC

  // LCOV_EXCL_START
  static void
  petscSolveLocalSystem(const CRSMatrix<double, size_t>&,
                        Vector<double>&,
                        const Vector<double>&,
                        const LinearSolverOptions&)
  {
    checkHasLibrary(LSLibrary::petsc);
    throw UnexpectedError("unexpected situation should not reach this point!");
  }
  // LCOV_EXCL_STOP

#endif   // PUGS_HAS_PETSC
};

bool
LinearSolver::hasLibrary(LSLibrary library) const
{
  return Internals::hasLibrary(library);
}

void
LinearSolver::checkOptions(const LinearSolverOptions& options) const
{
  Internals::checkOptions(options);
}

void
LinearSolver::solveLocalSystem(const CRSMatrix<double, size_t>& A, Vector<double>& x, const Vector<double>& b)
{
  switch (m_options.library()) {
  case LSLibrary::builtin: {
    Internals::builtinSolveLocalSystem(A, x, b, m_options);
    break;
  }
    // LCOV_EXCL_START
  case LSLibrary::petsc: {
    // not covered since if PETSc is not linked this point is
    // unreachable: LinearSolver throws an exception at construction
    // in this case.
    Internals::petscSolveLocalSystem(A, x, b, m_options);
    break;
  }
  default: {
    throw UnexpectedError(::name(m_options.library()) + " cannot solve local systems for sparse matrices");
  }
    // LCOV_EXCL_STOP
  }
}

LinearSolver::LinearSolver(const LinearSolverOptions& options) : m_options{options}
{
  Internals::checkHasLibrary(m_options.library());
  Internals::checkOptions(options);
}

LinearSolver::LinearSolver() : LinearSolver{LinearSolverOptions::default_options} {}