From a73b442ec146515d133ee9eef695ee99cb79f3f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Thu, 22 Apr 2021 19:44:20 +0200 Subject: [PATCH] Add tests for InterpolateItemArray Also fix interpolation on item list --- src/language/utils/InterpolateItemArray.hpp | 15 +- tests/CMakeLists.txt | 1 + tests/test_InterpolateItemArray.cpp | 467 ++++++++++++++++++++ 3 files changed, 480 insertions(+), 3 deletions(-) create mode 100644 tests/test_InterpolateItemArray.cpp diff --git a/src/language/utils/InterpolateItemArray.hpp b/src/language/utils/InterpolateItemArray.hpp index fcef1db90..c507c527c 100644 --- a/src/language/utils/InterpolateItemArray.hpp +++ b/src/language/utils/InterpolateItemArray.hpp @@ -16,7 +16,7 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O public: template <ItemType item_type> - static inline ItemArray<OutputType, item_type> + PUGS_INLINE static ItemArray<OutputType, item_type> interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list, const ItemValue<const InputType, item_type>& position) { @@ -35,9 +35,9 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O } template <ItemType item_type> - static inline Table<OutputType> + PUGS_INLINE static Table<OutputType> interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list, - const ItemArray<const InputType, item_type>& position, + const ItemValue<const InputType, item_type>& position, const Array<const ItemIdT<item_type>>& list_of_items) { Table<OutputType> table{list_of_items.size(), function_symbol_id_list.size()}; @@ -53,6 +53,15 @@ class InterpolateItemArray<OutputType(InputType)> : public PugsFunctionAdapter<O return table; } + + template <ItemType item_type> + PUGS_INLINE static Table<OutputType> + interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list, + const ItemValue<const InputType, item_type>& position, + const Array<ItemIdT<item_type>>& list_of_items) + { + return interpolate(function_symbol_id_list, position, Array<const ItemIdT<item_type>>{list_of_items}); + } }; #endif // INTERPOLATE_ITEM_ARRAY_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 098e28a20..4bb227aac 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -103,6 +103,7 @@ add_executable (mpi_unit_tests mpi_test_main.cpp test_DiscreteFunctionP0.cpp test_DiscreteFunctionP0Vector.cpp + test_InterpolateItemArray.cpp test_InterpolateItemValue.cpp test_ItemArray.cpp test_ItemArrayUtils.cpp diff --git a/tests/test_InterpolateItemArray.cpp b/tests/test_InterpolateItemArray.cpp new file mode 100644 index 000000000..cf65afdf4 --- /dev/null +++ b/tests/test_InterpolateItemArray.cpp @@ -0,0 +1,467 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTModulesImporter.hpp> +#include <language/ast/ASTNodeDataTypeBuilder.hpp> +#include <language/ast/ASTNodeExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionEvaluationExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionExpressionBuilder.hpp> +#include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/utils/PugsFunctionAdapter.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <MeshDataBaseForTests.hpp> +#include <mesh/Connectivity.hpp> +#include <mesh/Mesh.hpp> +#include <mesh/MeshData.hpp> +#include <mesh/MeshDataManager.hpp> + +#include <language/utils/InterpolateItemArray.hpp> + +#include <pegtl/string_input.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("InterpolateItemArray", "[language]") +{ + SECTION("interpolate on all items") + { + auto same_cell_array = [](auto f, auto g) -> bool { + using ItemIdType = typename decltype(f)::index_type; + + for (ItemIdType item_id = 0; item_id < f.numberOfItems(); ++item_id) { + for (size_t i = 0; i < f.sizeOfArrays(); ++i) { + if (f[item_id][i] != g[item_id][i]) { + return false; + } + } + } + + return true; + }; + + SECTION("1D") + { + constexpr size_t Dimension = 1; + + const auto& mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_1d).xj(); + + std::string_view data = R"( +import math; +let scalar_affine_1d: R^1 -> R, x -> 2*x[0] + 2; +let scalar_non_linear_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + CellArray<double> cell_array{mesh_1d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = 2 * x[0] + 2; + cell_array[cell_id][1] = 2 * exp(x[0]) + 3; + }); + + CellArray<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + + SECTION("2D") + { + constexpr size_t Dimension = 2; + + const auto& mesh_2d = MeshDataBaseForTests::get().cartesianMesh2D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_2d).xj(); + + std::string_view data = R"( +import math; +let scalar_affine_2d: R^2 -> R, x -> 2*x[0] + 3*x[1] + 2; +let scalar_non_linear_2d: R^2 -> R, x -> 2*exp(x[0])*sin(x[1])+3; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + CellArray<double> cell_array{mesh_2d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = 2 * x[0] + 3 * x[1] + 2; + cell_array[cell_id][1] = 2 * exp(x[0]) * sin(x[1]) + 3; + }); + + CellArray<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + + SECTION("3D") + { + constexpr size_t Dimension = 3; + + const auto& mesh_3d = MeshDataBaseForTests::get().cartesianMesh3D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_3d).xj(); + + std::string_view data = R"( +import math; +let scalar_affine_3d: R^3 -> R, x -> 2 * x[0] + 3 * x[1] + 2 * x[2] - 1; +let scalar_non_linear_3d: R^3 -> R, x -> 2 * exp(x[0]) * sin(x[1]) * x[2] + 3; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + CellArray<double> cell_array{mesh_3d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = 2 * x[0] + 3 * x[1] + 2 * x[2] - 1; + cell_array[cell_id][1] = 2 * exp(x[0]) * sin(x[1]) * x[2] + 3; + }); + + CellArray<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + } + + SECTION("interpolate on items list") + { + auto same_cell_value = [](auto interpolated, auto reference) -> bool { + for (size_t i = 0; i < interpolated.nbRows(); ++i) { + for (size_t j = 0; j < interpolated.nbColumns(); ++j) { + if (interpolated[i][j] != reference[i][j]) { + return false; + } + } + } + return true; + }; + + SECTION("1D") + { + constexpr size_t Dimension = 1; + + const auto& mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_1d).xj(); + + Array<const CellId> cell_id_list = [&] { + Array<CellId> cell_ids{mesh_1d->numberOfCells() / 2}; + for (size_t i_cell = 0; i_cell < cell_ids.size(); ++i_cell) { + cell_ids[i_cell] = static_cast<CellId>(2 * i_cell); + } + return cell_ids; + }(); + + std::string_view data = R"( + import math; + let scalar_affine_1d: R^1 -> R, x -> 2*x[0] + 2; + let scalar_non_linear_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3; + )"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + Table<double> cell_array{cell_id_list.size(), 2}; + parallel_for( + cell_id_list.size(), PUGS_LAMBDA(const size_t i) { + const TinyVector<Dimension>& x = xj[cell_id_list[i]]; + cell_array[i][0] = 2 * x[0] + 2; + cell_array[i][1] = 2 * exp(x[0]) + 3; + }); + + Table<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); + + REQUIRE(same_cell_value(cell_array, interpolate_array)); + } + + SECTION("2D") + { + constexpr size_t Dimension = 2; + + const auto& mesh_2d = MeshDataBaseForTests::get().cartesianMesh2D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_2d).xj(); + + Array<CellId> cell_id_list{mesh_2d->numberOfCells() / 2}; + for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) { + cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell); + } + + std::string_view data = R"( +import math; +let scalar_affine_2d: R^2 -> R, x -> 2*x[0] + 3*x[1] + 2; +let scalar_non_linear_2d: R^2 -> R, x -> 2*exp(x[0])*sin(x[1])+3; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + Table<double> cell_array{cell_id_list.size(), 2}; + parallel_for( + cell_id_list.size(), PUGS_LAMBDA(const size_t i) { + const TinyVector<Dimension>& x = xj[cell_id_list[i]]; + cell_array[i][0] = 2 * x[0] + 3 * x[1] + 2; + cell_array[i][1] = 2 * exp(x[0]) * sin(x[1]) + 3; + }); + + Table<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); + + REQUIRE(same_cell_value(cell_array, interpolate_array)); + } + + SECTION("3D") + { + constexpr size_t Dimension = 3; + + const auto& mesh_3d = MeshDataBaseForTests::get().cartesianMesh3D(); + auto xj = MeshDataManager::instance().getMeshData(*mesh_3d).xj(); + + Array<CellId> cell_id_list{mesh_3d->numberOfCells() / 2}; + for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) { + cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell); + } + + std::string_view data = R"( +import math; +let scalar_affine_3d: R^3 -> R, x -> 2 * x[0] + 3 * x[1] + 2 * x[2] - 1; +let scalar_non_linear_3d: R^3 -> R, x -> 2 * exp(x[0]) * sin(x[1]) * x[2] + 3; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("scalar_affine_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + { + auto [i_symbol, found] = symbol_table->find("scalar_non_linear_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + Table<double> cell_array{cell_id_list.size(), 2}; + parallel_for( + cell_id_list.size(), PUGS_LAMBDA(const size_t i) { + const TinyVector<Dimension>& x = xj[cell_id_list[i]]; + cell_array[i][0] = 2 * x[0] + 3 * x[1] + 2 * x[2] - 1; + cell_array[i][1] = 2 * exp(x[0]) * sin(x[1]) * x[2] + 3; + }); + + Table<const double> interpolate_array = + InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); + + REQUIRE(same_cell_value(cell_array, interpolate_array)); + } + } +} -- GitLab