#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <MeshDataBaseForTests.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemArray.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Messenger.hpp>

template class ItemArray<int, ItemType::node>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("ItemArray", "[mesh]")
{
  SECTION("default constructors")
  {
    REQUIRE_NOTHROW(NodeArray<int>{});
    REQUIRE_NOTHROW(EdgeArray<int>{});
    REQUIRE_NOTHROW(FaceArray<int>{});
    REQUIRE_NOTHROW(CellArray<int>{});

    REQUIRE(not NodeArray<int>{}.isBuilt());
    REQUIRE(not EdgeArray<int>{}.isBuilt());
    REQUIRE(not FaceArray<int>{}.isBuilt());
    REQUIRE(not CellArray<int>{}.isBuilt());
  }

  SECTION("1D")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_1d = named_mesh.mesh();

        const Connectivity<1>& connectivity = mesh_1d->connectivity();

        REQUIRE_NOTHROW(NodeArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(EdgeArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(FaceArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(CellArray<int>{connectivity, 3});

        REQUIRE(NodeArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(EdgeArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(FaceArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(CellArray<int>{connectivity, 3}.isBuilt());

        NodeArray<int> node_value{connectivity, 3};
        EdgeArray<int> edge_value{connectivity, 3};
        FaceArray<int> face_value{connectivity, 3};
        CellArray<int> cell_value{connectivity, 3};

        REQUIRE(edge_value.numberOfItems() == node_value.numberOfItems());
        REQUIRE(face_value.numberOfItems() == node_value.numberOfItems());
        REQUIRE(cell_value.numberOfItems() + 1 == node_value.numberOfItems());

        REQUIRE(node_value.sizeOfArrays() == 3);
        REQUIRE(edge_value.sizeOfArrays() == 3);
        REQUIRE(face_value.sizeOfArrays() == 3);
        REQUIRE(cell_value.sizeOfArrays() == 3);
      }
    }
  }

  SECTION("2D")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_2d = named_mesh.mesh();

        const Connectivity<2>& connectivity = mesh_2d->connectivity();

        REQUIRE_NOTHROW(NodeArray<int>{connectivity, 2});
        REQUIRE_NOTHROW(EdgeArray<int>{connectivity, 2});
        REQUIRE_NOTHROW(FaceArray<int>{connectivity, 2});
        REQUIRE_NOTHROW(CellArray<int>{connectivity, 2});

        REQUIRE(NodeArray<int>{connectivity, 2}.isBuilt());
        REQUIRE(EdgeArray<int>{connectivity, 2}.isBuilt());
        REQUIRE(FaceArray<int>{connectivity, 2}.isBuilt());
        REQUIRE(CellArray<int>{connectivity, 2}.isBuilt());

        NodeArray<int> node_value{connectivity, 2};
        EdgeArray<int> edge_value{connectivity, 2};
        FaceArray<int> face_value{connectivity, 2};
        CellArray<int> cell_value{connectivity, 2};

        REQUIRE(edge_value.numberOfItems() == face_value.numberOfItems());

        REQUIRE(node_value.sizeOfArrays() == 2);
        REQUIRE(edge_value.sizeOfArrays() == 2);
        REQUIRE(face_value.sizeOfArrays() == 2);
        REQUIRE(cell_value.sizeOfArrays() == 2);
      }
    }
  }

  SECTION("3D")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_3d = named_mesh.mesh();

        const Connectivity<3>& connectivity = mesh_3d->connectivity();

        REQUIRE_NOTHROW(NodeArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(EdgeArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(FaceArray<int>{connectivity, 3});
        REQUIRE_NOTHROW(CellArray<int>{connectivity, 3});

        REQUIRE(NodeArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(EdgeArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(FaceArray<int>{connectivity, 3}.isBuilt());
        REQUIRE(CellArray<int>{connectivity, 3}.isBuilt());

        NodeArray<int> node_value{connectivity, 3};
        EdgeArray<int> edge_value{connectivity, 3};
        FaceArray<int> face_value{connectivity, 3};
        CellArray<int> cell_value{connectivity, 3};

        REQUIRE(node_value.sizeOfArrays() == 3);
        REQUIRE(edge_value.sizeOfArrays() == 3);
        REQUIRE(face_value.sizeOfArrays() == 3);
        REQUIRE(cell_value.sizeOfArrays() == 3);
      }
    }
  }

  SECTION("set values from array")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_3d = named_mesh.mesh();

        const Connectivity<3>& connectivity = mesh_3d->connectivity();

        CellArray<size_t> cell_array{connectivity, 3};

        Table<size_t> table{cell_array.numberOfItems(), cell_array.sizeOfArrays()};
        {
          size_t k = 0;
          for (size_t i = 0; i < table.numberOfRows(); ++i) {
            for (size_t j = 0; j < table.numberOfColumns(); ++j) {
              table(i, j) = k++;
            }
          }
        }
        cell_array = table;

        auto is_same = [](const CellArray<size_t>& cell_array, const Table<size_t>& table) {
          bool is_same = true;
          for (CellId cell_id = 0; cell_id < cell_array.numberOfItems(); ++cell_id) {
            Array sub_array = cell_array[cell_id];
            for (size_t i = 0; i < sub_array.size(); ++i) {
              is_same &= (sub_array[i] == table(cell_id, i));
            }
          }
          return is_same;
        };

        REQUIRE(is_same(cell_array, table));
      }
    }
  }

  SECTION("copy")
  {
    auto is_same = [](const auto& cell_array, int value) {
      bool is_same = true;
      for (CellId cell_id = 0; cell_id < cell_array.numberOfItems(); ++cell_id) {
        Array sub_array = cell_array[cell_id];
        for (size_t i = 0; i < sub_array.size(); ++i) {
          is_same &= (sub_array[i] == value);
        }
      }
      return is_same;
    };

    std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_3d = named_mesh.mesh();

        const Connectivity<3>& connectivity = mesh_3d->connectivity();

        CellArray<int> cell_array{connectivity, 4};
        cell_array.fill(parallel::rank());

        CellArray<const int> cell_array_const_view{cell_array};
        REQUIRE(cell_array.numberOfItems() == cell_array_const_view.numberOfItems());
        REQUIRE(cell_array.sizeOfArrays() == cell_array_const_view.sizeOfArrays());
        REQUIRE(is_same(cell_array_const_view, static_cast<std::int64_t>(parallel::rank())));

        CellArray<const int> const_cell_array;
        const_cell_array = copy(cell_array);

        CellArray<int> duplicated_cell_array{connectivity, cell_array.sizeOfArrays()};
        copy_to(const_cell_array, duplicated_cell_array);

        cell_array.fill(0);

        REQUIRE(is_same(cell_array, 0));
        REQUIRE(is_same(cell_array_const_view, 0));
        REQUIRE(is_same(const_cell_array, static_cast<std::int64_t>(parallel::rank())));
        REQUIRE(is_same(duplicated_cell_array, static_cast<std::int64_t>(parallel::rank())));
      }
    }
  }

  SECTION("WeakItemArray")
  {
    std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes();

    for (auto named_mesh : mesh_list) {
      SECTION(named_mesh.name())
      {
        auto mesh_2d = named_mesh.mesh();

        const Connectivity<2>& connectivity = mesh_2d->connectivity();

        WeakFaceArray<int> weak_face_array{connectivity, 5};

        weak_face_array.fill(parallel::rank());

        FaceArray<const int> face_array{weak_face_array};

        REQUIRE(face_array.connectivity_ptr() == weak_face_array.connectivity_ptr());
      }
    }
  }

