#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <MeshDataBaseForTests.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

#include <mesh/CartesianMeshBuilder.hpp>

// clazy:excludeall=non-pod-global-static

TEST_CASE("DiscreteFunctionUtils", "[scheme]")
{
  SECTION("1D")
  {
    constexpr size_t Dimension = 1;

    std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh = named_mesh.mesh();

        std::shared_ptr mesh_copy =
          std::make_shared<std::decay_t<decltype(*mesh)>>(mesh->shared_connectivity(), mesh->xr());

        SECTION("common mesh")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr wh = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<2>>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          REQUIRE(getCommonMesh({uh, vh, wh}).get() == mesh.get());
          REQUIRE(getCommonMesh({uh, vh, wh, qh}).use_count() == 0);
        }

        SECTION("check discretization type")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          std::shared_ptr Uh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);
          std::shared_ptr Vh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);

          REQUIRE(checkDiscretizationType({uh}, DiscreteFunctionType::P0));
          REQUIRE(checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({Uh}, DiscreteFunctionType::P0));
        }

        SECTION("scalar function shallow copy")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1 function shallow copy")
        {
          using DataType     = TinyVector<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2 function shallow copy")
        {
          using DataType     = TinyVector<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3 function shallow copy")
        {
          using DataType     = TinyVector<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1x1 function shallow copy")
        {
          using DataType     = TinyMatrix<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2x2 function shallow copy")
        {
          using DataType     = TinyMatrix<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3x3 function shallow copy")
        {
          using DataType     = TinyMatrix<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }
      }
    }
  }

  SECTION("2D")
  {
    constexpr size_t Dimension = 2;

    std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh = named_mesh.mesh();

        std::shared_ptr mesh_copy =
          std::make_shared<std::decay_t<decltype(*mesh)>>(mesh->shared_connectivity(), mesh->xr());

        SECTION("common mesh")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr wh = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<2>>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          REQUIRE(getCommonMesh({uh, vh, wh}).get() == mesh.get());
          REQUIRE(getCommonMesh({uh, vh, wh, qh}).use_count() == 0);
        }

        SECTION("check discretization type")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          std::shared_ptr Uh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);
          std::shared_ptr Vh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);

          REQUIRE(checkDiscretizationType({uh}, DiscreteFunctionType::P0));
          REQUIRE(checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({Uh}, DiscreteFunctionType::P0));
        }

        SECTION("scalar function shallow copy")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1 function shallow copy")
        {
          using DataType     = TinyVector<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2 function shallow copy")
        {
          using DataType     = TinyVector<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3 function shallow copy")
        {
          using DataType     = TinyVector<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1x1 function shallow copy")
        {
          using DataType     = TinyMatrix<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2x2 function shallow copy")
        {
          using DataType     = TinyMatrix<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3x3 function shallow copy")
        {
          using DataType     = TinyMatrix<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }
      }
    }
  }

  SECTION("3D")
  {
    constexpr size_t Dimension = 3;

    std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh = named_mesh.mesh();

        std::shared_ptr mesh_copy =
          std::make_shared<std::decay_t<decltype(*mesh)>>(mesh->shared_connectivity(), mesh->xr());

        SECTION("common mesh")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr wh = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<2>>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          REQUIRE(getCommonMesh({uh, vh, wh}).get() == mesh.get());
          REQUIRE(getCommonMesh({uh, vh, wh, qh}).use_count() == 0);
        }

        SECTION("check discretization type")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);

          std::shared_ptr qh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_copy);

          std::shared_ptr Uh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);
          std::shared_ptr Vh = std::make_shared<DiscreteFunctionP0Vector<Dimension, double>>(mesh, 3);

          REQUIRE(checkDiscretizationType({uh}, DiscreteFunctionType::P0));
          REQUIRE(checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({uh, vh, qh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh}, DiscreteFunctionType::P0Vector));
          REQUIRE(checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0Vector));
          REQUIRE(not checkDiscretizationType({Uh, Vh}, DiscreteFunctionType::P0));
          REQUIRE(not checkDiscretizationType({Uh}, DiscreteFunctionType::P0));
        }

        SECTION("scalar function shallow copy")
        {
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1 function shallow copy")
        {
          using DataType     = TinyVector<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2 function shallow copy")
        {
          using DataType     = TinyVector<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3 function shallow copy")
        {
          using DataType     = TinyVector<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^1x1 function shallow copy")
        {
          using DataType     = TinyMatrix<1>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^2x2 function shallow copy")
        {
          using DataType     = TinyMatrix<2>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }

        SECTION("R^3x3 function shallow copy")
        {
          using DataType     = TinyMatrix<3>;
          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh);
          std::shared_ptr vh = shallowCopy(mesh, uh);

          REQUIRE(uh == vh);

          std::shared_ptr wh = shallowCopy(mesh_copy, uh);

          REQUIRE(uh != wh);
          REQUIRE(&(uh->cellValues()[CellId{0}]) ==
                  &(dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(*wh).cellValues()[CellId{0}]));
        }
      }
    }
  }

  SECTION("errors")
  {
    SECTION("different connectivities")
    {
      constexpr size_t Dimension = 1;

      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh = named_mesh.mesh();

          std::shared_ptr other_mesh =
            CartesianMeshBuilder{TinyVector<1>{-1}, TinyVector<1>{3}, TinyVector<1, size_t>{19}}.mesh();

          std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);

          REQUIRE_THROWS_WITH(shallowCopy(other_mesh, uh), "error: cannot shallow copy when connectivity changes");
        }
      }
    }

    SECTION("incompatible mesh dimension")
    {
      constexpr size_t Dimension = 1;

      std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesian1DMesh();
      std::shared_ptr mesh_2d = MeshDataBaseForTests::get().cartesian2DMesh();

      std::shared_ptr uh = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh_1d);

      REQUIRE_THROWS_WITH(shallowCopy(mesh_2d, uh), "error: incompatible mesh dimensions");
    }
  }
}
