diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index 4870ab0028bc1c1bc82f13b7b4b20d5d989ecce9..8073a47ef7b634a17d4498522d9ecf663cbcd716 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -122,20 +122,28 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder (_copyValue<I>(t, v), ...); } - template <size_t I> + template <typename T> PUGS_INLINE ASTNodeDataType - _getOneParameterDataType(ArgsTuple& t) const + _getDataType() const { - using ArgN_T = std::decay_t<decltype(std::get<I>(t))>; - if constexpr (is_data_variant_v<ArgN_T>) { - return ast_node_data_type_from<ArgN_T>; - } else if constexpr (std::is_same_v<void, ArgN_T>) { + if constexpr (is_data_variant_v<T>) { + return ast_node_data_type_from<T>; + } else if constexpr (std::is_same_v<void, T>) { return ASTNodeDataType::void_t; } else { - return ASTNodeDataType::type_id_t; + Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); + return ast_node_data_type_from<T>; } } + template <size_t I> + PUGS_INLINE ASTNodeDataType + _getOneParameterDataType(ArgsTuple& t) const + { + using ArgN_T = std::decay_t<decltype(std::get<I>(t))>; + return _getDataType<ArgN_T>(); + } + template <size_t... I> PUGS_INLINE std::vector<ASTNodeDataType> _getParameterDataTypes(ArgsTuple t, std::index_sequence<I...>) const @@ -177,13 +185,7 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder PUGS_INLINE ASTNodeDataType getReturnDataType() const final { - if constexpr (is_data_variant_v<FX>) { - return ast_node_data_type_from<FX>; - } else if constexpr (std::is_same_v<void, FX>) { - return ASTNodeDataType::void_t; - } else { - return ASTNodeDataType::type_id_t; - } + return _getDataType<FX>(); } PUGS_INLINE std::vector<ASTNodeDataType> @@ -256,19 +258,27 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder return std::make_shared<DataHandler<T>>(data); } - public: + template <typename T> PUGS_INLINE ASTNodeDataType - getReturnDataType() const final + _getDataType() const { - if constexpr (is_data_variant_v<FX>) { - return ast_node_data_type_from<FX>; - } else if constexpr (std::is_same_v<void, FX>) { + if constexpr (is_data_variant_v<T>) { + return ast_node_data_type_from<T>; + } else if constexpr (std::is_same_v<void, T>) { return ASTNodeDataType::void_t; } else { - return ASTNodeDataType::type_id_t; + Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); + return ast_node_data_type_from<T>; } } + public: + PUGS_INLINE ASTNodeDataType + getReturnDataType() const final + { + return this->_getDataType<FX>(); + } + PUGS_INLINE std::vector<ASTNodeDataType> getParameterDataTypes() const final { diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 052ac6fe06b8ea4c117fcdd747678c386c56cf9f..bf0445d6aadf5cfdbefed9134dafe0e89dd8d290 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -4,6 +4,17 @@ // clazy:excludeall=non-pod-global-static +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = {ASTNodeDataType::type_id_t, + "shared_const_double"}; + +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<double>> = {ASTNodeDataType::type_id_t, "shared_double"}; + +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<uint64_t>> = {ASTNodeDataType::type_id_t, + "shared_uint64_t"}; + TEST_CASE("BuiltinFunctionEmbedder", "[language]") { rang::setControlMode(rang::control::Off); @@ -95,10 +106,12 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") SECTION("EmbeddedData(double, double) BuiltinFunctionEmbedder") { - std::function sum = [&](double x, double y) -> std::shared_ptr<double> { return std::make_shared<double>(x + y); }; + std::function sum = [&](double x, double y) -> std::shared_ptr<const double> { + return std::make_shared<const double>(x + y); + }; std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = - std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<double>(double, double)>>(sum); + std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<const double>(double, double)>>(sum); // using 4ul enforces cast test REQUIRE(i_embedded_c->numberOfParameters() == 2); @@ -106,13 +119,13 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::double_t); REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::double_t); - REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::type_id_t); + REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<const double>>); DataVariant result = i_embedded_c->apply(std::vector<DataVariant>{2.3, 4ul}); EmbeddedData embedded_data = std::get<EmbeddedData>(result); - const IDataHandler& handled_data = embedded_data.get(); - const DataHandler<double>& data = dynamic_cast<const DataHandler<double>&>(handled_data); + const IDataHandler& handled_data = embedded_data.get(); + const DataHandler<const double>& data = dynamic_cast<const DataHandler<const double>&>(handled_data); REQUIRE(*data.data_ptr() == (2.3 + 4ul)); } @@ -125,7 +138,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->numberOfParameters() == 1); REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1); - REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::type_id_t); + REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ast_node_data_type_from<std::shared_ptr<const double>>); REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); EmbeddedData data(std::make_shared<DataHandler<const double>>(std::make_shared<const double>(-2.3))); @@ -189,7 +202,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->numberOfParameters() == 1); REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1); - REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::type_id_t); + REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t); REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::unsigned_int_t); std::vector<EmbeddedData> embedded_data; @@ -248,7 +261,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->numberOfParameters() == 0); REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0); - REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::type_id_t); + REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<double>>); const auto embedded_data = std::get<EmbeddedData>(i_embedded_c->apply(std::vector<DataVariant>{})); const IDataHandler& handled_data = embedded_data.get();