#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <MeshDataBaseForTests.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Messenger.hpp>

// Instantiate to ensure full coverage is performed
template class ItemValue<int, ItemType::cell>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("ItemValueUtils", "[mesh]")
{
  SECTION("Synchronize")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_2d = named_mesh.mesh();

        const Connectivity<2>& connectivity = mesh_2d->connectivity();

        WeakFaceValue<int> weak_face_value{connectivity};

        weak_face_value.fill(parallel::rank());

        FaceValue<const int> face_value{weak_face_value};

        REQUIRE(face_value.connectivity_ptr() == weak_face_value.connectivity_ptr());

        {   // before synchronization
          auto face_owner    = connectivity.faceOwner();
          auto face_is_owned = connectivity.faceIsOwned();

          for (FaceId i_face = 0; i_face < mesh_2d->numberOfFaces(); ++i_face) {
            if (face_is_owned[i_face]) {
              REQUIRE(face_owner[i_face] == face_value[i_face]);
            } else {
              REQUIRE(face_owner[i_face] != face_value[i_face]);
            }
          }
        }

        synchronize(weak_face_value);

        {   // after synchronization
          auto face_owner    = connectivity.faceOwner();
          auto face_is_owned = connectivity.faceIsOwned();

          for (FaceId i_face = 0; i_face < mesh_2d->numberOfFaces(); ++i_face) {
            REQUIRE(face_owner[i_face] == face_value[i_face]);
          }
        }
      }
    }
  }

  SECTION("min")
  {
    SECTION("1D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          const Connectivity<1>& connectivity = mesh_1d->connectivity();

          CellValue<int> cell_value{connectivity};
          cell_value.fill(-1);

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_1d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = 10 + parallel::rank();
              }
            });

          REQUIRE(min(cell_value) == 10);
        }
      }
    }

    SECTION("2D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_2d = named_mesh.mesh();

          const Connectivity<2>& connectivity = mesh_2d->connectivity();

          CellValue<int> cell_value{connectivity};
          cell_value.fill(-1);

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_2d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = 10 + parallel::rank();
              }
            });

          REQUIRE(min(cell_value) == 10);
        }
      }
    }

    SECTION("3D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_3d = named_mesh.mesh();

          const Connectivity<3>& connectivity = mesh_3d->connectivity();

          CellValue<int> cell_value{connectivity};
          cell_value.fill(-1);

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_3d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = 10 + parallel::rank();
              }
            });

          REQUIRE(min(cell_value) == 10);
        }
      }
    }
  }

  SECTION("max")
  {
    SECTION("1D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          const Connectivity<1>& connectivity = mesh_1d->connectivity();

          CellValue<size_t> cell_value{connectivity};
          cell_value.fill(std::numeric_limits<size_t>::max());

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_1d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = parallel::rank() + 1;
              }
            });

          REQUIRE(max(cell_value) == parallel::size());
        }
      }
    }

    SECTION("2D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_2d = named_mesh.mesh();

          const Connectivity<2>& connectivity = mesh_2d->connectivity();

          CellValue<size_t> cell_value{connectivity};
          cell_value.fill(std::numeric_limits<size_t>::max());

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_2d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = parallel::rank() + 1;
              }
            });

          REQUIRE(max(cell_value) == parallel::size());
        }
      }
    }

    SECTION("3D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_3d = named_mesh.mesh();

          const Connectivity<3>& connectivity = mesh_3d->connectivity();

          CellValue<size_t> cell_value{connectivity};
          cell_value.fill(std::numeric_limits<size_t>::max());

          auto cell_is_owned = connectivity.cellIsOwned();
          parallel_for(
            mesh_3d->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
              if (cell_is_owned[cell_id]) {
                cell_value[cell_id] = parallel::rank() + 1;
              }
            });

          REQUIRE(max(cell_value) == parallel::size());
        }
      }
    }
  }

  SECTION("sum")
  {
    SECTION("1D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_1d = named_mesh.mesh();

          const Connectivity<1>& connectivity = mesh_1d->connectivity();

          CellValue<size_t> cell_value{connectivity};
          cell_value.fill(5);

          auto cell_is_owned = connectivity.cellIsOwned();

          const size_t global_number_of_cells = [&] {
            size_t number_of_cells = 0;
            for (CellId cell_id = 0; cell_id < cell_is_owned.numberOfItems(); ++cell_id) {
              number_of_cells += cell_is_owned[cell_id];
            }
            return parallel::allReduceSum(number_of_cells);
          }();

          REQUIRE(sum(cell_value) == 5 * global_number_of_cells);
        }
      }
    }

    SECTION("2D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_2d = named_mesh.mesh();

          const Connectivity<2>& connectivity = mesh_2d->connectivity();

          FaceValue<size_t> face_value{connectivity};
          face_value.fill(2);

          auto face_is_owned = connectivity.faceIsOwned();

          const size_t global_number_of_faces = [&] {
            size_t number_of_faces = 0;
            for (FaceId face_id = 0; face_id < face_is_owned.numberOfItems(); ++face_id) {
              number_of_faces += face_is_owned[face_id];
            }
            return parallel::allReduceSum(number_of_faces);
          }();

          REQUIRE(sum(face_value) == 2 * global_number_of_faces);
        }
      }
    }

    SECTION("3D")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_3d = named_mesh.mesh();

          const Connectivity<3>& connectivity = mesh_3d->connectivity();

          NodeValue<size_t> node_value{connectivity};
          node_value.fill(3);

          auto node_is_owned = connectivity.nodeIsOwned();

          const size_t global_number_of_nodes = [&] {
            size_t number_of_nodes = 0;
            for (NodeId node_id = 0; node_id < node_is_owned.numberOfItems(); ++node_id) {
              number_of_nodes += node_is_owned[node_id];
            }
            return parallel::allReduceSum(number_of_nodes);
          }();

          REQUIRE(sum(node_value) == 3 * global_number_of_nodes);
        }
      }
    }
  }
}
