#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTModulesImporter.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeFunctionEvaluationExpressionBuilder.hpp>
#include <language/ast/ASTNodeFunctionExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <language/utils/PugsFunctionAdapter.hpp>
#include <language/utils/SymbolTable.hpp>

#include <MeshDataBaseForTests.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/PolynomialMesh.hpp>
#include <mesh/PolynomialMeshBuilder.hpp>
#include <scheme/CellIntegrator.hpp>

#include <analysis/GaussLegendreQuadratureDescriptor.hpp>
#include <analysis/GaussLobattoQuadratureDescriptor.hpp>
#include <analysis/GaussQuadratureDescriptor.hpp>
#include <language/utils/IntegrateOnCells.hpp>

// clazy:excludeall=non-pod-global-static

TEST_CASE("IntegrateOnCells_cubic", "[language]")
{
  auto scalar_error = [](auto f, auto g) -> double {
    double error = 0;
    for (size_t i = 0; i < f.size(); ++i) {
      error += std::abs(f[i] - g[i]);
    }

    return error;
  };

  auto vector_error = [](auto f, auto g) -> double {
    double error = 0;
    for (size_t i = 0; i < f.size(); ++i) {
      error += dot(f[i] - g[i], f[i] - g[i]);
    }

    return std::sqrt(error);
  };

  auto matrix_error = [](auto f, auto g) -> double {
    double error = 0;
    for (size_t i = 0; i < f.size(); ++i) {
      error += dot(f[i] - g[i], f[i] - g[i]);
    }

    return std::sqrt(error);
  };

  SECTION("Gauss quadrature")
  {
    auto quadrature_descriptor = GaussQuadratureDescriptor(20);

    SECTION("integrate on all cells")
    {
      SECTION("2D")
      {
        constexpr size_t Dimension = 2;

        std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

        for (const auto& named_mesh : mesh_list) {
          SECTION(named_mesh.name())
          {
            auto mesh_2d_v = named_mesh.mesh();
            auto mesh_2d   = mesh_2d_v->get<Mesh<2>>();

            PolynomialMeshBuilder pb{mesh_2d_v, 3};

            auto cubic_mesh = pb.mesh()->get<PolynomialMesh<2>>();

            std::string_view data = R"(
import math;
let scalar_2d: R^2 -> R, x -> 2*x[0]*x[1]+3;
let R3_2d: R^2 -> R^3, x -> [2*x[0]*x[0] + x[1]*x[1] - 1, x[0]-2*x[1], 3];
let R2x2_2d: R^2 -> R^2x2, x -> [[2*x[0]*x[0]*x[1]+3, x[0]-2*x[1]], [3, x[0]*x[1]]];
)";

            TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};

            auto ast = ASTBuilder::build(input);

            ASTModulesImporter{*ast};
            ASTNodeTypeCleaner<language::import_instruction>{*ast};

            ASTSymbolTableBuilder{*ast};
            ASTNodeDataTypeBuilder{*ast};

            ASTNodeTypeCleaner<language::var_declaration>{*ast};
            ASTNodeTypeCleaner<language::fct_declaration>{*ast};
            ASTNodeExpressionBuilder{*ast};

            std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;

            // ensure that variables are declared at this point
            TAO_PEGTL_NAMESPACE::position position{data.size(), 1, 1, "fixture"};

            SECTION("scalar 2d")
            {
              auto [i_symbol, found] = symbol_table->find("scalar_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<double> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *mesh_2d, integrate_value_polygonal);

              Array<double> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *cubic_mesh, integrate_value_cubic);

              REQUIRE(scalar_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("vector 2d")
            {
              using R3               = TinyVector<3>;
              auto [i_symbol, found] = symbol_table->find("R3_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R3> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *mesh_2d, integrate_value_polygonal);

              Array<R3> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *cubic_mesh, integrate_value_cubic);

              REQUIRE(vector_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("matrix 2d")
            {
              using R2x2             = TinyMatrix<2>;
              auto [i_symbol, found] = symbol_table->find("R2x2_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R2x2> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *mesh_2d, integrate_value_polygonal);

              Array<R2x2> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *cubic_mesh, integrate_value_cubic);

              REQUIRE(matrix_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }
          }
        }
      }
    }
  }

  SECTION("Gauss-Legendre quadrature")
  {
    auto quadrature_descriptor = GaussLegendreQuadratureDescriptor(12);

    SECTION("integrate on all cells")
    {
      SECTION("2D")
      {
        constexpr size_t Dimension = 2;

        std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

        for (const auto& named_mesh : mesh_list) {
          SECTION(named_mesh.name())
          {
            auto mesh_2d_v = named_mesh.mesh();
            auto mesh_2d   = mesh_2d_v->get<Mesh<2>>();

            PolynomialMeshBuilder pb{mesh_2d_v, 3};

            auto cubic_mesh = pb.mesh()->get<PolynomialMesh<2>>();

            std::string_view data = R"(
import math;
let scalar_2d: R^2 -> R, x -> 2*x[0]*x[1]+3;
let R3_2d: R^2 -> R^3, x -> [2*x[0]*x[0] + x[1]*x[1] - 1, x[0]-2*x[1], 3];
let R2x2_2d: R^2 -> R^2x2, x -> [[2*x[0]*x[0]*x[1]+3, x[0]-2*x[1]], [3, x[0]*x[1]]];
)";

            TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};

            auto ast = ASTBuilder::build(input);

            ASTModulesImporter{*ast};
            ASTNodeTypeCleaner<language::import_instruction>{*ast};

            ASTSymbolTableBuilder{*ast};
            ASTNodeDataTypeBuilder{*ast};

            ASTNodeTypeCleaner<language::var_declaration>{*ast};
            ASTNodeTypeCleaner<language::fct_declaration>{*ast};
            ASTNodeExpressionBuilder{*ast};

            std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;

            // ensure that variables are declared at this point
            TAO_PEGTL_NAMESPACE::position position{data.size(), 1, 1, "fixture"};

            SECTION("scalar 2d")
            {
              auto [i_symbol, found] = symbol_table->find("scalar_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<double> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *mesh_2d, integrate_value_polygonal);

              Array<double> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *cubic_mesh, integrate_value_cubic);

              REQUIRE(scalar_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("vector 2d")
            {
              using R3               = TinyVector<3>;
              auto [i_symbol, found] = symbol_table->find("R3_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R3> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *mesh_2d, integrate_value_polygonal);

              Array<R3> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *cubic_mesh, integrate_value_cubic);

              REQUIRE(vector_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("matrix 2d")
            {
              using R2x2             = TinyMatrix<2>;
              auto [i_symbol, found] = symbol_table->find("R2x2_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R2x2> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *mesh_2d, integrate_value_polygonal);

              Array<R2x2> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *cubic_mesh, integrate_value_cubic);

              REQUIRE(matrix_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }
          }
        }
      }
    }
  }

  SECTION("Gauss-Lobatto quadrature")
  {
    auto quadrature_descriptor = GaussLobattoQuadratureDescriptor(12);

    SECTION("integrate on all cells")
    {
      SECTION("2D")
      {
        constexpr size_t Dimension = 2;

        std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

        for (const auto& named_mesh : mesh_list) {
          SECTION(named_mesh.name())
          {
            auto mesh_2d_v = named_mesh.mesh();
            auto mesh_2d   = mesh_2d_v->get<Mesh<2>>();

            PolynomialMeshBuilder pb{mesh_2d_v, 3};

            auto cubic_mesh = pb.mesh()->get<PolynomialMesh<2>>();

            std::string_view data = R"(
import math;
let scalar_2d: R^2 -> R, x -> 2*x[0]*x[1]+3;
let R3_2d: R^2 -> R^3, x -> [2*x[0]*x[0] + x[1]*x[1] - 1, x[0]-2*x[1], 3];
let R2x2_2d: R^2 -> R^2x2, x -> [[2*x[0]*x[0]*x[1]+3, x[0]-2*x[1]], [3, x[0]*x[1]]];
)";

            TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};

            auto ast = ASTBuilder::build(input);

            ASTModulesImporter{*ast};
            ASTNodeTypeCleaner<language::import_instruction>{*ast};

            ASTSymbolTableBuilder{*ast};
            ASTNodeDataTypeBuilder{*ast};

            ASTNodeTypeCleaner<language::var_declaration>{*ast};
            ASTNodeTypeCleaner<language::fct_declaration>{*ast};
            ASTNodeExpressionBuilder{*ast};

            std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;

            // ensure that variables are declared at this point
            TAO_PEGTL_NAMESPACE::position position{data.size(), 1, 1, "fixture"};

            SECTION("scalar 2d")
            {
              auto [i_symbol, found] = symbol_table->find("scalar_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<double> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *mesh_2d, integrate_value_polygonal);

              Array<double> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<double(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                           *cubic_mesh, integrate_value_cubic);

              REQUIRE(scalar_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("vector 2d")
            {
              using R3               = TinyVector<3>;
              auto [i_symbol, found] = symbol_table->find("R3_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R3> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *mesh_2d, integrate_value_polygonal);

              Array<R3> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R3(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                       *cubic_mesh, integrate_value_cubic);

              REQUIRE(vector_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }

            SECTION("matrix 2d")
            {
              using R2x2             = TinyMatrix<2>;
              auto [i_symbol, found] = symbol_table->find("R2x2_2d", position);
              REQUIRE(found);
              REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);

              FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);

              Array<R2x2> integrate_value_polygonal(mesh_2d->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *mesh_2d, integrate_value_polygonal);

              Array<R2x2> integrate_value_cubic(cubic_mesh->numberOfCells());
              IntegrateOnCells<R2x2(TinyVector<Dimension>)>::integrateTo(function_symbol_id, quadrature_descriptor,
                                                                         *cubic_mesh, integrate_value_cubic);

              REQUIRE(matrix_error(integrate_value_cubic, integrate_value_polygonal) < 1E-12);
            }
          }
        }
      }
    }
  }
}