#ifndef DISCRETE_FUNCTION_DPK_FOR_TESTS_HPP
#define DISCRETE_FUNCTION_DPK_FOR_TESTS_HPP

#include <analysis/GaussQuadratureDescriptor.hpp>
#include <analysis/QuadratureFormula.hpp>
#include <analysis/QuadratureManager.hpp>
#include <geometry/CubeTransformation.hpp>
#include <geometry/LineTransformation.hpp>
#include <geometry/PrismTransformation.hpp>
#include <geometry/PyramidTransformation.hpp>
#include <geometry/SquareTransformation.hpp>
#include <geometry/TetrahedronTransformation.hpp>
#include <geometry/TriangleTransformation.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshDataManager.hpp>
#include <type_traits>

namespace test_only
{

template <MeshConcept MeshType, typename DataType>
DiscreteFunctionP0<std::remove_const_t<DataType>>
exact_projection(const MeshType& mesh,
                 size_t degree,
                 std::function<DataType(const TinyVector<MeshType::Dimension>&)> exact_function)
{
  DiscreteFunctionP0<std::remove_const_t<DataType>> P0_function{mesh.meshVariant()};

  auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();

  auto xr                  = mesh.xr();
  auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
  auto cell_type           = mesh.connectivity().cellType();

  auto sum = [&exact_function, &Vj](const CellId cell_id, const auto& T,
                                    const auto& qf) -> std::remove_const_t<DataType> {
    std::remove_const_t<DataType> integral =
      (qf.weight(0) * T.jacobianDeterminant(qf.point(0))) * exact_function(T(qf.point(0)));
    for (size_t i_quadrarture = 1; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
      integral += (qf.weight(i_quadrarture) * T.jacobianDeterminant(qf.point(i_quadrarture))) *
                  exact_function(T(qf.point(i_quadrarture)));
    }
    return 1. / Vj[cell_id] * integral;
  };

  for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
    auto cell_nodes = cell_to_node_matrix[cell_id];
    if constexpr (MeshType::Dimension == 1) {
      LineTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]]};
      auto qf              = QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor{degree + 1});
      P0_function[cell_id] = sum(cell_id, T, qf);
    } else if constexpr (MeshType::Dimension == 2) {
      switch (cell_type[cell_id]) {
      case CellType::Triangle: {
        TriangleTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]]};
        auto qf              = QuadratureManager::instance().getTriangleFormula(GaussQuadratureDescriptor{degree + 2});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      case CellType::Quadrangle: {
        SquareTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                                                    xr[cell_nodes[3]]};
        auto qf              = QuadratureManager::instance().getSquareFormula(GaussQuadratureDescriptor{degree + 2});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    } else if constexpr (MeshType::Dimension == 3) {
      switch (cell_type[cell_id]) {
      case CellType::Tetrahedron: {
        TetrahedronTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]]};
        auto qf = QuadratureManager::instance().getTetrahedronFormula(GaussQuadratureDescriptor{degree + 3});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      case CellType::Pyramid: {
        PyramidTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                                xr[cell_nodes[4]]};
        auto qf              = QuadratureManager::instance().getPyramidFormula(GaussQuadratureDescriptor{degree + 3});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      case CellType::Prism: {
        PrismTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                              xr[cell_nodes[3]], xr[cell_nodes[4]], xr[cell_nodes[5]]};
        auto qf              = QuadratureManager::instance().getPrismFormula(GaussQuadratureDescriptor{degree + 3});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      case CellType::Hexahedron: {
        CubeTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                             xr[cell_nodes[4]], xr[cell_nodes[5]], xr[cell_nodes[6]], xr[cell_nodes[7]]};
        auto qf              = QuadratureManager::instance().getCubeFormula(GaussQuadratureDescriptor{degree + 3});
        P0_function[cell_id] = sum(cell_id, T, qf);
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    } else {
      throw UnexpectedError("invalid mesh dimension");
    }
  }

  return P0_function;
}

template <MeshConcept MeshType, typename DataType, size_t NbComponents>
DiscreteFunctionP0Vector<std::remove_const_t<DataType>>
exact_projection(
  const MeshType& mesh,
  size_t degree,
  const std::array<std::function<DataType(const TinyVector<MeshType::Dimension>&)>, NbComponents>& vector_exact)
{
  DiscreteFunctionP0Vector<std::remove_const_t<DataType>> P0_function_vector{mesh.meshVariant(), vector_exact.size()};

  for (size_t i_component = 0; i_component < vector_exact.size(); ++i_component) {
    auto exact_function = vector_exact[i_component];

    DiscreteFunctionP0 P0_function = exact_projection(mesh, degree, vector_exact[i_component]);

    parallel_for(
      mesh.numberOfCells(),
      PUGS_LAMBDA(const CellId cell_id) { P0_function_vector[cell_id][i_component] = P0_function[cell_id]; });
  }

  return P0_function_vector;
}

template <typename DataType>
PUGS_INLINE double
get_max_error(const DataType& x, const DataType& y)
{
  if constexpr (is_tiny_matrix_v<DataType>) {
    return frobeniusNorm(x - y);
  } else if constexpr (is_tiny_vector_v<DataType>) {
    return l2Norm(x - y);
  } else {
    static_assert(std::is_arithmetic_v<DataType>, "expecting arithmetic type");
    return std::abs(x - y);
  }
}

