diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 2e56fe228bc4b5d0ae863ae177b3633efbfa6e90..31cde714627687bb0aee8f62ff122cf2e5ae0a1b 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -310,6 +310,64 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<const double>>); } + SECTION("(R^2) -> (R) BuiltinFunctionEmbedder") + { + std::function c = [](const std::vector<TinyVector<2>>& x) -> std::vector<double> { + std::vector<double> sum(x.size()); + for (size_t i = 0; i < x.size(); ++i) { + sum[i] = x[i][0] + x[i][1]; + } + return sum; + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = + std::make_unique<BuiltinFunctionEmbedder<std::vector<double>(const std::vector<TinyVector<2>>&)>>(c); + + REQUIRE(i_embedded_c->numberOfParameters() == 1); + REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1); + REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t); + REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::tuple_t); + + std::vector<TinyVector<2>> x = {TinyVector<2>{1, 2}, TinyVector<2>{3, 4}, TinyVector<2>{-2, 4}}; + std::vector<double> result = c(x); + + const std::vector value = std::get<std::vector<double>>(i_embedded_c->apply(std::vector<DataVariant>{x})); + + for (size_t i = 0; i < result.size(); ++i) { + REQUIRE(value[i] == result[i]); + } + } + + SECTION("std::vector<EmbeddedData>(EmbeddedData, EmbeddedData) BuiltinFunctionEmbedder") + { + std::function sum = [&](const uint64_t& i, const uint64_t& j) -> std::vector<std::shared_ptr<const uint64_t>> { + std::vector<std::shared_ptr<const uint64_t>> x = {std::make_shared<const uint64_t>(i), + std::make_shared<const uint64_t>(j)}; + return x; + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique< + BuiltinFunctionEmbedder<std::vector<std::shared_ptr<const uint64_t>>(const uint64_t&, const uint64_t&)>>(sum); + + REQUIRE(i_embedded_c->numberOfParameters() == 2); + REQUIRE(i_embedded_c->getParameterDataTypes().size() == 2); + REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::unsigned_int_t); + REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t); + REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::tuple_t); + + std::vector<uint64_t> values{3, 4}; + + std::vector<EmbeddedData> embedded_data_list = + std::get<std::vector<EmbeddedData>>(i_embedded_c->apply(std::vector<DataVariant>{values[0], values[1]})); + + for (size_t i = 0; i < embedded_data_list.size(); ++i) { + const EmbeddedData& embedded_data = embedded_data_list[i]; + const IDataHandler& handled_data = embedded_data.get(); + const DataHandler<const uint64_t>& data = dynamic_cast<const DataHandler<const uint64_t>&>(handled_data); + REQUIRE(*data.data_ptr() == values[i]); + } + } + SECTION("void -> N*R*shared_const_double BuiltinFunctionEmbedder") { std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<const double>> {