#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <MeshDataBaseForTests.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Messenger.hpp>

// Instantiate to ensure full coverage is performed
template class ItemValue<int, ItemType::cell>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("ItemValueUtils", "[mesh]")
{
  SECTION("Synchronize")
  {
    const Mesh<Connectivity<2>>& mesh_2d = MeshDataBaseForTests::get().cartesianMesh<2>();
    const Connectivity<2>& connectivity  = mesh_2d.connectivity();

    WeakFaceValue<int> weak_face_value{connectivity};

    weak_face_value.fill(parallel::rank());

    FaceValue<const int> face_value{weak_face_value};

    REQUIRE(face_value.connectivity_ptr() == weak_face_value.connectivity_ptr());

    {   // before synchronization
      auto face_owner    = connectivity.faceOwner();
      auto face_is_owned = connectivity.faceIsOwned();

      for (FaceId i_face = 0; i_face < mesh_2d.numberOfFaces(); ++i_face) {
        if (face_is_owned[i_face]) {
          REQUIRE(face_owner[i_face] == face_value[i_face]);
        } else {
          REQUIRE(face_owner[i_face] != face_value[i_face]);
        }
      }
    }

    synchronize(weak_face_value);

    {   // after synchronization
      auto face_owner    = connectivity.faceOwner();
      auto face_is_owned = connectivity.faceIsOwned();

      for (FaceId i_face = 0; i_face < mesh_2d.numberOfFaces(); ++i_face) {
        REQUIRE(face_owner[i_face] == face_value[i_face]);
      }
    }
  }

  SECTION("min")
  {
    SECTION("1D")
    {
      const Mesh<Connectivity<1>>& mesh_1d = MeshDataBaseForTests::get().cartesianMesh<1>();
      const Connectivity<1>& connectivity  = mesh_1d.connectivity();

      CellValue<int> cell_value{connectivity};
      cell_value.fill(-1);

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_1d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = 10 + parallel::rank();
          }
        });

      REQUIRE(min(cell_value) == 10);
    }

    SECTION("2D")
    {
      const Mesh<Connectivity<2>>& mesh_2d = MeshDataBaseForTests::get().cartesianMesh<2>();
      const Connectivity<2>& connectivity  = mesh_2d.connectivity();

      CellValue<int> cell_value{connectivity};
      cell_value.fill(-1);

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_2d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = 10 + parallel::rank();
          }
        });

      REQUIRE(min(cell_value) == 10);
    }

    SECTION("3D")
    {
      const Mesh<Connectivity<3>>& mesh_3d = MeshDataBaseForTests::get().cartesianMesh<3>();
      const Connectivity<3>& connectivity  = mesh_3d.connectivity();

      CellValue<int> cell_value{connectivity};
      cell_value.fill(-1);

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_3d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = 10 + parallel::rank();
          }
        });

      REQUIRE(min(cell_value) == 10);
    }
  }

  SECTION("max")
  {
    SECTION("1D")
    {
      const Mesh<Connectivity<1>>& mesh_1d = MeshDataBaseForTests::get().cartesianMesh<1>();
      const Connectivity<1>& connectivity  = mesh_1d.connectivity();

      CellValue<size_t> cell_value{connectivity};
      cell_value.fill(std::numeric_limits<size_t>::max());

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_1d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = parallel::rank() + 1;
          }
        });

      REQUIRE(max(cell_value) == parallel::size());
    }

    SECTION("2D")
    {
      const Mesh<Connectivity<2>>& mesh_2d = MeshDataBaseForTests::get().cartesianMesh<2>();
      const Connectivity<2>& connectivity  = mesh_2d.connectivity();

      CellValue<size_t> cell_value{connectivity};
      cell_value.fill(std::numeric_limits<size_t>::max());

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_2d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = parallel::rank() + 1;
          }
        });

      REQUIRE(max(cell_value) == parallel::size());
    }

    SECTION("3D")
    {
      const Mesh<Connectivity<3>>& mesh_3d = MeshDataBaseForTests::get().cartesianMesh<3>();
      const Connectivity<3>& connectivity  = mesh_3d.connectivity();

      CellValue<size_t> cell_value{connectivity};
      cell_value.fill(std::numeric_limits<size_t>::max());

      auto cell_is_owned = connectivity.cellIsOwned();
      parallel_for(
        mesh_3d.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          if (cell_is_owned[cell_id]) {
            cell_value[cell_id] = parallel::rank() + 1;
          }
        });

      REQUIRE(max(cell_value) == parallel::size());
    }
  }

  SECTION("sum")
  {
    SECTION("1D")
    {
      const Mesh<Connectivity<1>>& mesh_1d = MeshDataBaseForTests::get().cartesianMesh<1>();
      const Connectivity<1>& connectivity  = mesh_1d.connectivity();

      CellValue<size_t> cell_value{connectivity};
      cell_value.fill(5);

      auto cell_is_owned = connectivity.cellIsOwned();

      const size_t global_number_of_cells = [&] {
        size_t number_of_cells = 0;
        for (CellId cell_id = 0; cell_id < cell_is_owned.numberOfItems(); ++cell_id) {
          number_of_cells += cell_is_owned[cell_id];
        }
        return parallel::allReduceSum(number_of_cells);
      }();

      REQUIRE(sum(cell_value) == 5 * global_number_of_cells);
    }

    SECTION("2D")
    {
      const Mesh<Connectivity<2>>& mesh_2d = MeshDataBaseForTests::get().cartesianMesh<2>();
      const Connectivity<2>& connectivity  = mesh_2d.connectivity();

      FaceValue<size_t> face_value{connectivity};
      face_value.fill(2);

      auto face_is_owned = connectivity.faceIsOwned();

      const size_t global_number_of_faces = [&] {
        size_t number_of_faces = 0;
        for (FaceId face_id = 0; face_id < face_is_owned.numberOfItems(); ++face_id) {
          number_of_faces += face_is_owned[face_id];
        }
        return parallel::allReduceSum(number_of_faces);
      }();

      REQUIRE(sum(face_value) == 2 * global_number_of_faces);
    }

    SECTION("3D")
    {
      const Mesh<Connectivity<3>>& mesh_3d = MeshDataBaseForTests::get().cartesianMesh<3>();
      const Connectivity<3>& connectivity  = mesh_3d.connectivity();

      NodeValue<size_t> node_value{connectivity};
      node_value.fill(3);

      auto node_is_owned = connectivity.nodeIsOwned();

      const size_t global_number_of_nodes = [&] {
        size_t number_of_nodes = 0;
        for (NodeId node_id = 0; node_id < node_is_owned.numberOfItems(); ++node_id) {
          number_of_nodes += node_is_owned[node_id];
        }
        return parallel::allReduceSum(number_of_nodes);
      }();

      REQUIRE(sum(node_value) == 3 * global_number_of_nodes);
    }
  }
}