template <MeshConcept MeshType, typename DataType>
double
max_reconstruction_error(const MeshType& mesh,
                         DiscreteFunctionDPk<MeshType::Dimension, const DataType> dpk_f,
                         std::function<DataType(const TinyVector<MeshType::Dimension>&)> exact)
{
  auto xr                  = mesh.xr();
  auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
  auto cell_type           = mesh.connectivity().cellType();

  double max_error = 0;
  for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
    auto cell_nodes = cell_to_node_matrix[cell_id];
    if constexpr (MeshType::Dimension == 1) {
      Assert(cell_type[cell_id] == CellType::Line);
      LineTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]]};
      auto qf = QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
      for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
        auto x    = T(qf.point(i_quadrarture));
        max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
      }
    } else if constexpr (MeshType::Dimension == 2) {
      switch (cell_type[cell_id]) {
      case CellType::Triangle: {
        TriangleTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]]};
        auto qf = QuadratureManager::instance().getTriangleFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      case CellType::Quadrangle: {
        SquareTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                                                    xr[cell_nodes[3]]};
        auto qf = QuadratureManager::instance().getSquareFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    } else if constexpr (MeshType::Dimension == 3) {
      switch (cell_type[cell_id]) {
      case CellType::Tetrahedron: {
        TetrahedronTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]]};
        auto qf = QuadratureManager::instance().getTetrahedronFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      case CellType::Pyramid: {
        PyramidTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                                xr[cell_nodes[4]]};
        auto qf = QuadratureManager::instance().getPyramidFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      case CellType::Prism: {
        PrismTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                              xr[cell_nodes[3]], xr[cell_nodes[4]], xr[cell_nodes[5]]};
        auto qf = QuadratureManager::instance().getPrismFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      case CellType::Hexahedron: {
        CubeTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                             xr[cell_nodes[4]], xr[cell_nodes[5]], xr[cell_nodes[6]], xr[cell_nodes[7]]};
        auto qf = QuadratureManager::instance().getCubeFormula(GaussQuadratureDescriptor{dpk_f.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x    = T(qf.point(i_quadrarture));
          max_error = std::max(max_error, get_max_error(dpk_f[cell_id](x), exact(x)));
        }
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    }
  }
  return max_error;
}

template <MeshConcept MeshType, typename DataType, size_t NbComponents>
double
max_reconstruction_error(
  const MeshType& mesh,
  DiscreteFunctionDPkVector<MeshType::Dimension, const DataType> dpk_v,
  const std::array<std::function<DataType(const TinyVector<MeshType::Dimension>&)>, NbComponents>& vector_exact)
{
  auto xr                  = mesh.xr();
  auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
  double max_error         = 0;
  auto cell_type           = mesh.connectivity().cellType();

  REQUIRE(NbComponents == dpk_v.numberOfComponents());

  for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
    auto cell_nodes = cell_to_node_matrix[cell_id];
    if constexpr (MeshType::Dimension == 1) {
      Assert(cell_type[cell_id] == CellType::Line);
      LineTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]]};
      auto qf = QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
      for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
        auto x = T(qf.point(i_quadrarture));
        for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
          max_error = std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
        }
      }
    } else if constexpr (MeshType::Dimension == 2) {
      switch (cell_type[cell_id]) {
      case CellType::Triangle: {
        TriangleTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]]};
        auto qf = QuadratureManager::instance().getTriangleFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      case CellType::Quadrangle: {
        SquareTransformation<MeshType::Dimension> T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                                                    xr[cell_nodes[3]]};
        auto qf = QuadratureManager::instance().getSquareFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    } else if constexpr (MeshType::Dimension == 3) {
      switch (cell_type[cell_id]) {
      case CellType::Tetrahedron: {
        TetrahedronTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]]};
        auto qf = QuadratureManager::instance().getTetrahedronFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      case CellType::Pyramid: {
        PyramidTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                                xr[cell_nodes[4]]};
        auto qf = QuadratureManager::instance().getPyramidFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      case CellType::Prism: {
        PrismTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]],
                              xr[cell_nodes[3]], xr[cell_nodes[4]], xr[cell_nodes[5]]};
        auto qf = QuadratureManager::instance().getPrismFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      case CellType::Hexahedron: {
        CubeTransformation T{xr[cell_nodes[0]], xr[cell_nodes[1]], xr[cell_nodes[2]], xr[cell_nodes[3]],
                             xr[cell_nodes[4]], xr[cell_nodes[5]], xr[cell_nodes[6]], xr[cell_nodes[7]]};
        auto qf = QuadratureManager::instance().getCubeFormula(GaussQuadratureDescriptor{dpk_v.degree() + 1});
        for (size_t i_quadrarture = 0; i_quadrarture < qf.numberOfPoints(); ++i_quadrarture) {
          auto x = T(qf.point(i_quadrarture));
          for (size_t i_component = 0; i_component < NbComponents; ++i_component) {
            max_error =
              std::max(max_error, get_max_error(dpk_v(cell_id, i_component)(x), vector_exact[i_component](x)));
          }
        }
        break;
      }
      default: {
        throw UnexpectedError("unexpected cell type");
      }
      }
    }
  }
  return max_error;
}

}   // namespace test_only

#endif   // DISCRETE_FUNCTION_DPK_FOR_TESTS_HPP