diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index cd81202905224b896081bffdf1baaa2ee64407cc..1f2d0ce4db829694dca9875593f168ebcac47b6a 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -5,6 +5,7 @@ #include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/FunctionTable.hpp> +#include <language/utils/ItemArrayVariantFunctionInterpoler.hpp> #include <language/utils/ItemValueVariantFunctionInterpoler.hpp> #include <language/utils/OStream.hpp> #include <language/utils/OperatorRepository.hpp> @@ -129,6 +130,17 @@ MeshModule::MeshModule() )); + this->_addBuiltinFunction( + "interpolate_array", + std::function( + + [](std::shared_ptr<const IMesh> mesh, std::shared_ptr<const ItemType> item_type, + const std::vector<FunctionSymbolId>& function_id_list) -> std::shared_ptr<const ItemArrayVariant> { + return ItemArrayVariantFunctionInterpoler{mesh, *item_type, function_id_list}.interpolate(); + } + + )); + this->_addBuiltinFunction("transform", std::function( [](std::shared_ptr<const IMesh> p_mesh, diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index a81ffa8aecdf7af0085296b181f2b3ec4d63238f..db3e7856b758b6a4822d97ac0330958b3b367176 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -28,6 +28,7 @@ add_library(PugsLanguageUtils FunctionSymbolId.cpp IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForZ.cpp + ItemArrayVariantFunctionInterpoler.cpp ItemValueVariantFunctionInterpoler.cpp OFStream.cpp OperatorRepository.cpp diff --git a/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02d47d399ed9ec0fdb570fc3877b0b384f7a5429 --- /dev/null +++ b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp @@ -0,0 +1,159 @@ +#include <language/utils/ItemArrayVariantFunctionInterpoler.hpp> + +#include <language/utils/InterpolateItemArray.hpp> +#include <mesh/Connectivity.hpp> +#include <mesh/ItemArrayVariant.hpp> +#include <mesh/Mesh.hpp> +#include <mesh/MeshData.hpp> +#include <mesh/MeshDataManager.hpp> +#include <utils/Exceptions.hpp> + +#include <memory> + +template <size_t Dimension, typename DataType, typename ArrayType> +std::shared_ptr<ItemArrayVariant> +ItemArrayVariantFunctionInterpoler::_interpolate() const +{ + std::shared_ptr p_mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh); + using MeshDataType = MeshData<Dimension>; + + switch (m_item_type) { + case ItemType::cell: { + MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh); + return std::make_shared<ItemArrayVariant>( + InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(m_function_id_list, + mesh_data.xj())); + } + case ItemType::face: { + MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh); + return std::make_shared<ItemArrayVariant>( + InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::face>(m_function_id_list, + mesh_data.xl())); + } + case ItemType::edge: { + MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh); + return std::make_shared<ItemArrayVariant>( + InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::edge>(m_function_id_list, + mesh_data.xe())); + } + case ItemType::node: { + return std::make_shared<ItemArrayVariant>( + InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::node>(m_function_id_list, + p_mesh->xr())); + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("invalid item type"); + } + // LCOV_EXCL_STOP + } +} + +template <size_t Dimension> +std::shared_ptr<ItemArrayVariant> +ItemArrayVariantFunctionInterpoler::_interpolate() const +{ + const ASTNodeDataType data_type = [&] { + const auto& function0_descriptor = m_function_id_list[0].descriptor(); + Assert(function0_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t); + + ASTNodeDataType data_type = function0_descriptor.domainMappingNode().children[1]->m_data_type.contentType(); + + for (size_t i = 1; i < m_function_id_list.size(); ++i) { + const auto& function_descriptor = m_function_id_list[i].descriptor(); + Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t); + if (data_type != function_descriptor.domainMappingNode().children[1]->m_data_type.contentType()) { + throw NormalError("functions must have the same type"); + } + } + + return data_type; + }(); + + switch (data_type) { + case ASTNodeDataType::bool_t: { + return this->_interpolate<Dimension, bool>(); + } + case ASTNodeDataType::unsigned_int_t: { + return this->_interpolate<Dimension, uint64_t>(); + } + case ASTNodeDataType::int_t: { + return this->_interpolate<Dimension, int64_t>(); + } + case ASTNodeDataType::double_t: { + return this->_interpolate<Dimension, double>(); + } + case ASTNodeDataType::vector_t: { + switch (data_type.dimension()) { + case 1: { + return this->_interpolate<Dimension, TinyVector<1>>(); + } + case 2: { + return this->_interpolate<Dimension, TinyVector<2>>(); + } + case 3: { + return this->_interpolate<Dimension, TinyVector<3>>(); + } + // LCOV_EXCL_START + default: { + std::ostringstream os; + os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset; + + throw UnexpectedError(os.str()); + } + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::matrix_t: { + Assert(data_type.numberOfColumns() == data_type.numberOfRows(), "undefined matrix type"); + switch (data_type.numberOfColumns()) { + case 1: { + return this->_interpolate<Dimension, TinyMatrix<1>>(); + } + case 2: { + return this->_interpolate<Dimension, TinyMatrix<2>>(); + } + case 3: { + return this->_interpolate<Dimension, TinyMatrix<3>>(); + } + // LCOV_EXCL_START + default: { + std::ostringstream os; + os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset; + + throw UnexpectedError(os.str()); + } + // LCOV_EXCL_STOP + } + } + // LCOV_EXCL_START + default: { + std::ostringstream os; + os << "invalid interpolation array type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset; + + throw UnexpectedError(os.str()); + } + // LCOV_EXCL_STOP + } +} + +std::shared_ptr<ItemArrayVariant> +ItemArrayVariantFunctionInterpoler::interpolate() const +{ + switch (m_mesh->dimension()) { + case 1: { + return this->_interpolate<1>(); + } + case 2: { + return this->_interpolate<2>(); + } + case 3: { + return this->_interpolate<3>(); + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("invalid dimension"); + } + // LCOV_EXCL_STOP + } +} diff --git a/src/language/utils/ItemArrayVariantFunctionInterpoler.hpp b/src/language/utils/ItemArrayVariantFunctionInterpoler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eac61b5fb067e83c2cf8f132060e68884316cad4 --- /dev/null +++ b/src/language/utils/ItemArrayVariantFunctionInterpoler.hpp @@ -0,0 +1,38 @@ +#ifndef ITEM_ARRAY_VARIANT_FUNCTION_INTERPOLER_HPP +#define ITEM_ARRAY_VARIANT_FUNCTION_INTERPOLER_HPP + +#include <language/utils/FunctionSymbolId.hpp> +#include <mesh/IMesh.hpp> +#include <mesh/IZoneDescriptor.hpp> +#include <mesh/ItemArrayVariant.hpp> +#include <mesh/ItemType.hpp> + +class ItemArrayVariantFunctionInterpoler +{ + private: + std::shared_ptr<const IMesh> m_mesh; + const ItemType m_item_type; + const std::vector<FunctionSymbolId> m_function_id_list; + + template <size_t Dimension, typename DataType, typename ArrayType = DataType> + std::shared_ptr<ItemArrayVariant> _interpolate() const; + + template <size_t Dimension> + std::shared_ptr<ItemArrayVariant> _interpolate() const; + + public: + std::shared_ptr<ItemArrayVariant> interpolate() const; + + ItemArrayVariantFunctionInterpoler(const std::shared_ptr<const IMesh>& mesh, + const ItemType& item_type, + const std::vector<FunctionSymbolId>& function_id_list) + : m_mesh{mesh}, m_item_type{item_type}, m_function_id_list{function_id_list} + {} + + ItemArrayVariantFunctionInterpoler(const ItemArrayVariantFunctionInterpoler&) = delete; + ItemArrayVariantFunctionInterpoler(ItemArrayVariantFunctionInterpoler&&) = delete; + + ~ItemArrayVariantFunctionInterpoler() = default; +}; + +#endif // ITEM_ARRAY_VARIANT_FUNCTION_INTERPOLER_HPP diff --git a/src/language/utils/ItemValueVariantFunctionInterpoler.cpp b/src/language/utils/ItemValueVariantFunctionInterpoler.cpp index e10e834e55a126823414316ce5d527716befd761..98dbe1658b45b8b8613892d0cfa72e67ccc67700 100644 --- a/src/language/utils/ItemValueVariantFunctionInterpoler.cpp +++ b/src/language/utils/ItemValueVariantFunctionInterpoler.cpp @@ -60,13 +60,13 @@ ItemValueVariantFunctionInterpoler::_interpolate() const switch (data_type) { case ASTNodeDataType::bool_t: { - return this->_interpolate<Dimension, bool, double>(); + return this->_interpolate<Dimension, bool>(); } case ASTNodeDataType::unsigned_int_t: { - return this->_interpolate<Dimension, uint64_t, double>(); + return this->_interpolate<Dimension, uint64_t>(); } case ASTNodeDataType::int_t: { - return this->_interpolate<Dimension, int64_t, double>(); + return this->_interpolate<Dimension, int64_t>(); } case ASTNodeDataType::double_t: { return this->_interpolate<Dimension, double>(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a3403df2dd5311ba94aed1c622af194f9b7db23..9f1f7d9005d6aa8a583283bf2bfda8dc6e23eedd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -179,6 +179,7 @@ add_executable (mpi_unit_tests test_ItemArray.cpp test_ItemArrayUtils.cpp test_ItemArrayVariant.cpp + test_ItemArrayVariantFunctionInterpoler.cpp test_ItemValue.cpp test_ItemValueUtils.cpp test_ItemValueVariant.cpp diff --git a/tests/test_ItemArrayVariantFunctionInterpoler.cpp b/tests/test_ItemArrayVariantFunctionInterpoler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..640e3e77aa5e6ede05fe0f217d7587f7839e4d56 --- /dev/null +++ b/tests/test_ItemArrayVariantFunctionInterpoler.cpp @@ -0,0 +1,648 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTModulesImporter.hpp> +#include <language/ast/ASTNodeDataTypeBuilder.hpp> +#include <language/ast/ASTNodeExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionEvaluationExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionExpressionBuilder.hpp> +#include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/utils/PugsFunctionAdapter.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <MeshDataBaseForTests.hpp> +#include <mesh/Connectivity.hpp> +#include <mesh/Mesh.hpp> +#include <mesh/MeshData.hpp> +#include <mesh/MeshDataManager.hpp> + +#include <language/utils/ItemArrayVariantFunctionInterpoler.hpp> + +#include <pegtl/string_input.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("ItemArrayVariantFunctionInterpoler", "[scheme]") +{ + auto same_item_array = [](auto f, auto g) -> bool { + using ItemIdType = typename decltype(f)::index_type; + if (f.sizeOfArrays() != g.sizeOfArrays()) { + return false; + } + + for (ItemIdType item_id = 0; item_id < f.numberOfItems(); ++item_id) { + for (size_t i = 0; i < f.sizeOfArrays(); ++i) { + if (f[item_id][i] != g[item_id][i]) { + return false; + } + } + } + + return true; + }; + + SECTION("1D") + { + constexpr size_t Dimension = 1; + + std::array mesh_list = MeshDataBaseForTests::get().all1DMeshes(); + + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_1d = named_mesh.mesh(); + + auto xj = MeshDataManager::instance().getMeshData(*mesh_1d).xj(); + auto xr = mesh_1d->xr(); + + std::string_view data = R"( +import math; +let B_scalar_non_linear1_1d: R^1 -> B, x -> (exp(2 * x[0]) + 3 > 4); +let B_scalar_non_linear2_1d: R^1 -> B, x -> (exp(2 * x[0]) + 3 < 4); + +let N_scalar_non_linear1_1d: R^1 -> N, x -> floor(3 * x[0] * x[0] + 2); +let N_scalar_non_linear2_1d: R^1 -> N, x -> floor(2 * x[0] * x[0]); + +let Z_scalar_non_linear1_1d: R^1 -> Z, x -> floor(exp(2 * x[0]) - 1); +let Z_scalar_non_linear2_1d: R^1 -> Z, x -> floor(cos(2 * x[0]) + 0.5); + +let R_scalar_non_linear1_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3; +let R_scalar_non_linear2_1d: R^1 -> R, x -> 2 * sin(x[0]) + 1; +let R_scalar_non_linear3_1d: R^1 -> R, x -> x[0] * sin(x[0]); + +let R1_non_linear1_1d: R^1 -> R^1, x -> 2 * exp(x[0]); +let R1_non_linear2_1d: R^1 -> R^1, x -> 2 * exp(x[0])*x[0]; + +let R2_non_linear_1d: R^1 -> R^2, x -> [2 * exp(x[0]), -3*x[0]]; + +let R3_non_linear_1d: R^1 -> R^3, x -> [2 * exp(x[0]) + 3, x[0] - 2, 3]; + +let R1x1_non_linear_1d: R^1 -> R^1x1, x -> (2 * exp(x[0]) * sin(x[0]) + 3); + +let R2x2_non_linear_1d: R^1 -> R^2x2, x -> [[2 * exp(x[0]) * sin(x[0]) + 3, sin(x[0] - 2 * x[0])], [3, x[0] * x[0]]]; + +let R3x3_non_linear_1d: R^1 -> R^3x3, x -> [[2 * exp(x[0]) * sin(x[0]) + 3, sin(x[0] - 2 * x[0]), 3], [x[0] * x[0], -4*x[0], 2*x[0]+1], [3, -6*x[0], exp(x[0])]]; +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + SECTION("B_scalar_non_linear_1d") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("B_scalar_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("B_scalar_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + CellArray<bool> cell_array{mesh_1d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = std::exp(2 * x[0]) + 3 > 4; + cell_array[cell_id][1] = std::exp(2 * x[0]) + 3 < 4; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_data_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_data_variant->get<CellArray<bool>>())); + } + + SECTION("N_scalar_non_linear_1d") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("N_scalar_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("N_scalar_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + NodeArray<uint64_t> node_array{mesh_1d->connectivity(), 2}; + parallel_for( + node_array.numberOfItems(), PUGS_LAMBDA(const NodeId node_id) { + const TinyVector<Dimension>& x = xr[node_id]; + node_array[node_id][0] = std::floor(3 * x[0] * x[0] + 2); + node_array[node_id][1] = std::floor(2 * x[0] * x[0]); + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::node, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_data_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(node_array, item_data_variant->get<NodeArray<uint64_t>>())); + } + + SECTION("Z_scalar_non_linear_1d") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("Z_scalar_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("Z_scalar_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + CellArray<int64_t> cell_array{mesh_1d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = std::floor(std::exp(2 * x[0]) - 1); + cell_array[cell_id][1] = std::floor(cos(2 * x[0]) + 0.5); + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<int64_t>>())); + } + + SECTION("R_scalar_non_linear_1d") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("R_scalar_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("R_scalar_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + auto [i_symbol_f3, found_f3] = symbol_table->find("R_scalar_non_linear3_1d", position); + REQUIRE(found_f3); + REQUIRE(i_symbol_f3->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function3_symbol_id(std::get<uint64_t>(i_symbol_f3->attributes().value()), symbol_table); + + CellArray<double> cell_array{mesh_1d->connectivity(), 3}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + + cell_array[cell_id][0] = 2 * std::exp(x[0]) + 3; + cell_array[cell_id][1] = 2 * std::sin(x[0]) + 1; + cell_array[cell_id][2] = x[0] * std::sin(x[0]); + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, + {function1_symbol_id, function2_symbol_id, + function3_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<double>>())); + } + + SECTION("R1_non_linear_1d") + { + using DataType = TinyVector<1>; + + auto [i_symbol_f1, found_f1] = symbol_table->find("R1_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("R1_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + + cell_array[cell_id][0] = DataType{2 * std::exp(x[0])}; + cell_array[cell_id][1] = DataType{2 * std::exp(x[0]) * x[0]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + + SECTION("R2_non_linear_1d") + { + using DataType = TinyVector<2>; + + auto [i_symbol, found] = symbol_table->find("R2_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 1}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + + cell_array[cell_id][0] = DataType{2 * std::exp(x[0]), -3 * x[0]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + + SECTION("R3_non_linear_1d") + { + using DataType = TinyVector<3>; + + auto [i_symbol, found] = symbol_table->find("R3_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 1}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = DataType{2 * std::exp(x[0]) + 3, x[0] - 2, 3}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + + SECTION("R1x1_non_linear_1d") + { + using DataType = TinyMatrix<1>; + + auto [i_symbol, found] = symbol_table->find("R1x1_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 1}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = DataType{2 * std::exp(x[0]) * std::sin(x[0]) + 3}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + + SECTION("R2x2_non_linear_1d") + { + using DataType = TinyMatrix<2>; + + auto [i_symbol, found] = symbol_table->find("R2x2_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 1}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + cell_array[cell_id][0] = + DataType{2 * std::exp(x[0]) * std::sin(x[0]) + 3, std::sin(x[0] - 2 * x[0]), 3, x[0] * x[0]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + + SECTION("R3x3_non_linear_1d") + { + using DataType = TinyMatrix<3>; + + auto [i_symbol, found] = symbol_table->find("R3x3_non_linear_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + CellArray<DataType> cell_array{mesh_1d->connectivity(), 1}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + + cell_array[cell_id][0] = DataType{2 * exp(x[0]) * std::sin(x[0]) + 3, + std::sin(x[0] - 2 * x[0]), + 3, + x[0] * x[0], + -4 * x[0], + 2 * x[0] + 1, + 3, + -6 * x[0], + std::exp(x[0])}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); + } + } + } + } + + SECTION("2D") + { + constexpr size_t Dimension = 2; + + std::array mesh_list = MeshDataBaseForTests::get().all2DMeshes(); + + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_2d = named_mesh.mesh(); + + auto xl = MeshDataManager::instance().getMeshData(*mesh_2d).xl(); + + std::string_view data = R"( +import math; +let B_scalar_non_linear1_2d: R^2 -> B, x -> (exp(2 * x[0])< 2*x[1]); +let B_scalar_non_linear2_2d: R^2 -> B, x -> (sin(2 * x[0])< x[1]); + +let R2_non_linear_2d: R^2 -> R^2, x -> [2 * exp(x[0]), -3*x[1]]; + +let R3x3_non_linear_2d: R^2 -> R^3x3, x -> [[2 * exp(x[0]) * sin(x[1]) + 3, sin(x[1] - 2 * x[0]), 3], + [x[1] * x[0], -4*x[1], 2*x[0]+1], + [3, -6*x[0], exp(x[1])]]; + )"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + SECTION("B_scalar_non_linear_2d") + { + using DataType = bool; + + auto [i_symbol_f1, found_f1] = symbol_table->find("B_scalar_non_linear1_2d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("B_scalar_non_linear2_2d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + FaceArray<DataType> face_array{mesh_2d->connectivity(), 2}; + parallel_for( + face_array.numberOfItems(), PUGS_LAMBDA(const FaceId face_id) { + const TinyVector<Dimension>& x = xl[face_id]; + face_array[face_id][0] = std::exp(2 * x[0]) < 2 * x[1]; + face_array[face_id][1] = std::sin(2 * x[0]) < x[1]; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_2d, ItemType::face, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(face_array, item_array_variant->get<FaceArray<DataType>>())); + } + + SECTION("R2_non_linear_2d") + { + using DataType = TinyVector<2>; + + auto [i_symbol, found] = symbol_table->find("R2_non_linear_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + FaceArray<DataType> face_array{mesh_2d->connectivity(), 1}; + parallel_for( + face_array.numberOfItems(), PUGS_LAMBDA(const FaceId face_id) { + const TinyVector<Dimension>& x = xl[face_id]; + + face_array[face_id][0] = DataType{2 * std::exp(x[0]), -3 * x[1]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_2d, ItemType::face, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(face_array, item_array_variant->get<FaceArray<DataType>>())); + } + + SECTION("R3x3_non_linear_2d") + { + using DataType = TinyMatrix<3>; + + auto [i_symbol, found] = symbol_table->find("R3x3_non_linear_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + FaceArray<DataType> face_array{mesh_2d->connectivity(), 1}; + parallel_for( + face_array.numberOfItems(), PUGS_LAMBDA(const FaceId face_id) { + const TinyVector<Dimension>& x = xl[face_id]; + + face_array[face_id][0] = DataType{2 * std::exp(x[0]) * std::sin(x[1]) + 3, + std::sin(x[1] - 2 * x[0]), + 3, + x[1] * x[0], + -4 * x[1], + 2 * x[0] + 1, + 3, + -6 * x[0], + std::exp(x[1])}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_2d, ItemType::face, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(face_array, item_array_variant->get<FaceArray<DataType>>())); + } + } + } + } + + SECTION("3D") + { + constexpr size_t Dimension = 3; + + std::array mesh_list = MeshDataBaseForTests::get().all3DMeshes(); + + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_3d = named_mesh.mesh(); + + auto xe = MeshDataManager::instance().getMeshData(*mesh_3d).xe(); + + std::string_view data = R"( +import math; +let R_scalar_non_linear1_3d: R^3 -> R, x -> 2 * exp(x[0]+x[2]) + 3 * x[1]; +let R_scalar_non_linear2_3d: R^3 -> R, x -> 3 * sin(x[0]+x[2]) + 2 * x[1]; + +let R3_non_linear_3d: R^3 -> R^3, x -> [2 * exp(x[0]) + 3, x[1] - 2, 3 * x[2]]; +let R2x2_non_linear_3d: R^3 -> R^2x2, + x -> [[2 * exp(x[0]) * sin(x[1]) + 3, sin(x[2] - 2 * x[0])], + [3, x[1] * x[0] - x[2]]]; + )"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + SECTION("R_scalar_non_linear_3d") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("R_scalar_non_linear1_3d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("R_scalar_non_linear2_3d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + EdgeArray<double> edge_array{mesh_3d->connectivity(), 2}; + parallel_for( + edge_array.numberOfItems(), PUGS_LAMBDA(const EdgeId edge_id) { + const TinyVector<Dimension>& x = xe[edge_id]; + + edge_array[edge_id][0] = 2 * std::exp(x[0] + x[2]) + 3 * x[1]; + edge_array[edge_id][1] = 3 * std::sin(x[0] + x[2]) + 2 * x[1]; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_3d, ItemType::edge, + {function1_symbol_id, function2_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(edge_array, item_array_variant->get<EdgeArray<double>>())); + } + + SECTION("R3_non_linear_3d") + { + using DataType = TinyVector<3>; + + auto [i_symbol, found] = symbol_table->find("R3_non_linear_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + EdgeArray<DataType> edge_array{mesh_3d->connectivity(), 1}; + parallel_for( + edge_array.numberOfItems(), PUGS_LAMBDA(const EdgeId edge_id) { + const TinyVector<Dimension>& x = xe[edge_id]; + + edge_array[edge_id][0] = DataType{2 * std::exp(x[0]) + 3, x[1] - 2, 3 * x[2]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_3d, ItemType::edge, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(edge_array, item_array_variant->get<EdgeArray<DataType>>())); + } + + SECTION("R2x2_non_linear_3d") + { + using DataType = TinyMatrix<2>; + + auto [i_symbol, found] = symbol_table->find("R2x2_non_linear_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + EdgeArray<DataType> edge_array{mesh_3d->connectivity(), 1}; + parallel_for( + edge_array.numberOfItems(), PUGS_LAMBDA(const EdgeId edge_id) { + const TinyVector<Dimension>& x = xe[edge_id]; + edge_array[edge_id][0] = + DataType{2 * std::exp(x[0]) * std::sin(x[1]) + 3, std::sin(x[2] - 2 * x[0]), 3, x[1] * x[0] - x[2]}; + }); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_3d, ItemType::edge, {function_symbol_id}); + std::shared_ptr item_array_variant = interpoler.interpolate(); + + REQUIRE(same_item_array(edge_array, item_array_variant->get<EdgeArray<DataType>>())); + } + } + } + } +}