From a73b442ec146515d133ee9eef695ee99cb79f3f8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Thu, 22 Apr 2021 19:44:20 +0200
Subject: [PATCH] Add tests for InterpolateItemArray

Also fix interpolation on item list
---
 src/language/utils/InterpolateItemArray.hpp |  15 +-
 tests/CMakeLists.txt                        |   1 +
 tests/test_InterpolateItemArray.cpp         | 467 ++++++++++++++++++++
 3 files changed, 480 insertions(+), 3 deletions(-)
 create mode 100644 tests/test_InterpolateItemArray.cpp

diff --git a/src/language/utils/InterpolateItemArray.hpp b/src/language/utils/InterpolateItemArray.hpp
index fcef1db90..c507c527c 100644
--- a/src/language/utils/InterpolateItemArray.hpp
+++ b/src/language/utils/InterpolateItemArray.hpp
@@ -16,7 +16,7 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O
 
  public:
   template <ItemType item_type>
-  static inline ItemArray<OutputType, item_type>
+  PUGS_INLINE static ItemArray<OutputType, item_type>
   interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list,
               const ItemValue<const InputType, item_type>& position)
   {
@@ -35,9 +35,9 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O
   }
 
   template <ItemType item_type>
-  static inline Table<OutputType>
+  PUGS_INLINE static Table<OutputType>
   interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list,
-              const ItemArray<const InputType, item_type>& position,
+              const ItemValue<const InputType, item_type>& position,
               const Array<const ItemIdT<item_type>>& list_of_items)
   {
     Table<OutputType> table{list_of_items.size(), function_symbol_id_list.size()};
@@ -53,6 +53,15 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O
 
     return table;
   }
