diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index 4bf7e56eebfd93b244d13a774e2e69aac2834794..9e7eec757fb2c396125d6d299f5a7529ca420009 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -31,9 +31,20 @@ class FunctionArgumentToStringConverter final : public IFunctionArgumentConverte DataVariant convert(ExecutionPolicy& exec_policy, DataVariant&& value) { - std::ostringstream sout; - sout << value; - exec_policy.currentContext()[m_argument_id] = sout.str(); + std::visit( + [&](auto&& v) { + using T = std::decay_t<decltype(v)>; + if constexpr (std::is_arithmetic_v<T>) { + exec_policy.currentContext()[m_argument_id] = std::to_string(v); + } else if constexpr (std::is_same_v<T, std::string>) { + exec_policy.currentContext()[m_argument_id] = v; + } else { + std::ostringstream sout; + sout << value << std::ends; + exec_policy.currentContext()[m_argument_id] = sout.str(); + } + }, + value); return {}; } @@ -100,6 +111,7 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; } else { + static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>); exec_policy.currentContext()[m_argument_id] = std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); } @@ -132,19 +144,36 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter TupleType list_value; list_value.reserve(v.size()); for (size_t i = 0; i < v.size(); ++i) { - list_value.emplace_back(v[i]); + list_value.emplace_back(std::move(v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); + } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType>) { + TupleType list_value; + list_value.reserve(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + list_value.push_back(static_cast<ContentType>(v[i])); + } + exec_policy.currentContext()[m_argument_id] = std::move(list_value); + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { - throw UnexpectedError(demangle<ValueT>() + " unexpected value type"); + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + + demangle<ContentType>() + "'"); } }, value); + } else { - throw UnexpectedError(demangle<std::decay_t<decltype(*this)>>() + ": did nothing!"); + // LCOV_EXCL_START + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } return {}; } @@ -174,31 +203,43 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter std::visit( [&](auto&& vi) { using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (is_tiny_vector_v<ContentType>) { - throw NotImplementedError("TinyVector case"); + if constexpr (std::is_same_v<Vi_T, ContentType>) { + list_value.emplace_back(vi); + } else if constexpr (is_tiny_vector_v<ContentType>) { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<Vi_T>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } else if constexpr (std::is_convertible_v<Vi_T, ContentType>) { list_value.emplace_back(vi); } else { + // LCOV_EXCL_START throw UnexpectedError("unexpected types"); + // LCOV_EXCL_STOP } }, (v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); - } else if constexpr (std::is_same_v<ValueT, ContentType>) { - exec_policy.currentContext()[m_argument_id] = std::move(v); } else if constexpr (is_std_vector_v<ValueT>) { using ContentT = typename ValueT::value_type; if constexpr (std::is_same_v<ContentT, ContentType>) { - TupleType list_value; - list_value.reserve(v.size()); - for (size_t i = 0; i < v.size(); ++i) { - list_value.emplace_back(v[i]); - } - exec_policy.currentContext()[m_argument_id] = std::move(list_value); + exec_policy.currentContext()[m_argument_id] = v; + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<ContentT>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } + } else if constexpr (std::is_same_v<ValueT, ContentType>) { + exec_policy.currentContext()[m_argument_id] = std::move(TupleType{v}); + } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ValueT> and + not is_tiny_vector_v<ContentType>) { + exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { + // LCOV_EXCL_START throw UnexpectedError(demangle<ValueT>() + " unexpected value type"); + // LCOV_EXCL_STOP } }, value); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6796d69e488ff3a1d01693c4a17bdb364ef99933..094efbd440a452bcc7a323299f39eb52dcd3046b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -58,6 +58,7 @@ add_executable (unit_tests test_ExecutionPolicy.cpp test_FakeProcessor.cpp test_ForProcessor.cpp + test_FunctionArgumentConverter.cpp test_FunctionProcessor.cpp test_FunctionSymbolId.cpp test_FunctionTable.cpp diff --git a/tests/test_FunctionArgumentConverter.cpp b/tests/test_FunctionArgumentConverter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3d2ee0939c306fb971274c782268b161125bee6 --- /dev/null +++ b/tests/test_FunctionArgumentConverter.cpp @@ -0,0 +1,160 @@ +#include <catch2/catch.hpp> + +#include <language/node_processor/FunctionArgumentConverter.hpp> +#include <language/utils/SymbolTable.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("FunctionArgumentConverter", "[language]") +{ + ExecutionPolicy::Context context(0, std::make_shared<ExecutionPolicy::Context::Values>(3)); + ExecutionPolicy execution_policy(ExecutionPolicy{}, context); + + SECTION("FunctionArgumentToStringConverter") + { + const std::string s{"foo"}; + FunctionArgumentToStringConverter converter0{0}; + converter0.convert(execution_policy, s); + + const TinyVector<3> X{1, 3.2, 4}; + FunctionArgumentToStringConverter converter1{1}; + converter1.convert(execution_policy, X); + std::ostringstream os_X; + os_X << X << std::ends; + + const double x = 3.2; + FunctionArgumentToStringConverter converter2{2}; + converter2.convert(execution_policy, x); + + REQUIRE(std::get<std::string>(execution_policy.currentContext()[0]) == s); + REQUIRE(std::get<std::string>(execution_policy.currentContext()[1]) == os_X.str()); + REQUIRE(std::get<std::string>(execution_policy.currentContext()[2]) == std::to_string(x)); + } + + SECTION("FunctionArgumentConverter") + { + const double double_value = 1.7; + FunctionArgumentConverter<double, double> converter0{0}; + converter0.convert(execution_policy, double{double_value}); + + const uint64_t uint64_value = 3; + FunctionArgumentConverter<double, uint64_t> converter1{1}; + converter1.convert(execution_policy, uint64_value); + + const bool bool_value = false; + FunctionArgumentConverter<uint64_t, bool> converter2{2}; + converter2.convert(execution_policy, bool_value); + + REQUIRE(std::get<double>(execution_policy.currentContext()[0]) == double_value); + REQUIRE(std::get<double>(execution_policy.currentContext()[1]) == static_cast<double>(uint64_value)); + REQUIRE(std::get<uint64_t>(execution_policy.currentContext()[2]) == static_cast<uint64_t>(bool_value)); + } + + SECTION("FunctionTinyVectorArgumentConverter") + { + const TinyVector<3> x3{1.7, 2.9, -3}; + FunctionTinyVectorArgumentConverter<TinyVector<3>, TinyVector<3>> converter0{0}; + converter0.convert(execution_policy, TinyVector{x3}); + + const double x1 = 6.3; + FunctionTinyVectorArgumentConverter<TinyVector<1>, double> converter1{1}; + converter1.convert(execution_policy, double{x1}); + + AggregateDataVariant values{std::vector<DataVariant>{6.3, 3.2, 4ul}}; + FunctionTinyVectorArgumentConverter<TinyVector<3>, TinyVector<3>> converter2{2}; + converter2.convert(execution_policy, values); + + REQUIRE(std::get<TinyVector<3>>(execution_policy.currentContext()[0]) == x3); + REQUIRE(std::get<TinyVector<1>>(execution_policy.currentContext()[1]) == TinyVector<1>{x1}); + REQUIRE(std::get<TinyVector<3>>(execution_policy.currentContext()[2]) == TinyVector<3>{6.3, 3.2, 4ul}); + + AggregateDataVariant bad_values{std::vector<DataVariant>{6.3, 3.2, std::string{"bar"}}}; + + REQUIRE_THROWS_WITH(converter2.convert(execution_policy, bad_values), std::string{"unexpected error: "} + + demangle<std::string>() + + " unexpected aggregate value type"); + } + + SECTION("FunctionTupleArgumentConverter") + { + const TinyVector<3> x3{1.7, 2.9, -3}; + FunctionTupleArgumentConverter<TinyVector<3>, TinyVector<3>> converter0{0}; + converter0.convert(execution_policy, TinyVector{x3}); + + const double a = 1.2; + const double b = -3.5; + const double c = 2.6; + FunctionTupleArgumentConverter<double, double> converter1{1}; + converter1.convert(execution_policy, std::vector{a, b, c}); + + const uint64_t i = 1; + const uint64_t j = 3; + const uint64_t k = 6; + FunctionTupleArgumentConverter<double, uint64_t> converter2{2}; + converter2.convert(execution_policy, std::vector<uint64_t>{i, j, k}); + + REQUIRE(std::get<std::vector<TinyVector<3>>>(execution_policy.currentContext()[0]) == + std::vector<TinyVector<3>>{x3}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a, b, c}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[2]) == std::vector<double>{i, j, k}); + + converter1.convert(execution_policy, a); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a}); + + converter1.convert(execution_policy, j); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{j}); + + // Errors + REQUIRE_THROWS_WITH(converter0.convert(execution_policy, j), + "unexpected error: cannot convert 'unsigned long' to 'TinyVector<3ul, double>'"); + } + + SECTION("FunctionListArgumentConverter") + { + const uint64_t i = 3; + FunctionListArgumentConverter<double, double> converter0{0}; + converter0.convert(execution_policy, i); + + const double a = 6.3; + const double b = -1.3; + const double c = 3.6; + FunctionListArgumentConverter<double, double> converter1{1}; + converter1.convert(execution_policy, std::vector<double>{a, b, c}); + + AggregateDataVariant v{std::vector<DataVariant>{1ul, 2.3, -3l}}; + FunctionListArgumentConverter<double, double> converter2{2}; + converter2.convert(execution_policy, v); + + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[0]) == std::vector<double>{i}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a, b, c}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[2]) == std::vector<double>{1ul, 2.3, -3l}); + + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_0{0}; + converterR2_0.convert(execution_policy, TinyVector<2>{1, 3.2}); + + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_1{1}; + converterR2_1.convert(execution_policy, std::vector{TinyVector<2>{1, 3.2}, TinyVector<2>{-1, 0.2}}); + + AggregateDataVariant v_R2{std::vector<DataVariant>{TinyVector<2>{-3, 12.2}, TinyVector<2>{2, 1.2}}}; + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_2{2}; + converterR2_2.convert(execution_policy, v_R2); + + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[0]) == + std::vector<TinyVector<2>>{TinyVector<2>{1, 3.2}}); + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[1]) == + std::vector<TinyVector<2>>{TinyVector<2>{1, 3.2}, TinyVector<2>{-1, 0.2}}); + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[2]) == + std::vector<TinyVector<2>>{TinyVector<2>{-3, 12.2}, TinyVector<2>{2, 1.2}}); + } + + SECTION("FunctionArgumentToFunctionSymbolIdConverter") + { + std::shared_ptr symbol_table = std::make_shared<SymbolTable>(); + + const uint64_t f_id = 3; + FunctionArgumentToFunctionSymbolIdConverter converter0{0, symbol_table}; + converter0.convert(execution_policy, f_id); + + REQUIRE(std::get<FunctionSymbolId>(execution_policy.currentContext()[0]).id() == f_id); + } +}