diff --git a/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp index 02d47d399ed9ec0fdb570fc3877b0b384f7a5429..48b51efba5aa8a9a9905247629e9e398a70ab17a 100644 --- a/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp +++ b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp @@ -55,7 +55,8 @@ ItemArrayVariantFunctionInterpoler::_interpolate() const { const ASTNodeDataType data_type = [&] { const auto& function0_descriptor = m_function_id_list[0].descriptor(); - Assert(function0_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t); + Assert(function0_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t or + function0_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::tuple_t); ASTNodeDataType data_type = function0_descriptor.domainMappingNode().children[1]->m_data_type.contentType(); diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 1a4beb32e2889467a4685c79f6afa25e1aef4b68..a4d1f76bd45285294008bfbaa24b6172a2b0b774 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -308,7 +308,9 @@ class PugsFunctionAdapter<OutputType(InputType...)> if constexpr (std::is_arithmetic_v<Value_I_Type>) { return value_i; } else { + // LCOV_EXCL_START throw UnexpectedError("expecting arithmetic type"); + // LCOV_EXCL_STOP } }, value[i]); @@ -316,13 +318,59 @@ class PugsFunctionAdapter<OutputType(InputType...)> return array; } else { + // LCOV_EXCL_START throw UnexpectedError("invalid DataVariant"); + // LCOV_EXCL_STOP + } + }, + result); + }; + } else if constexpr (is_tiny_vector_v<OutputType> or (is_tiny_matrix_v<OutputType>)) { + return [&](DataVariant&& result) -> std::vector<OutputType> { + return std::visit( + [&](auto&& value) -> std::vector<OutputType> { + using ValueType = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<ValueType, AggregateDataVariant>) { + std::vector<OutputType> array(value.size()); + + for (size_t i = 0; i < value.size(); ++i) { + array[i] = std::visit( + [&](auto&& value_i) -> OutputType { + using Value_I_Type = std::decay_t<decltype(value_i)>; + if constexpr (std::is_same_v<Value_I_Type, OutputType>) { + return value_i; + } else if constexpr (OutputType::Dimension == 1) { + if constexpr (std::is_arithmetic_v<Value_I_Type>) { + return OutputType(value_i); + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting arithmetic type"); + // LCOV_EXCL_STOP + } + } else if constexpr (std::is_integral_v<Value_I_Type>) { + // reaching this point, it must be a null vector + // or a null matrix + return OutputType{zero}; + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting arithmetic type"); + // LCOV_EXCL_STOP + } + }, + value[i]); + } + + return array; + } else { + // LCOV_EXCL_START + throw UnexpectedError("invalid DataVariant"); + // LCOV_EXCL_STOP } }, result); }; } else { - throw NotImplementedError("non-arithmetic tuple type"); + throw UnexpectedError("non-arithmetic tuple type"); } } diff --git a/tests/test_InterpolateItemArray.cpp b/tests/test_InterpolateItemArray.cpp index 878503f1156e167796ebd8d3dec4019bc76e59f7..6d020282ad0d0a953ca7e815357d5461f5dd3a76 100644 --- a/tests/test_InterpolateItemArray.cpp +++ b/tests/test_InterpolateItemArray.cpp @@ -175,6 +175,67 @@ let f_1d: R^1 -> (R), x -> (2*x[0] + 2, 2 * exp(x[0]) + 3); } } } + + SECTION("from -> (R^1)") + { + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_1d = named_mesh.mesh(); + + auto xj = MeshDataManager::instance().getMeshData(*mesh_1d).xj(); + + std::string_view data = R"( +import math; +let f_1d: R^1 -> (R^1), x -> (2*x[0] + 2, [2 * exp(x[0]) + 3]); +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("f_1d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + using R1 = TinyVector<1>; + CellArray<R1> cell_array{mesh_1d->connectivity(), 2}; + parallel_for( + cell_array.numberOfItems(), PUGS_LAMBDA(const CellId cell_id) { + const TinyVector<Dimension>& x = xj[cell_id]; + + cell_array[cell_id][0] = R1{2 * x[0] + 2}; + cell_array[cell_id][1] = R1{2 * exp(x[0]) + 3}; + }); + + CellArray<const R1> interpolate_array = + InterpolateItemArray<R1(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + } + } } SECTION("2D") @@ -450,7 +511,7 @@ let f_3d: R^3 -> (R), x -> (2 * x[0] + 3 * x[1] + 2 * x[2] - 1, 2 * exp(x[0]) * SECTION("interpolate on items list") { - auto same_cell_value = [](auto interpolated, auto reference) -> bool { + auto same_cell_array = [](auto interpolated, auto reference) -> bool { for (size_t i = 0; i < interpolated.numberOfRows(); ++i) { for (size_t j = 0; j < interpolated.numberOfColumns(); ++j) { if (interpolated[i][j] != reference[i][j]) { @@ -540,7 +601,7 @@ let scalar_non_linear_1d: R^1 -> R, x -> 2 * exp(x[0]) + 3; InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } @@ -608,7 +669,7 @@ let f_1d: R^1 -> (R), x -> (2*x[0] + 2, 2 * exp(x[0]) + 3); InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } @@ -690,7 +751,7 @@ let scalar_non_linear_2d: R^2 -> R, x -> 2*exp(x[0])*sin(x[1])+3; InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } @@ -755,7 +816,77 @@ let f_2d: R^2 -> (R), x -> (2*x[0] + 3*x[1] + 2, 2*exp(x[0])*sin(x[1])+3); InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + } + } + + SECTION("from -> (R^2x2)") + { + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_2d = named_mesh.mesh(); + + auto xj = MeshDataManager::instance().getMeshData(*mesh_2d).xj(); + + Array<CellId> cell_id_list{mesh_2d->numberOfCells() / 2}; + for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) { + cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell); + } + + std::string_view data = R"( +import math; +let f_2d: R^2 -> (R^2x2), x -> ([[x[0],0],[2-x[1], x[0]*x[1]]], [[2*x[0], x[1]],[2, -x[1]]]); +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("f_2d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + using R2x2 = TinyMatrix<2>; + + Table<R2x2> cell_array{cell_id_list.size(), 2}; + parallel_for( + cell_id_list.size(), PUGS_LAMBDA(const size_t i) { + const TinyVector<Dimension>& x = xj[cell_id_list[i]]; + + cell_array[i][0] = R2x2{x[0], 0, // + 2 - x[1], x[0] * x[1]}; + + cell_array[i][1] = R2x2{2 * x[0], x[1], // + 2, -x[1]}; + }); + + Table<const R2x2> interpolate_array = + InterpolateItemArray<R2x2(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } @@ -837,7 +968,7 @@ let scalar_non_linear_3d: R^3 -> R, x -> 2 * exp(x[0]) * sin(x[1]) * x[2] + 3; InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } @@ -902,7 +1033,74 @@ let f_3d: R^3 -> (R), x -> (2 * x[0] + 3 * x[1] + 2 * x[2] - 1, 2 * exp(x[0]) * InterpolateItemArray<double(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); - REQUIRE(same_cell_value(cell_array, interpolate_array)); + REQUIRE(same_cell_array(cell_array, interpolate_array)); + } + } + } + + SECTION("from -> (R^3)") + { + for (const auto& named_mesh : mesh_list) { + SECTION(named_mesh.name()) + { + auto mesh_3d = named_mesh.mesh(); + + auto xj = MeshDataManager::instance().getMeshData(*mesh_3d).xj(); + + Array<CellId> cell_id_list{mesh_3d->numberOfCells() / 2}; + for (size_t i_cell = 0; i_cell < cell_id_list.size(); ++i_cell) { + cell_id_list[i_cell] = static_cast<CellId>(2 * i_cell); + } + + std::string_view data = R"( +import math; +let f_3d: R^3 -> (R^3), x -> (2*x, [2*x[0]-x[1], 3*x[2]-x[0], x[1]+x[2]], 0); +)"; + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + TAO_PEGTL_NAMESPACE::position position{TAO_PEGTL_NAMESPACE::internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this point + + std::vector<FunctionSymbolId> function_symbol_id_list; + + { + auto [i_symbol, found] = symbol_table->find("f_3d", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + function_symbol_id_list.push_back( + FunctionSymbolId(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table)); + } + + using R3 = TinyVector<3>; + Table<R3> cell_array{cell_id_list.size(), 3}; + parallel_for( + cell_id_list.size(), PUGS_LAMBDA(const size_t i) { + const R3& x = xj[cell_id_list[i]]; + + cell_array[i][0] = 2 * x; + cell_array[i][1] = R3{2 * x[0] - x[1], 3 * x[2] - x[0], x[1] + x[2]}; + cell_array[i][2] = zero; + }); + + Table<const R3> interpolate_array = + InterpolateItemArray<R3(TinyVector<Dimension>)>::interpolate(function_symbol_id_list, xj, cell_id_list); + + REQUIRE(same_cell_array(cell_array, interpolate_array)); } } } diff --git a/tests/test_ItemArrayVariantFunctionInterpoler.cpp b/tests/test_ItemArrayVariantFunctionInterpoler.cpp index 640e3e77aa5e6ede05fe0f217d7587f7839e4d56..7ceaf2207ef53526f19511aa31415ce8efcd577a 100644 --- a/tests/test_ItemArrayVariantFunctionInterpoler.cpp +++ b/tests/test_ItemArrayVariantFunctionInterpoler.cpp @@ -386,6 +386,33 @@ let R3x3_non_linear_1d: R^1 -> R^3x3, x -> [[2 * exp(x[0]) * sin(x[0]) + 3, sin( REQUIRE(same_item_array(cell_array, item_array_variant->get<CellArray<DataType>>())); } + + SECTION("different types") + { + auto [i_symbol_f1, found_f1] = symbol_table->find("R_scalar_non_linear1_1d", position); + REQUIRE(found_f1); + REQUIRE(i_symbol_f1->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function1_symbol_id(std::get<uint64_t>(i_symbol_f1->attributes().value()), symbol_table); + + auto [i_symbol_f2, found_f2] = symbol_table->find("Z_scalar_non_linear2_1d", position); + REQUIRE(found_f2); + REQUIRE(i_symbol_f2->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function2_symbol_id(std::get<uint64_t>(i_symbol_f2->attributes().value()), symbol_table); + + auto [i_symbol_f3, found_f3] = symbol_table->find("R_scalar_non_linear3_1d", position); + REQUIRE(found_f3); + REQUIRE(i_symbol_f3->attributes().dataType() == ASTNodeDataType::function_t); + + FunctionSymbolId function3_symbol_id(std::get<uint64_t>(i_symbol_f3->attributes().value()), symbol_table); + + ItemArrayVariantFunctionInterpoler interpoler(mesh_1d, ItemType::cell, + {function1_symbol_id, function2_symbol_id, + function3_symbol_id}); + + REQUIRE_THROWS_WITH(interpoler.interpolate(), "error: functions must have the same type"); + } } } }