#ifndef INTEGRATION_TOOLS_HPP
#define INTEGRATION_TOOLS_HPP

#include <language/utils/EvaluateAtPoints.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/Types.hpp>

#include <analysis/GaussLegendreQuadrature.hpp>

#include <cmath>

enum class QuadratureType : int8_t
{
  QT__begin = 0,
  //
  gausslobatto  = QT__begin,
  gausslegendre = 1,
  //
  QT__end
};

class IntegrationMethod
{
 public:
  PUGS_INLINE virtual SmallArray<TinyVector<1>> quadraturePoints(const double& a, const double& b) = 0;
  PUGS_INLINE virtual size_t numberOfPoints()                                                      = 0;
  PUGS_INLINE virtual SmallArray<double> weights()                                                 = 0;
  // One does not use the '=default' constructor to avoid
  // (zero-initialization) performances issues
  PUGS_INLINE
  constexpr IntegrationMethod() noexcept {}

  PUGS_INLINE
  constexpr IntegrationMethod(const IntegrationMethod&) noexcept = default;

  PUGS_INLINE
  constexpr IntegrationMethod(IntegrationMethod&& v) noexcept = default;

  virtual PUGS_INLINE ~IntegrationMethod() noexcept = default;
};

template <size_t Order>
class IntegrationMethodLobatto : public IntegrationMethod
{
 private:
  static constexpr size_t number_points = (Order < 4 ? 3 : (Order < 6 ? 4 : (Order < 8 ? 5 : Order < 10 ? 6 : 7)));
  SmallArray<double> m_weights;
  SmallArray<TinyVector<1>> m_quadrature_positions;
  PUGS_INLINE SmallArray<TinyVector<1>>
  translateQuadraturePoints(const double& a, const double& b)
  {
    SmallArray<TinyVector<1>> points{number_points};
    double length = b - a;
    for (size_t i = 0; i < m_quadrature_positions.size(); i++) {
      TinyVector<1> qpos{m_quadrature_positions[i]};
      points[i] = 0.5 * (length * qpos + a + b);
    }
    return points;
  }

  PUGS_INLINE
  void
  fillArrayLobatto()
  {
    switch (Order) {
    case 0:
    case 1:
    case 2:
    case 3: {
      double oneov3             = 1. / 3.;
      m_quadrature_positions[0] = -1.;
      m_quadrature_positions[1] = 0.;
      m_quadrature_positions[2] = 1.;
      m_weights[0]              = oneov3;
      m_weights[1]              = 4 * oneov3;
      m_weights[2]              = oneov3;
      // m_quadrature_positions.fill({-1, 0, 1});
      // m_weights.fill({oneov3, 4 * oneov3, oneov3});
      break;
    }
    case 4:
    case 5: {
      double oneov6             = 1. / 6.;
      double coef               = std::sqrt(1. / 5.);
      m_quadrature_positions[0] = -1;
      m_quadrature_positions[1] = -coef;
      m_quadrature_positions[2] = coef;
      m_quadrature_positions[3] = 1;
      m_weights[0]              = oneov6;
      m_weights[1]              = 5 * oneov6;
      m_weights[2]              = 5 * oneov6;
      m_weights[3]              = oneov6;
      break;
    }
    case 6:
    case 7: {
      double oneov90            = 1. / 90.;
      double oneov10            = 1. / 10.;
      double coef               = std::sqrt(3. / 7.);
      m_quadrature_positions[0] = -1;
      m_quadrature_positions[1] = -coef;
      m_quadrature_positions[2] = 0;
      m_quadrature_positions[3] = coef;
      m_quadrature_positions[4] = 1;
      m_weights[0]              = oneov10;
      m_weights[1]              = 49 * oneov90;
      m_weights[2]              = 64 * oneov90;
      m_weights[3]              = 49 * oneov90;
      m_weights[4]              = oneov10;
      break;
    }
    case 8:
    case 9: {
      double oneov30            = 1. / 30.;
      double coef1              = std::sqrt(1. / 3. - 2 * std::sqrt(7) / 21.);
      double coef2              = std::sqrt(1. / 3. + 2 * std::sqrt(7) / 21.);
      double coef3              = (14. + std::sqrt(7)) * oneov30;
      double coef4              = (14. - std::sqrt(7)) * oneov30;
      m_quadrature_positions[0] = -1;
      m_quadrature_positions[1] = -coef2;
      m_quadrature_positions[2] = -coef1;
      m_quadrature_positions[3] = coef1;
      m_quadrature_positions[4] = coef2;
      m_quadrature_positions[5] = 1;
      m_weights[0]              = 2. * oneov30;
      m_weights[1]              = coef4;
      m_weights[2]              = coef3;
      m_weights[3]              = coef3;
      m_weights[4]              = coef4;
      m_weights[5]              = 2. * oneov30;
      break;
    }
    case 10:
    case 11: {
      double oneov350           = 1. / 350.;
      double oneov21            = 1. / 21.;
      double coef0              = 256. / 525.;
      double coef1              = std::sqrt((5. - 2 * std::sqrt(5. / 3.)) / 11.);
      double coef2              = std::sqrt((5. + 2 * std::sqrt(5. / 3.)) / 11.);
      double coef3              = (124. + 7 * std::sqrt(15.)) * oneov350;
      double coef4              = (124. - 7 * std::sqrt(15.)) * oneov350;
      m_quadrature_positions[0] = -1;
      m_quadrature_positions[1] = -coef2;
      m_quadrature_positions[2] = -coef1;
      m_quadrature_positions[3] = 0;
      m_quadrature_positions[4] = coef1;
      m_quadrature_positions[5] = coef2;
      m_quadrature_positions[6] = 1;
      m_weights[0]              = oneov21;
      m_weights[1]              = coef4;
      m_weights[2]              = coef3;
      m_weights[3]              = coef0;
      m_weights[4]              = coef3;
      m_weights[5]              = coef4;
      m_weights[6]              = oneov21;
      break;
    }

    default: {
      throw UnexpectedError("Gauss-Lobatto quadratures handle orders up to 11.");
      break;
    }
    }
  }