+
+  template <ItemType item_type>
+  PUGS_INLINE static Table<OutputType>
+  interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list,
+              const ItemValue<const InputType, item_type>& position,
+              const Array<ItemIdT<item_type>>& list_of_items)
+  {
+    return interpolate(function_symbol_id_list, position, Array<const ItemIdT<item_type>>{list_of_items});
+  }
 };
 
 #endif   // INTERPOLATE_ITEM_ARRAY_HPP
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 098e28a20..4bb227aac 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -103,6 +103,7 @@ add_executable (mpi_unit_tests
   mpi_test_main.cpp
   test_DiscreteFunctionP0.cpp
   test_DiscreteFunctionP0Vector.cpp
+  test_InterpolateItemArray.cpp
   test_InterpolateItemValue.cpp
   test_ItemArray.cpp
   test_ItemArrayUtils.cpp
diff --git a/tests/test_InterpolateItemArray.cpp b/tests/test_InterpolateItemArray.cpp
new file mode 100644
index 000000000..cf65afdf4
--- /dev/null
+++ b/tests/test_InterpolateItemArray.cpp
@@ -0,0 +1,467 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <language/ast/ASTBuilder.hpp>
+#include <language/ast/ASTModulesImporter.hpp>
+#include <language/ast/ASTNodeDataTypeBuilder.hpp>
+#include <language/ast/ASTNodeExpressionBuilder.hpp>
+#include <language/ast/ASTNodeFunctionEvaluationExpressionBuilder.hpp>
+#include <language/ast/ASTNodeFunctionExpressionBuilder.hpp>
+#include <language/ast/ASTNodeTypeCleaner.hpp>
+#include <language/ast/ASTSymbolTableBuilder.hpp>
+#include <language/utils/PugsFunctionAdapter.hpp>
+#include <language/utils/SymbolTable.hpp>
+
+#include <MeshDataBaseForTests.hpp>
+#include <mesh/Connectivity.hpp>
+#include <mesh/Mesh.hpp>
+#include <mesh/MeshData.hpp>
+#include <mesh/MeshDataManager.hpp>
+
+#include <language/utils/InterpolateItemArray.hpp>
+
+#include <pegtl/string_input.hpp>
+
+// clazy:excludeall=non-pod-global-static
+
+TEST_CASE("InterpolateItemArray", "[language]")
+{
+  SECTION("interpolate on all items")
+  {
+    auto same_cell_array = [](auto f, auto g) -> bool {
+      using ItemIdType = typename decltype(f)::index_type;
+
+      for (ItemIdType item_id = 0; item_id < f.numberOfItems(); ++item_id) {
+        for (size_t i = 0; i < f.sizeOfArrays(); ++i) {
+          if (f[item_id][i] != g[item_id][i]) {
+            return false;
+          }
+        }
+      }
+
+      return true;
+    };
+
+    SECTION("1D")
+    {
+      constexpr size_t Dimension = 1;
+
+      const auto& mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_1d).xj();
+
+      std::string_view data = R"(
+import math;
+let scalar_affine_1d: R^1 -> R, x -> 2*x[0] + 2;
+let scalar_non_linear_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3;
+)";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_1d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_1d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      CellArray<double> cell_array{mesh_1d->connectivity(), 2};
+      parallel_for(
+        cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) {
+          const TinyVector<Dimension>& x = xj[cell_id];
+          cell_array[cell_id][0]         = 2 * x[0] + 2;
+          cell_array[cell_id][1]         = 2 * exp(x[0]) + 3;
+        });
+
+      CellArray<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj);
+
+      REQUIRE(same_cell_array(cell_array, interpolate_array));
+    }
+
+    SECTION("2D")
+    {
+      constexpr size_t Dimension = 2;
+
+      const auto& mesh_2d = MeshDataBaseForTests::get().cartesianMesh2D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_2d).xj();
+
+      std::string_view data = R"(
+import math;
+let scalar_affine_2d: R^2 -> R, x -> 2*x[0] + 3*x[1] + 2;
+let scalar_non_linear_2d: R^2 -> R, x -> 2*exp(x[0])*sin(x[1])+3;
+)";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_2d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_2d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      CellArray<double> cell_array{mesh_2d->connectivity(), 2};
+      parallel_for(
+        cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) {
+          const TinyVector<Dimension>& x = xj[cell_id];
+          cell_array[cell_id][0]         = 2 * x[0] + 3 * x[1] + 2;
+          cell_array[cell_id][1]         = 2 * exp(x[0]) * sin(x[1]) + 3;
+        });
+
+      CellArray<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj);
+
+      REQUIRE(same_cell_array(cell_array, interpolate_array));
+    }
+
+    SECTION("3D")
+    {
+      constexpr size_t Dimension = 3;
+
+      const auto& mesh_3d = MeshDataBaseForTests::get().cartesianMesh3D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_3d).xj();
+
+      std::string_view data = R"(
+import math;
+let scalar_affine_3d: R^3 -> R, x -> 2 * x[0] + 3 * x[1] + 2 * x[2] - 1;
+let scalar_non_linear_3d: R^3 -> R, x -> 2 * exp(x[0]) * sin(x[1]) * x[2] + 3;
+)";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_3d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_3d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      CellArray<double> cell_array{mesh_3d->connectivity(), 2};
+      parallel_for(
+        cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) {
+          const TinyVector<Dimension>& x = xj[cell_id];
+          cell_array[cell_id][0]         = 2 * x[0] + 3 * x[1] + 2 * x[2] - 1;
+          cell_array[cell_id][1]         = 2 * exp(x[0]) * sin(x[1]) * x[2] + 3;
+        });
+
+      CellArray<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj);
+
+      REQUIRE(same_cell_array(cell_array, interpolate_array));
+    }
+  }
+
+  SECTION("interpolate on items list")
+  {
+    auto same_cell_value = [](auto interpolated, auto reference) -> bool {
+      for (size_t i = 0; i < interpolated.nbRows(); ++i) {
+        for (size_t j = 0; j < interpolated.nbColumns(); ++j) {
+          if (interpolated[i][j] != reference[i][j]) {
+            return false;
+          }
+        }
+      }
+      return true;
+    };
+
+    SECTION("1D")
+    {
+      constexpr size_t Dimension = 1;
+
+      const auto& mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_1d).xj();
+
+      Array<const CellId> cell_id_list = [&] {
+        Array<CellId> cell_ids{mesh_1d->numberOfCells() / 2};
+        for (size_t i_cell = 0; i_cell < cell_ids.size(); ++i_cell) {
+          cell_ids[i_cell] = static_cast<CellId>(2 * i_cell);
+        }
+        return cell_ids;
+      }();
+
+      std::string_view data = R"(
+  import math;
+  let scalar_affine_1d: R^1 -> R, x -> 2*x[0] + 2;
+  let scalar_non_linear_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3;
+  )";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_1d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_1d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      Table<double> cell_array{cell_id_list.size(), 2};
+      parallel_for(
+        cell_id_list.size(), PUGS_LAMBDA(const size_t i) {
+          const TinyVector<Dimension>& x = xj[cell_id_list[i]];
+          cell_array[i][0]               = 2 * x[0] + 2;
+          cell_array[i][1]               = 2 * exp(x[0]) + 3;
+        });
+
+      Table<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list);
+
+      REQUIRE(same_cell_value(cell_array, interpolate_array));
+    }
+
+    SECTION("2D")
+    {
+      constexpr size_t Dimension = 2;
+
+      const auto& mesh_2d = MeshDataBaseForTests::get().cartesianMesh2D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_2d).xj();
+
+      Array<CellId> cell_id_list{mesh_2d->numberOfCells() / 2};
+      for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) {
+        cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell);
+      }
+
+      std::string_view data = R"(
+import math;
+let scalar_affine_2d: R^2 -> R, x -> 2*x[0] + 3*x[1] + 2;
+let scalar_non_linear_2d: R^2 -> R, x -> 2*exp(x[0])*sin(x[1])+3;
+)";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_2d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_2d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      Table<double> cell_array{cell_id_list.size(), 2};
+      parallel_for(
+        cell_id_list.size(), PUGS_LAMBDA(const size_t i) {
+          const TinyVector<Dimension>& x = xj[cell_id_list[i]];
+          cell_array[i][0]               = 2 * x[0] + 3 * x[1] + 2;
+          cell_array[i][1]               = 2 * exp(x[0]) * sin(x[1]) + 3;
+        });
+
+      Table<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list);
+
+      REQUIRE(same_cell_value(cell_array, interpolate_array));
+    }
+
+    SECTION("3D")
+    {
+      constexpr size_t Dimension = 3;
+
+      const auto& mesh_3d = MeshDataBaseForTests::get().cartesianMesh3D();
+      auto xj             = MeshDataManager::instance().getMeshData(*mesh_3d).xj();
+
+      Array<CellId> cell_id_list{mesh_3d->numberOfCells() / 2};
+      for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) {
+        cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell);
+      }
+
+      std::string_view data = R"(
+import math;
+let scalar_affine_3d: R^3 -> R, x -> 2 * x[0] + 3 * x[1] + 2 * x[2] - 1;
+let scalar_non_linear_3d: R^3 -> R, x -> 2 * exp(x[0]) * sin(x[1]) * x[2] + 3;
+)";
+      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
+
+      auto ast = ASTBuilder::build(input);
+
+      ASTModulesImporter{*ast};
+      ASTNodeTypeCleaner<language::import_instruction>{*ast};
+
+      ASTSymbolTableBuilder{*ast};
+      ASTNodeDataTypeBuilder{*ast};
+
+      ASTNodeTypeCleaner<language::var_declaration>{*ast};
+      ASTNodeTypeCleaner<language::fct_declaration>{*ast};
+      ASTNodeExpressionBuilder{*ast};
+
+      std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table;
+
+      TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"};
+      position.byte = data.size();   // ensure that variables are declared at this point
+
+      std::vector<FunctionSymbolId> function_symbol_id_list;
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_affine_3d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      {
+        auto [i_symbol, found] = symbol_table->find("scalar_non_linear_3d", position);
+        REQUIRE(found);
+        REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t);
+
+        function_symbol_id_list.push_back(
+          FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table));
+      }
+
+      Table<double> cell_array{cell_id_list.size(), 2};
+      parallel_for(
+        cell_id_list.size(), PUGS_LAMBDA(const size_t i) {
+          const TinyVector<Dimension>& x = xj[cell_id_list[i]];
+          cell_array[i][0]               = 2 * x[0] + 3 * x[1] + 2 * x[2] - 1;
+          cell_array[i][1]               = 2 * exp(x[0]) * sin(x[1]) * x[2] + 3;
+        });
+
+      Table<const double> interpolate_array =
+        InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list);
+
+      REQUIRE(same_cell_value(cell_array, interpolate_array));
+    }
+  }
+}
-- 
GitLab