#ifndef PCG_HPP
#define PCG_HPP

#include <iomanip>
#include <iostream>

#include <rang.hpp>

struct CG
{
  template <typename MatrixType, typename VectorType, typename RHSVectorType>
  CG(const MatrixType& A,
     VectorType& x,
     const RHSVectorType& f,
     const double epsilon,
     const size_t maximum_iteration,
     const bool verbose = false)
  {
    if (verbose) {
      std::cout << "- conjugate gradient\n";
      std::cout << "  epsilon = " << epsilon << '\n';
      std::cout << "  maximum number of iterations: " << maximum_iteration << '\n';
    }

    VectorType h{f.size()};
    VectorType b = copy(f);

    if (verbose) {
      h = A * x;
      h -= f;
      std::cout << "- initial " << rang::style::bold << "real" << rang::style::reset << " residu :   " << dot(h, h)
                << '\n';
    }

    VectorType g{b.size()};
    VectorType cg = copy(b);

    double gcg  = 0;
    double gcg0 = 1;

    double relativeEpsilon = epsilon;

    for (size_t i = 1; i <= maximum_iteration; ++i) {
      if (i == 1) {
        h = A * x;

        cg -= h;

        g = copy(cg);   // TODO: precond: g = g/C

        gcg = dot(g, cg);

        h = copy(g);
      }

      b = A * h;

      double hAh = dot(h, b);

      if (hAh == 0) {
        hAh = 1.;
      }
      double ro = gcg / hAh;
      cg -= ro * b;

      // TODO: precond: b <- b/C

      x += ro * h;
      g -= ro * b;

      double gamma = gcg;
      gcg          = dot(g, cg);

      if ((i == 1) && (gcg != 0)) {
        relativeEpsilon = epsilon * gcg;
        gcg0            = gcg;
        if (verbose) {
          std::cout << "  initial residu: " << gcg << '\n';
        }
      }
      if (verbose) {
        std::cout << "  - iteration " << std::setw(6) << i << std::scientific << " residu: " << gcg / gcg0;
        std::cout << " absolute: " << std::scientific << gcg << '\n';
      }

      if (gcg < relativeEpsilon) {
        break;
      }

      gamma = gcg / gamma;

      h *= gamma;
      h += g;
    }

    if (gcg > relativeEpsilon) {
      std::cout << "  conjugate gradient: " << rang::fgB::red << "*NOT CONVERGED*" << rang::style::reset << '\n';
      std::cout << "  - epsilon:          " << epsilon << '\n';
      std::cout << "  - relative residu : " << std::scientific << gcg / gcg0 << '\n';
      std::cout << "  - absolute residu : " << std::scientific << gcg << '\n';
    }
  }
};

#endif   // PCG_HPP