 public:
  // One does not use the '=default' constructor to avoid
  // (zero-initialization) performances issues
  PUGS_INLINE
  constexpr IntegrationMethodLobatto() noexcept : m_weights{number_points}, m_quadrature_positions{number_points}
  {
    m_weights.fill(0.);
    m_quadrature_positions.fill(TinyVector<1, double>{0});
    fillArrayLobatto();
  }

  PUGS_INLINE
  SmallArray<TinyVector<1>>
  quadraturePoints(const double& a, const double& b)
  {
    return translateQuadraturePoints(a, b);
  }

  PUGS_INLINE size_t
  numberOfPoints()
  {
    return number_points;
  }
  PUGS_INLINE
  SmallArray<double>
  weights()
  {
    return m_weights;
  }

  PUGS_INLINE
  constexpr IntegrationMethodLobatto(const IntegrationMethodLobatto&) noexcept = default;

  PUGS_INLINE
  constexpr IntegrationMethodLobatto(IntegrationMethodLobatto&& v) noexcept = default;

  PUGS_INLINE
  ~IntegrationMethodLobatto() noexcept = default;
};

template <size_t Order>
class IntegrationMethodLegendre : public IntegrationMethod
{
 private:
  static constexpr size_t number_points = (Order < 4 ? 2 : (Order < 6 ? 3 : (Order < 8 ? 4 : 5)));
  SmallArray<double> m_weights;
  SmallArray<TinyVector<1>> m_quadrature_positions;
  PUGS_INLINE SmallArray<TinyVector<1>>
  translateQuadraturePoints(const double& a, const double& b)
  {
    SmallArray<TinyVector<1>> points{number_points};
    double length = b - a;
    for (size_t i = 0; i < m_quadrature_positions.size(); i++) {
      TinyVector<1> qpos{m_quadrature_positions[i]};
      points[i] = 0.5 * (length * qpos + a + b);
    }
    return points;
  }