#ifndef NDEBUG
  SECTION("error")
  {
    SECTION("checking for build ItemArray")
    {
      CellArray<int> cell_array;
      REQUIRE_THROWS_AS(cell_array[CellId{0}], AssertError);

      FaceArray<int> face_array;
      REQUIRE_THROWS_AS(face_array[FaceId{0}], AssertError);

      EdgeArray<int> edge_array;
      REQUIRE_THROWS_AS(edge_array[EdgeId{0}], AssertError);

      NodeArray<int> node_array;
      REQUIRE_THROWS_AS(node_array[NodeId{0}], AssertError);
    }

    SECTION("checking for bounds violation")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_3d = named_mesh.mesh();

          const Connectivity<3>& connectivity = mesh_3d->connectivity();

          CellArray<int> cell_array{connectivity, 1};
          CellId invalid_cell_id = connectivity.numberOfCells();
          REQUIRE_THROWS_AS(cell_array[invalid_cell_id], AssertError);

          FaceArray<int> face_array{connectivity, 2};
          FaceId invalid_face_id = connectivity.numberOfFaces();
          REQUIRE_THROWS_AS(face_array[invalid_face_id], AssertError);

          EdgeArray<int> edge_array{connectivity, 1};
          EdgeId invalid_edge_id = connectivity.numberOfEdges();
          REQUIRE_THROWS_AS(edge_array[invalid_edge_id], AssertError);

          NodeArray<int> node_array{connectivity, 0};
          NodeId invalid_node_id = connectivity.numberOfNodes();
          REQUIRE_THROWS_AS(node_array[invalid_node_id], AssertError);
        }
      }
    }

    SECTION("set values from invalid array size")
    {
      std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes();

      for (auto named_mesh : mesh_list) {
        SECTION(named_mesh.name())
        {
          auto mesh_3d = named_mesh.mesh();

          const Connectivity<3>& connectivity = mesh_3d->connectivity();

          CellArray<size_t> cell_array{connectivity, 2};

          Table<size_t> values{3, connectivity.numberOfCells() + 3};
          REQUIRE_THROWS_AS(cell_array = values, AssertError);
        }
      }
    }
  }
#endif   // NDEBUG
}
