#ifndef BICG_STAB_HPP
#define BICG_STAB_HPP

#include <cmath>
#include <iomanip>
#include <iostream>

#include <rang.hpp>

template <bool verbose = true>
struct BiCGStab
{
  template <typename VectorType, typename MatrixType>
  BiCGStab(const VectorType& b, const MatrixType& A, VectorType& x, const size_t max_iter, const double epsilon = 1e-6)
  {
    if constexpr (verbose) {
      std::cout << "- bi-conjugate gradient stabilized\n";
      std::cout << "  epsilon = " << epsilon << '\n';
      std::cout << "  maximum number of iterations: " << max_iter << '\n';
    }

    VectorType r_k_1{b.size()};

    r_k_1 = b - A * x;

    double residu = std::sqrt((r_k_1, r_k_1));   // Norm(r_k_1);

    if (residu != 0) {
      double resid0 = residu;

      VectorType rTilda_0 = copy(r_k_1);
      VectorType p_k      = copy(r_k_1);

      VectorType s_k{x.size()};

      VectorType Ap_k{x.size()};
      VectorType As_k{x.size()};

      VectorType r_k{x.size()};

      if constexpr (verbose) {
        std::cout << "   initial residu: " << resid0 << '\n';
      }
      for (size_t i = 1; i <= max_iter; ++i) {
        if constexpr (verbose) {
          std::cout << "  - iteration: " << std::setw(6) << i << "\tresidu: " << residu / resid0
                    << "\tabsolute: " << residu << '\n';
        }

        Ap_k = A * p_k;

        const double alpha_k = (r_k_1, rTilda_0) / (Ap_k, rTilda_0);

        s_k  = r_k_1 - alpha_k * Ap_k;
        As_k = A * s_k;

        const double w_k = (As_k, s_k) / (As_k, As_k);

        x += alpha_k * p_k + w_k * s_k;
        r_k = s_k - w_k * As_k;

        const double beta_k = (r_k, rTilda_0) / (r_k_1, rTilda_0) * (alpha_k / w_k);

        p_k -= w_k * Ap_k;
        p_k *= beta_k;
        p_k += r_k;

        if ((residu = std::sqrt((r_k, r_k))) / resid0 < epsilon) {
          break;
        }

        r_k_1 = r_k;
      }

      if (residu / resid0 > epsilon) {
        std::cout << "  bi-conjugate gradient stabilized: " << rang::fgB::red << "*NOT CONVERGED*" << rang::style::reset
                  << '\n';
        ;
        std::cout << "  - epsilon:          " << epsilon << '\n';
        std::cout << "  - relative residu : " << residu / resid0 << '\n';
        std::cout << "  - absolute residu : " << residu << '\n';
      }
    }
  }
};

#endif   // BICG_STAB_HPP