  PUGS_INLINE
  void
  fillArrayLegendre()
  {
    GaussLegendreQuadrature<1> tensorial_gauss_legendre(Order);
    copy_to(tensorial_gauss_legendre.pointList(), m_quadrature_positions);
    copy_to(tensorial_gauss_legendre.weightList(), m_weights);
    // switch (Order) {
    // case 0:
    // case 1:
    // case 2:
    // case 3: {
    //   double oneovsqrt3         = 1. / std::sqrt(3);
    //   m_quadrature_positions[0] = -oneovsqrt3;
    //   m_quadrature_positions[1] = oneovsqrt3;
    //   m_weights[0]              = 1;
    //   m_weights[1]              = 1;
    //   // m_quadrature_positions.fill({-1, 0, 1});
    //   // m_weights.fill({oneov3, 4 * oneov3, oneov3});
    //   break;
    // }
    // case 4:
    // case 5: {
    //   double oneover9           = 1. / 9;
    //   double coef               = std::sqrt(3. / 5.);
    //   m_quadrature_positions[0] = -coef;
    //   m_quadrature_positions[1] = 0;
    //   m_quadrature_positions[2] = coef;
    //   m_weights[0]              = 5 * oneover9;
    //   m_weights[1]              = 8 * oneover9;
    //   m_weights[2]              = 5 * oneover9;
    //   break;
    // }
    // case 6:
    // case 7: {
    //   double oneov36 = 1. / 36.;
    //   double coef1   = std::sqrt(3. / 7. - 2. / 7. * std::sqrt(6. / 5.));
    //   double coef2   = std::sqrt(3. / 7. + 2. / 7. * std::sqrt(6. / 5.));
    //   double coef3   = 18. + std::sqrt(30.);
    //   double coef4   = 18. - std::sqrt(30.);

    //   m_quadrature_positions[0] = -coef2;
    //   m_quadrature_positions[1] = -coef1;
    //   m_quadrature_positions[2] = coef1;
    //   m_quadrature_positions[3] = coef2;
    //   m_weights[0]              = coef4 * oneov36;
    //   m_weights[1]              = coef3 * oneov36;
    //   m_weights[2]              = coef3 * oneov36;
    //   m_weights[3]              = coef4 * oneov36;
    //   break;
    // }
    // case 8:
    // case 9: {
    //   double oneov900           = 1. / 900.;
    //   double oneov225           = 1. / 225.;
    //   double coef1              = 1. / 3. * std::sqrt(5. - 2. * std::sqrt(10. / 7.));
    //   double coef2              = 1. / 3. * std::sqrt(5. + 2. * std::sqrt(10. / 7.));
    //   double coef3              = 322. + 13. * std::sqrt(70.);
    //   double coef4              = 322. - 13. * std::sqrt(70.);
    //   m_quadrature_positions[0] = -coef2;
    //   m_quadrature_positions[1] = -coef1;
    //   m_quadrature_positions[2] = 0;
    //   m_quadrature_positions[3] = coef1;
    //   m_quadrature_positions[4] = coef2;
    //   m_weights[0]              = coef4 * oneov900;
    //   m_weights[1]              = coef3 * oneov900;
    //   m_weights[2]              = 128. * oneov225;
    //   m_weights[3]              = coef3 * oneov900;
    //   m_weights[4]              = coef4 * oneov900;
    //   break;
    // }
    // default: {
    //   throw UnexpectedError("Gauss-Legendre quadratures handle orders up to 9.");
    //   break;
    // }
    // }
  }

 public:
  // One does not use the '=default' constructor to avoid
  // (zero-initialization) performances issues
  PUGS_INLINE
  constexpr IntegrationMethodLegendre() noexcept : m_weights{number_points}, m_quadrature_positions{number_points}
  {
    m_weights.fill(0.);
    m_quadrature_positions.fill(TinyVector<1, double>{0});
    fillArrayLegendre();
  }

  PUGS_INLINE
  SmallArray<TinyVector<1>>
  quadraturePoints(const double& a, const double& b)
  {
    return translateQuadraturePoints(a, b);
  }

  PUGS_INLINE size_t
  numberOfPoints()
  {
    return number_points;
  }
  PUGS_INLINE
  SmallArray<double>
  weights()
  {
    return m_weights;
  }

  PUGS_INLINE
  constexpr IntegrationMethodLegendre(const IntegrationMethodLegendre&) noexcept = default;

  PUGS_INLINE
  constexpr IntegrationMethodLegendre(IntegrationMethodLegendre&& v) noexcept = default;

  PUGS_INLINE
  ~IntegrationMethodLegendre() noexcept = default;
};

template <size_t Order, typename T = double>
class IntegrationTools
{
 public:
  using data_type = T;

 private:
  //  T m_values[N];
  std::shared_ptr<IntegrationMethod> m_integration_method;
  size_t N;

 public:
  PUGS_INLINE T
  integrate(const FunctionSymbolId& function, const double& a, const double& b) const
  {
    T result                                  = 0;
    SmallArray<const TinyVector<1>> positions = quadraturePoints(a, b);
    SmallArray<double> weights                = m_integration_method->weights();
    SmallArray<T> values(positions.size());
    EvaluateAtPoints<double(const TinyVector<1>)>::evaluateTo(function, positions, values);
    Assert(positions.size() == weights.size(), "Wrong number of quadrature points and/or weights in Gauss quadrature");
    for (size_t i = 0; i < values.size(); ++i) {
      result += weights[i] * values[i];
    }
    return (b - a) / 2 * result;
  }

