#ifndef PCG_HPP
#define PCG_HPP

#include <iomanip>
#include <iostream>

#include <rang.hpp>

template <bool verbose = true>
struct PCG
{
  template <typename VectorType, typename MatrixType, typename PreconditionerType>
  PCG(const VectorType& f,
      const MatrixType& A,
      [[maybe_unused]] const PreconditionerType& C,
      VectorType& x,
      const size_t maxiter,
      const double epsilon = 1e-6)
  {
    if constexpr (verbose) {
      std::cout << "- conjugate gradient\n";
      std::cout << "  epsilon = " << epsilon << '\n';
      std::cout << "  maximum number of iterations: " << maxiter << '\n';
    }

    VectorType h{f.size()};
    VectorType b = copy(f);

    if constexpr (verbose) {
      h = A * x;
      h -= f;
      std::cout << "- initial *real* residu :   " << (h, h) << '\n';
    }

    VectorType g{b.size()};
    VectorType cg = copy(b);

    double gcg  = 0;
    double gcg0 = 1;

    double relativeEpsilon = epsilon;

    for (size_t i = 1; i <= maxiter; ++i) {
      if (i == 1) {
        h = A * x;

        cg -= h;

        g = copy(cg);   // TODO: precond: g = g/C

        gcg = (g, cg);

        h = copy(g);
      }

      b = A * h;

      double hAh = (h, b);

      if (hAh == 0) {
        hAh = 1.;
      }
      double ro = gcg / hAh;
      cg -= ro * b;

      // TODO: precond: b <- b/C

      x += ro * h;
      g -= ro * b;

      double gamma = gcg;
      gcg          = (g, cg);

      if ((i == 1) && (gcg != 0)) {
        relativeEpsilon = epsilon * gcg;
        gcg0            = gcg;
        if constexpr (verbose) {
          std::cout << "  initial residu: " << gcg << '\n';
        }
      }
      if constexpr (verbose) {
        std::cout << "  - iteration " << std::setw(6) << i << "\tresidu: " << gcg / gcg0;
        std::cout << "\tabsolute: " << gcg << '\n';
      }

      if (gcg < relativeEpsilon) {
        break;
      }

      gamma = gcg / gamma;

      h *= gamma;
      h += g;
    }

    if (gcg > relativeEpsilon) {
      std::cout << "  conjugate gradient: " << rang::fgB::red << "*NOT CONVERGED*" << rang::style::reset << '\n';
      std::cout << "  - epsilon:          " << epsilon << '\n';
      std::cout << "  - relative residu : " << gcg / gcg0 << '\n';
      std::cout << "  - absolute residu : " << gcg << '\n';
    }
  }
};

#endif   // PCG_HPP