#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <language/utils/EmbeddedIDiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>

#include <MeshDataBaseForTests.hpp>

// clazy:excludeall=non-pod-global-static

TEST_CASE("EmbeddedIDiscreteFunctionUtils", "[language]")
{
  using R1 = TinyVector<1, double>;
  using R2 = TinyVector<2, double>;
  using R3 = TinyVector<3, double>;

  using R1x1 = TinyMatrix<1, 1, double>;
  using R2x2 = TinyMatrix<2, 2, double>;
  using R3x3 = TinyMatrix<3, 3, double>;

  SECTION("operand type name")
  {
    SECTION("basic types")
    {
      REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(double{1}) == "R");
      REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(std::make_shared<double>(1)) == "R");
    }

    SECTION("discrete P0 function")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, double>{mesh_1d}) ==
                  "Vh(P0:R)");

          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R1>{mesh_1d}) ==
                  "Vh(P0:R^1)");
          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R2>{mesh_1d}) ==
                  "Vh(P0:R^2)");
          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R3>{mesh_1d}) ==
                  "Vh(P0:R^3)");

          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R1x1>{mesh_1d}) ==
                  "Vh(P0:R^1x1)");
          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R2x2>{mesh_1d}) ==
                  "Vh(P0:R^2x2)");
          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R3x3>{mesh_1d}) ==
                  "Vh(P0:R^3x3)");
        }
      }
    }

    SECTION("discrete P0Vector function")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0Vector<1, double>{mesh_1d, 2}) ==
                  "Vh(P0Vector:R)");
        }
      }
    }
  }

  SECTION("check if is same discretization")
  {
    SECTION("from shared_ptr")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(std::make_shared<DiscreteFunctionP0<1, double>>(
                                                                         mesh_1d),
                                                                       std::make_shared<DiscreteFunctionP0<1, double>>(
                                                                         mesh_1d)));

          REQUIRE(not EmbeddedIDiscreteFunctionUtils::
                    isSameDiscretization(std::make_shared<DiscreteFunctionP0<1, double>>(mesh_1d),
                                         std::make_shared<DiscreteFunctionP0Vector<1, double>>(mesh_1d, 1)));
        }
      }
    }

    SECTION("from value")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, double>{mesh_1d},
                                                                       DiscreteFunctionP0<1, double>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R1>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R1>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R2>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R3>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R1x1>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R1x1>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2x2>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R2x2>{mesh_1d}));

          REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3x3>{mesh_1d},
                                                                       DiscreteFunctionP0<1, R3x3>{mesh_1d}));

          REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, double>{mesh_1d},
                                                                           DiscreteFunctionP0<1, R1>{mesh_1d}));

          REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2>{mesh_1d},
                                                                           DiscreteFunctionP0<1, R2x2>{mesh_1d}));

          REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3x3>{mesh_1d},
                                                                           DiscreteFunctionP0<1, R2x2>{mesh_1d}));
        }
      }
    }

    SECTION("invalid data type")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          REQUIRE_THROWS_WITH(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1,
                                                                                                      int64_t>{mesh_1d},
                                                                                   DiscreteFunctionP0<1, int64_t>{
                                                                                     mesh_1d}),
                              "unexpected error: invalid data type Vh(P0:Z)");
        }
      }
    }
  }

#ifndef NDEBUG
  SECTION("errors")
  {
    REQUIRE_THROWS_WITH(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(std::shared_ptr<double>()),
                        "dangling shared_ptr");
  }

#endif   // NDEBUG
}