  PUGS_INLINE SmallArray<T>
  integrateFunction(const FunctionSymbolId& function, const SmallArray<TinyVector<1>>& vertices) const
  {
    SmallArray<const TinyVector<1>> positions = quadraturePoints(vertices);
    SmallArray<T> result{vertices.size() - 1};
    SmallArray<double> interval_size{vertices.size() - 1};
    SmallArray<double> weights = m_integration_method->weights();
    SmallArray<T> values(positions.size());
    EvaluateAtPoints<double(const TinyVector<1>)>::evaluateTo(function, positions, values);

    for (size_t i = 0; i < interval_size.size(); ++i) {
      interval_size[i] = vertices[i + 1][0] - vertices[i][0];
    }
    for (size_t i = 0; i < interval_size.size(); ++i) {
      for (size_t j = 0; j < weights.size(); j++) {
        result[i] += 0.5 * interval_size[i] * weights[j] * values[i * weights.size() + j];
      }
    }
    return result;
  }

  PUGS_INLINE T
  testIntegrate(const double& a, const double& b) const
  {
    T result                            = 0;
    SmallArray<TinyVector<1>> positions = quadraturePoints(a, b);

    SmallArray<double> weights = m_integration_method->weights();

    SmallArray<T> values(weights.size());
    for (size_t j = 0; j < weights.size(); ++j) {
      values[j] = std::pow(positions[j][0], 3.);
    }

    Assert(positions.size() == weights.size(), "Wrong number of quadrature points and/or weights in Gauss quadrature");
    for (size_t i = 0; i < values.size(); ++i) {
      result += weights[i] * values[i];
    }

    return 0.5 * (b - a) * result;
  }

  PUGS_INLINE SmallArray<T>
  testIntegrateFunction(const SmallArray<TinyVector<1>>& vertices) const
  {
    SmallArray<TinyVector<1>> positions = quadraturePoints(vertices);
    SmallArray<double> interval_size{vertices.size() - 1};
    SmallArray<double> weights = m_integration_method->weights();
    SmallArray<T> values(positions.size());
    for (size_t j = 0; j < positions.size(); ++j) {
      values[j] = std::pow(positions[j][0], 3);
    }

    for (size_t i = 0; i < interval_size.size(); ++i) {
      interval_size[i] = vertices[i + 1][0] - vertices[i][0];
    }
    SmallArray<T> result{vertices.size() - 1};
    if constexpr (std::is_arithmetic_v<T>) {
      result.fill(0);
    } else if constexpr (is_tiny_vector_v<T> or is_tiny_matrix_v<T>) {
      result.fill(zero);
    } else {
      static_assert(std::is_arithmetic_v<T>, "incorrect data type");
    }
    for (size_t i = 0; i < interval_size.size(); ++i) {
      for (size_t j = 0; j < weights.size(); j++) {
        result[i] += 0.5 * interval_size[i] * weights[j] * values[i * weights.size() + j];
      }
    }
    return result;
  }

  PUGS_INLINE SmallArray<TinyVector<1>>
  quadraturePoints(const double& a, const double& b) const
  {
    return m_integration_method->quadraturePoints(a, b);
  }

  PUGS_INLINE
  SmallArray<TinyVector<1>>
  quadraturePoints(const SmallArray<TinyVector<1>>& positions) const
  {
    size_t number_of_intervals = positions.size() - 1;
    size_t quadrature_size     = m_integration_method->numberOfPoints();
    SmallArray<TinyVector<1>> quadratures{quadrature_size * number_of_intervals};
    for (size_t j = 0; j < number_of_intervals; ++j) {
      double a                    = positions[j][0];
      double b                    = positions[j + 1][0];
      auto intermediate_positions = m_integration_method->quadraturePoints(a, b);
      for (size_t k = 0; k < quadrature_size; k++) {
        quadratures[j * quadrature_size + k] = intermediate_positions[k];
      }
    }
    return quadratures;
  }

  PUGS_INLINE SmallArray<double>
  weights() const
  {
    return m_integration_method->weights();
  }

  PUGS_INLINE
  constexpr size_t
  order() const
  {
    return Order;
  }

  PUGS_INLINE constexpr IntegrationTools(QuadratureType quadrature)
  {
    switch (quadrature) {
    case QuadratureType::gausslobatto: {
      m_integration_method = std::make_shared<IntegrationMethodLobatto<Order>>();
      break;
    }
    case QuadratureType::gausslegendre: {
      m_integration_method = std::make_shared<IntegrationMethodLegendre<Order>>();
      break;
    }
    case QuadratureType::QT__end: {
      throw NormalError("QuadratureType is not defined!");
    }
    }
  }

  // One does not use the '=default' constructor to avoid
  // (zero-initialization) performances issues
  PUGS_INLINE
  constexpr IntegrationTools() noexcept {}

  PUGS_INLINE
  constexpr IntegrationTools(const IntegrationTools&) noexcept = default;

  PUGS_INLINE
  constexpr IntegrationTools(IntegrationTools&& v) noexcept = default;

  PUGS_INLINE
  ~IntegrationTools() noexcept = default;
};

#endif   // INTEGRATION_TOOLS_HPP
