diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp index 7433ad9f72e7e7314705ce36dfec44114172955e..2473205ae73d8caf552ac561c14907159d192622 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp @@ -45,45 +45,7 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo } else if (type_node.is_type<language::matrix_type>()) { data_type = getMatrixDataType(type_node); } else if (type_node.is_type<language::tuple_type_specifier>()) { - const auto& content_node = type_node.children[0]; - - if (content_node->is_type<language::type_name_id>()) { - const std::string& type_name_id = content_node->string(); - - auto& symbol_table = *type_node.m_symbol_table; - - const auto [i_type_symbol, found] = symbol_table.find(type_name_id, content_node->begin()); - if (not found) { - throw ParseError("undefined type identifier", std::vector{content_node->begin()}); - } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) { - std::ostringstream os; - os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '" - << dataTypeName(i_type_symbol->attributes().dataType()) << '\''; - throw ParseError(os.str(), std::vector{content_node->begin()}); - } - - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(type_name_id); - } else if (content_node->is_type<language::B_set>()) { - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); - } else if (content_node->is_type<language::Z_set>()) { - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - } else if (content_node->is_type<language::N_set>()) { - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); - } else if (content_node->is_type<language::R_set>()) { - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); - } else if (content_node->is_type<language::vector_type>()) { - content_node->m_data_type = getVectorDataType(*type_node.children[0]); - } else if (content_node->is_type<language::matrix_type>()) { - content_node->m_data_type = getMatrixDataType(*type_node.children[0]); - } else if (content_node->is_type<language::string_type>()) { - content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - } else { - // LCOV_EXCL_START - throw UnexpectedError("unexpected content type in tuple"); - // LCOV_EXCL_STOP - } - - data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(content_node->m_data_type); + data_type = getTupleDataType(type_node); } else if (type_node.is_type<language::string_type>()) { data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else if (type_node.is_type<language::type_name_id>()) { @@ -372,6 +334,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const value_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else if (image_node.is_type<language::type_name_id>()) { value_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(image_node.m_data_type.nameOfTypeId()); + } else if (image_node.is_type<language::tuple_type_specifier>()) { + value_type = getTupleDataType(image_node); } // LCOV_EXCL_START diff --git a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp index edec0c149615f28a35256afc62f3ceef3182c3e1..76b34ae5374adf13cfa82a513ca3e2fc13401400 100644 --- a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp @@ -118,6 +118,145 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; + auto add_affectation_processor_for_tuple_data = [&](const ASTNodeDataType& tuple_type, + const ASTNodeSubDataType& node_sub_data_type) { + if constexpr (std::is_same_v<OperatorT, language::eq_op>) { + if (node_sub_data_type.m_data_type == ASTNodeDataType::tuple_t) { + const ASTNodeDataType& rhs_tuple_content = node_sub_data_type.m_data_type.contentType(); + switch (tuple_type.contentType()) { + case ASTNodeDataType::bool_t: { + if (rhs_tuple_content == ASTNodeDataType::bool_t) { + list_affectation_processor->template add<std::vector<bool>, std::vector<bool>>(value_node); + } else { + // LCOV_EXCL_START + throw UnexpectedError("incompatible tuple types in affectation"); + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::unsigned_int_t: { + if (rhs_tuple_content == ASTNodeDataType::bool_t) { + list_affectation_processor->template add<std::vector<uint64_t>, std::vector<bool>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::unsigned_int_t) { + list_affectation_processor->template add<std::vector<uint64_t>, std::vector<uint64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::int_t) { + list_affectation_processor->template add<std::vector<uint64_t>, std::vector<int64_t>>(value_node); + } else { + // LCOV_EXCL_START + throw UnexpectedError("incompatible tuple types in affectation"); + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::int_t: { + if (rhs_tuple_content == ASTNodeDataType::bool_t) { + list_affectation_processor->template add<std::vector<int64_t>, std::vector<bool>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::unsigned_int_t) { + list_affectation_processor->template add<std::vector<int64_t>, std::vector<uint64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::int_t) { + list_affectation_processor->template add<std::vector<int64_t>, std::vector<int64_t>>(value_node); + } else { + // LCOV_EXCL_START + throw UnexpectedError("incompatible tuple types in affectation"); + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::double_t: { + if (rhs_tuple_content == ASTNodeDataType::bool_t) { + list_affectation_processor->template add<std::vector<double>, std::vector<bool>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::unsigned_int_t) { + list_affectation_processor->template add<std::vector<double>, std::vector<uint64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::int_t) { + list_affectation_processor->template add<std::vector<double>, std::vector<int64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::double_t) { + list_affectation_processor->template add<std::vector<double>, std::vector<double>>(value_node); + } else { + // LCOV_EXCL_START + throw UnexpectedError("incompatible tuple types in affectation"); + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::string_t: { + if (rhs_tuple_content == ASTNodeDataType::bool_t) { + list_affectation_processor->template add<std::vector<std::string>, std::vector<bool>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::unsigned_int_t) { + list_affectation_processor->template add<std::vector<std::string>, std::vector<uint64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::int_t) { + list_affectation_processor->template add<std::vector<std::string>, std::vector<int64_t>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::double_t) { + list_affectation_processor->template add<std::vector<std::string>, std::vector<double>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::string_t) { + list_affectation_processor->template add<std::vector<std::string>, std::vector<std::string>>(value_node); + } else if (rhs_tuple_content == ASTNodeDataType::vector_t) { + switch (rhs_tuple_content.dimension()) { + case 1: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyVector<1>>>( + value_node); + break; + } + case 2: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyVector<2>>>( + value_node); + break; + } + case 3: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyVector<3>>>( + value_node); + break; + } + } + break; + } else if (rhs_tuple_content == ASTNodeDataType::matrix_t) { + Assert(rhs_tuple_content.numberOfRows() == rhs_tuple_content.numberOfColumns()); + switch (rhs_tuple_content.numberOfRows()) { + case 1: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyMatrix<1>>>( + value_node); + break; + } + case 2: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyMatrix<2>>>( + value_node); + break; + } + case 3: { + list_affectation_processor->template add<std::vector<std::string>, std::vector<TinyMatrix<3>>>( + value_node); + break; + } + } + break; + } else { + // LCOV_EXCL_START + throw UnexpectedError("incompatible tuple types in affectation"); + // LCOV_EXCL_STOP + } + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error:invalid operand type for tuple affectation", + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (node_sub_data_type.m_data_type == ASTNodeDataType::list_t) { + list_affectation_processor->template add<std::vector<std::string>, AggregateDataVariant>(value_node); + // throw NotImplementedError("here"); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error:invalid operand type for tuple affectation", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error:invalid operand type for tuple affectation", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP + } + }; + auto add_affectation_processor_for_embedded_data = [&](const ASTNodeSubDataType& node_sub_data_type) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { switch (node_sub_data_type.m_data_type) { @@ -184,6 +323,31 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } break; } + case ASTNodeDataType::matrix_t: { + Assert(node_sub_data_type.m_data_type.numberOfRows() == node_sub_data_type.m_data_type.numberOfColumns()); + switch (node_sub_data_type.m_data_type.numberOfRows()) { + case 1: { + list_affectation_processor->template add<std::string, TinyMatrix<1>>(value_node); + break; + } + case 2: { + list_affectation_processor->template add<std::string, TinyMatrix<2>>(value_node); + break; + } + case 3: { + list_affectation_processor->template add<std::string, TinyMatrix<3>>(value_node); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + break; + } + // LCOV_EXCL_START default: { throw ParseError("unexpected error:invalid operand type for string affectation", @@ -267,6 +431,10 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( case ASTNodeDataType::string_t: { add_affectation_processor_for_string_data(node_sub_data_type); break; + } + case ASTNodeDataType::tuple_t: { + add_affectation_processor_for_tuple_data(value_type, node_sub_data_type); + break; } // LCOV_EXCL_START default: { diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp index e12d985add1cc54dda9ff6bb090146bef692b434..933afbff4f297df12c7aa4751c20b62f69e9af77 100644 --- a/src/language/node_processor/AffectationProcessor.hpp +++ b/src/language/node_processor/AffectationProcessor.hpp @@ -105,7 +105,7 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs = std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; - os << std::get<DataT>(rhs); + os << std::boolalpha << std::get<DataT>(rhs); m_lhs = os.str(); } } else { @@ -115,7 +115,7 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs += std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; - os << std::get<DataT>(rhs); + os << std::boolalpha << std::get<DataT>(rhs); m_lhs += os.str(); } } @@ -149,8 +149,125 @@ class AffectationExecutor final : public IAffectationExecutor } }, rhs); + } else if constexpr (is_std_vector_v<ValueT> and is_std_vector_v<DataT>) { + using ValueContentT = typename ValueT::value_type; + using DataContentT = typename DataT::value_type; + + if constexpr (std::is_same_v<ValueContentT, DataContentT>) { + m_lhs = std::move(std::get<DataT>(rhs)); + } else if constexpr (std::is_convertible_v<DataContentT, ValueContentT>) { + m_lhs.resize(std::get<DataT>(rhs).size()); + std::visit( + [&](auto&& v) { + using Vi_T = std::decay_t<decltype(v)>; + if constexpr (is_std_vector_v<Vi_T>) { + if constexpr (std::is_arithmetic_v<typename Vi_T::value_type>) { + for (size_t i = 0; i < v.size(); ++i) { + m_lhs[i] = v[i]; + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + rhs); + } else if constexpr (std::is_same_v<ValueContentT, std::string>) { + m_lhs.resize(std::get<DataT>(rhs).size()); + + std::visit( + [&](auto&& v) { + if constexpr (is_std_vector_v<std::decay_t<decltype(v)>>) { + using Vi_T = typename std::decay_t<decltype(v)>::value_type; + for (size_t i = 0; i < v.size(); ++i) { + if constexpr (std::is_arithmetic_v<Vi_T>) { + m_lhs[i] = std::move(std::to_string(v[i])); + } else { + std::ostringstream os; + os << std::boolalpha << v[i]; + m_lhs[i] = std::move(os.str()); + } + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + rhs); + } else { + // LCOV_EXCL_START + throw UnexpectedError("invalid value type"); + // LCOV_EXCL_STOP + } + } else if constexpr (is_std_vector_v<ValueT> and std::is_same_v<DataT, AggregateDataVariant>) { + using ValueContentT = typename ValueT::value_type; + const AggregateDataVariant& children_values = std::get<AggregateDataVariant>(rhs); + m_lhs.resize(children_values.size()); + auto& tuple_value = m_lhs; + for (size_t i = 0; i < children_values.size(); ++i) { + std::visit( + [&](auto&& child_value) { + using T = std::decay_t<decltype(child_value)>; + if constexpr (std::is_same_v<T, ValueContentT>) { + tuple_value[i] = child_value; + } else if constexpr (std::is_arithmetic_v<ValueContentT> and + std::is_convertible_v<T, ValueContentT>) { + tuple_value[i] = static_cast<ValueContentT>(child_value); + } else if constexpr (std::is_same_v<std::string, ValueContentT>) { + if constexpr (std::is_arithmetic_v<T>) { + tuple_value[i] = std::to_string(child_value); + } else { + std::ostringstream os; + os << std::boolalpha << child_value; + tuple_value[i] = os.str(); + } + } else if constexpr (is_tiny_vector_v<ValueContentT>) { + if constexpr (std::is_arithmetic_v<T>) { + if constexpr (std::is_same_v<ValueContentT, TinyVector<1>>) { + tuple_value[i][0] = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected error: unexpected right hand side type in affectation"); + // LCOV_EXCL_STOP + } + } else if constexpr (is_tiny_matrix_v<ValueContentT>) { + if constexpr (std::is_arithmetic_v<T>) { + if constexpr (std::is_same_v<ValueContentT, TinyMatrix<1>>) { + tuple_value[i](0, 0) = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected error: unexpected right hand side type in affectation"); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected error: unexpected right hand side type in affectation"); + // LCOV_EXCL_STOP + } + }, + children_values[i]); + } + // throw NotImplementedError("list -> tuple"); } else { + // LCOV_EXCL_START throw UnexpectedError("invalid value type"); + // LCOV_EXCL_STOP } } else { AffOp<OperatorT>().eval(m_lhs, std::get<DataT>(rhs)); @@ -251,7 +368,7 @@ class AffectationToTupleProcessor final : public AffectationToDataVariantProcess *m_lhs = std::vector{std::move(std::to_string(v))}; } else { std::ostringstream os; - os << v; + os << std::boolalpha << v; *m_lhs = std::vector<std::string>{os.str()}; } } else if constexpr (is_tiny_vector_v<ValueT> or is_tiny_matrix_v<ValueT>) { @@ -306,7 +423,7 @@ class AffectationToTupleFromListProcessor final : public AffectationToDataVarian tuple_value[i] = std::to_string(child_value); } else { std::ostringstream os; - os << child_value; + os << std::boolalpha << child_value; tuple_value[i] = os.str(); } } else if constexpr (is_tiny_vector_v<ValueT>) { diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp index c5313a0ea3567f85d375ad107c5a9a183ac6cf23..9dd9203275ef91c08742f9f58769e9023e24b852 100644 --- a/src/language/node_processor/FunctionProcessor.hpp +++ b/src/language/node_processor/FunctionProcessor.hpp @@ -31,7 +31,9 @@ class FunctionExpressionProcessor final : public INodeProcessor static_assert(ReturnType::Dimension == 1, "invalid conversion"); return ReturnType(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy))); } else { + // LCOV_EXCL_START throw UnexpectedError("invalid conversion"); + // LCOV_EXCL_STOP } } diff --git a/src/language/utils/ASTNodeDataType.cpp b/src/language/utils/ASTNodeDataType.cpp index 462e1e26532e7a6f2ab73369b8a0e462a9eaa6e9..3c6d4af0f8be9f4002834fb15548316f301b0159 100644 --- a/src/language/utils/ASTNodeDataType.cpp +++ b/src/language/utils/ASTNodeDataType.cpp @@ -4,6 +4,7 @@ #include <language/ast/ASTNode.hpp> #include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> +#include <language/utils/SymbolTable.hpp> #include <utils/PugsAssert.hpp> ASTNodeDataType @@ -78,7 +79,8 @@ getMatrixDataType(const ASTNode& type_node) ASTNodeDataType getMatrixExpressionType(const ASTNode& matrix_expression_node) { - if (not matrix_expression_node.is_type<language::matrix_expression>()) { + if (not(matrix_expression_node.is_type<language::matrix_expression>() and + matrix_expression_node.children.size() > 0)) { throw ParseError("unexpected node type", matrix_expression_node.begin()); } @@ -113,6 +115,50 @@ getMatrixExpressionType(const ASTNode& matrix_expression_node) return ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension0, dimension1); } +ASTNodeDataType +getTupleDataType(const ASTNode& type_node) +{ + const auto& content_node = type_node.children[0]; + + if (content_node->is_type<language::type_name_id>()) { + const std::string& type_name_id = content_node->string(); + + auto& symbol_table = *type_node.m_symbol_table; + + const auto [i_type_symbol, found] = symbol_table.find(type_name_id, content_node->begin()); + if (not found) { + throw ParseError("undefined type identifier", std::vector{content_node->begin()}); + } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) { + std::ostringstream os; + os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '" + << dataTypeName(i_type_symbol->attributes().dataType()) << '\''; + throw ParseError(os.str(), std::vector{content_node->begin()}); + } + + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(type_name_id); + } else if (content_node->is_type<language::B_set>()) { + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + } else if (content_node->is_type<language::Z_set>()) { + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } else if (content_node->is_type<language::N_set>()) { + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + } else if (content_node->is_type<language::R_set>()) { + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + } else if (content_node->is_type<language::vector_type>()) { + content_node->m_data_type = getVectorDataType(*type_node.children[0]); + } else if (content_node->is_type<language::matrix_type>()) { + content_node->m_data_type = getMatrixDataType(*type_node.children[0]); + } else if (content_node->is_type<language::string_type>()) { + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected content type in tuple"); + // LCOV_EXCL_STOP + } + + return ASTNodeDataType::build<ASTNodeDataType::tuple_t>(content_node->m_data_type); +} + std::string dataTypeName(const ASTNodeDataType& data_type) { diff --git a/src/language/utils/ASTNodeDataType.hpp b/src/language/utils/ASTNodeDataType.hpp index 27f08bf5a003727b4dea779ce7156870d29af448..b7cc48fefaa6d849cf970641b37279a5a85ca15e 100644 --- a/src/language/utils/ASTNodeDataType.hpp +++ b/src/language/utils/ASTNodeDataType.hpp @@ -18,6 +18,8 @@ ASTNodeDataType getVectorExpressionType(const ASTNode& vector_expression_node); ASTNodeDataType getMatrixDataType(const ASTNode& type_node); ASTNodeDataType getMatrixExpressionType(const ASTNode& matrix_expression_node); +ASTNodeDataType getTupleDataType(const ASTNode& type_node); + std::string dataTypeName(const std::vector<ASTNodeDataType>& data_type_vector); std::string dataTypeName(const ASTNodeDataType& data_type); diff --git a/src/language/utils/ASTNodeNaturalConversionChecker.cpp b/src/language/utils/ASTNodeNaturalConversionChecker.cpp index 44182f0424f8c088b3d555ef467f1f4fede7d4f3..b3ad34eeabdcaa942590fa34a7b1940f7a552834 100644 --- a/src/language/utils/ASTNodeNaturalConversionChecker.cpp +++ b/src/language/utils/ASTNodeNaturalConversionChecker.cpp @@ -108,7 +108,7 @@ ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalExpressionConve } } else if (target_data_type == ASTNodeDataType::tuple_t) { const ASTNodeDataType& target_content_type = target_data_type.contentType(); - if (node.m_data_type == ASTNodeDataType::tuple_t) { + if ((data_type == ASTNodeDataType::tuple_t) or (node.m_data_type == ASTNodeDataType::tuple_t)) { this->_checkIsNaturalExpressionConversion(node, data_type.contentType(), target_content_type); } else if (node.m_data_type == ASTNodeDataType::list_t) { for (const auto& child : node.children) { diff --git a/src/language/utils/BuiltinFunctionEmbedderUtils.cpp b/src/language/utils/BuiltinFunctionEmbedderUtils.cpp index ba6658ff7b5dc901182c62f7a07c0018ed94bd5a..be3474ef7ae4e657d7553715c8414180ddc15d52 100644 --- a/src/language/utils/BuiltinFunctionEmbedderUtils.cpp +++ b/src/language/utils/BuiltinFunctionEmbedderUtils.cpp @@ -82,24 +82,7 @@ getBuiltinFunctionEmbedder(ASTNode& n) } bool is_castable = true; if (target_type.dimension() > 1) { - switch (arg_type) { - case ASTNodeDataType::int_t: { - break; - } - case ASTNodeDataType::list_t: { - if (arg_type.contentTypeList().size() != target_type.dimension()) { - is_castable = false; - break; - } - for (auto list_arg : arg_type.contentTypeList()) { - is_castable &= isNaturalConversion(*list_arg, ASTNodeDataType::build<ASTNodeDataType::double_t>()); - } - break; - } - default: { - is_castable &= false; - } - } + return (arg_type == ASTNodeDataType::int_t); } else { is_castable &= isNaturalConversion(arg_type, ASTNodeDataType::build<ASTNodeDataType::double_t>()); } @@ -113,24 +96,7 @@ getBuiltinFunctionEmbedder(ASTNode& n) bool is_castable = true; if (target_type.numberOfRows() > 1) { - switch (arg_type) { - case ASTNodeDataType::int_t: { - break; - } - case ASTNodeDataType::list_t: { - if (arg_type.contentTypeList().size() != target_type.numberOfRows() * target_type.numberOfColumns()) { - is_castable = false; - break; - } - for (auto list_arg : arg_type.contentTypeList()) { - is_castable &= isNaturalConversion(*list_arg, ASTNodeDataType::build<ASTNodeDataType::double_t>()); - } - break; - } - default: { - is_castable &= false; - } - } + return (arg_type == ASTNodeDataType::int_t); } else { is_castable &= isNaturalConversion(arg_type, ASTNodeDataType::build<ASTNodeDataType::double_t>()); } diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index d39e0f4233df70709d3d765146447727f7541ae4..de5e884bf63bc0ac3726864c4c4571c0a0fb9f55 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -154,28 +154,6 @@ class PugsFunctionAdapter<OutputType(InputType...)> { if constexpr (is_tiny_vector_v<OutputType>) { switch (data_type) { - case ASTNodeDataType::list_t: { - return [](DataVariant&& result) -> OutputType { - AggregateDataVariant& v = std::get<AggregateDataVariant>(result); - OutputType x; - - for (size_t i = 0; i < x.dimension(); ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_arithmetic_v<Vi_T>) { - x[i] = vi; - } else { - // LCOV_EXCL_START - throw UnexpectedError("expecting arithmetic value"); - // LCOV_EXCL_STOP - } - }, - v[i]); - } - return x; - }; - } case ASTNodeDataType::vector_t: { return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; } @@ -231,30 +209,6 @@ class PugsFunctionAdapter<OutputType(InputType...)> } } else if constexpr (is_tiny_matrix_v<OutputType>) { switch (data_type) { - case ASTNodeDataType::list_t: { - return [](DataVariant&& result) -> OutputType { - AggregateDataVariant& v = std::get<AggregateDataVariant>(result); - OutputType x; - - for (size_t i = 0, l = 0; i < x.numberOfRows(); ++i) { - for (size_t j = 0; j < x.numberOfColumns(); ++j, ++l) { - std::visit( - [&](auto&& Aij) { - using Aij_T = std::decay_t<decltype(Aij)>; - if constexpr (std::is_arithmetic_v<Aij_T>) { - x(i, j) = Aij; - } else { - // LCOV_EXCL_START - throw UnexpectedError("expecting arithmetic value"); - // LCOV_EXCL_STOP - } - }, - v[l]); - } - } - return x; - }; - } case ASTNodeDataType::matrix_t: { return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; } diff --git a/tests/test_ASTNodeDataType.cpp b/tests/test_ASTNodeDataType.cpp index 023ec9963bcd66286c9ffcfcadbc4fb3b2abdbee..8c6219e2d9975a0f2e007b93b9e3beaa6285e955 100644 --- a/tests/test_ASTNodeDataType.cpp +++ b/tests/test_ASTNodeDataType.cpp @@ -182,6 +182,68 @@ TEST_CASE("ASTNodeDataType", "[language]") } } + SECTION("getVectorExpressionType") + { + std::unique_ptr vector_expression_node = std::make_unique<ASTNode>(); + vector_expression_node->set_type<language::vector_expression>(); + vector_expression_node->emplace_back(std::make_unique<ASTNode>()); + + SECTION("good nodes") + { + vector_expression_node->children.resize(1); + for (size_t i = 0; i < vector_expression_node->children.size(); ++i) { + vector_expression_node->children[i] = std::make_unique<ASTNode>(); + vector_expression_node->children[i]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + } + REQUIRE(getVectorExpressionType(*vector_expression_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + REQUIRE(getVectorExpressionType(*vector_expression_node).dimension() == 1); + + vector_expression_node->children.resize(2); + for (size_t i = 0; i < vector_expression_node->children.size(); ++i) { + vector_expression_node->children[i] = std::make_unique<ASTNode>(); + vector_expression_node->children[i]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + } + REQUIRE(getVectorExpressionType(*vector_expression_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + REQUIRE(getVectorExpressionType(*vector_expression_node).dimension() == 2); + + vector_expression_node->children.resize(3); + for (size_t i = 0; i < vector_expression_node->children.size(); ++i) { + vector_expression_node->children[i] = std::make_unique<ASTNode>(); + vector_expression_node->children[i]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + } + REQUIRE(getVectorExpressionType(*vector_expression_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + REQUIRE(getVectorExpressionType(*vector_expression_node).dimension() == 3); + } + + SECTION("bad content type") + { + vector_expression_node->children.resize(3); + for (size_t i = 0; i < vector_expression_node->children.size(); ++i) { + vector_expression_node->children[i] = std::make_unique<ASTNode>(); + } + REQUIRE_THROWS_WITH(getVectorExpressionType(*vector_expression_node), + "unexpected error: invalid implicit conversion: undefined -> R"); + } + + SECTION("bad node type") + { + vector_expression_node->set_type<language::real>(); + REQUIRE_THROWS_WITH(getVectorExpressionType(*vector_expression_node), "unexpected node type"); + } + + SECTION("bad children size 2") + { + vector_expression_node->children.resize(4); + REQUIRE_THROWS_WITH(getVectorExpressionType(*vector_expression_node), "invalid dimension (must be 1, 2 or 3)"); + } + + SECTION("bad children size 2") + { + vector_expression_node->children.clear(); + REQUIRE_THROWS_WITH(getVectorExpressionType(*vector_expression_node), "unexpected node type"); + } + } + SECTION("getMatrixDataType") { std::unique_ptr type_node = std::make_unique<ASTNode>(); @@ -229,7 +291,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); } - SECTION("bad children size 1") + SECTION("bad children size 2") { type_node->children.emplace_back(std::unique_ptr<ASTNode>()); REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); @@ -298,6 +360,127 @@ TEST_CASE("ASTNodeDataType", "[language]") } } + SECTION("getMatrixExpressionType") + { + std::unique_ptr matrix_expression_node = std::make_unique<ASTNode>(); + matrix_expression_node->set_type<language::matrix_expression>(); + matrix_expression_node->emplace_back(std::make_unique<ASTNode>()); + + SECTION("good nodes") + { + { + const size_t dimension = 1; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->children[j]->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } + } + REQUIRE(getMatrixExpressionType(*matrix_expression_node) == + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension, dimension)); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfRows() == dimension); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfColumns() == dimension); + } + + { + const size_t dimension = 2; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->children[j]->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } + } + REQUIRE(getMatrixExpressionType(*matrix_expression_node) == + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension, dimension)); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfRows() == dimension); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfColumns() == dimension); + } + + { + const size_t dimension = 3; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->children[j]->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } + } + REQUIRE(getMatrixExpressionType(*matrix_expression_node) == + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension, dimension)); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfRows() == dimension); + REQUIRE(getMatrixExpressionType(*matrix_expression_node).numberOfColumns() == dimension); + } + } + + SECTION("bad content type") + { + const size_t dimension = 3; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + } + } + REQUIRE_THROWS_WITH(getMatrixExpressionType(*matrix_expression_node), + "unexpected error: invalid implicit conversion: undefined -> R"); + } + + SECTION("bad node type") + { + matrix_expression_node->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixExpressionType(*matrix_expression_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + matrix_expression_node->children.resize(4); + REQUIRE_THROWS_WITH(getMatrixExpressionType(*matrix_expression_node), "invalid dimension (must be 1, 2 or 3)"); + } + + SECTION("bad children size 2") + { + const size_t dimension = 2; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension + 1; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + } + } + REQUIRE_THROWS_WITH(getMatrixExpressionType(*matrix_expression_node), "only square matrices are supported"); + } + + SECTION("bad children size 3") + { + const size_t dimension = 2; + matrix_expression_node->children.clear(); + for (size_t i = 0; i < dimension; ++i) { + matrix_expression_node->children.emplace_back(std::make_unique<ASTNode>()); + matrix_expression_node->children[i]->set_type<language::row_expression>(); + for (size_t j = 0; j < dimension; ++j) { + matrix_expression_node->children[i]->children.emplace_back(std::make_unique<ASTNode>()); + } + } + matrix_expression_node->children[1]->children.emplace_back(std::make_unique<ASTNode>()); + + REQUIRE_THROWS_WITH(getMatrixExpressionType(*matrix_expression_node), "row must have same sizes"); + } + } + SECTION("isNaturalConversion") { SECTION("-> B") diff --git a/tests/test_ASTNodeDataTypeBuilder.cpp b/tests/test_ASTNodeDataTypeBuilder.cpp index d8921fe051231713af2c2e8e258348098a1db6bb..9ec4f8dabbf448916ca491cc262be1edf5fa412c 100644 --- a/tests/test_ASTNodeDataTypeBuilder.cpp +++ b/tests/test_ASTNodeDataTypeBuilder.cpp @@ -520,9 +520,10 @@ let t : (R), t = (2, 3.1, 5); SECTION("R^d tuples") { std::string_view data = R"( -let a : R^2, a = (2,3.1); -let t1 : (R^2), t1 = (a, (1,2), 0); +let a : R^2, a = [2,3.1]; +let t1 : (R^2), t1 = (a, [1,2], 0); let t2 : (R^3), t2 = (0, 0); +let t3 : (R^2), t3 = ([1,2], a, 0); )"; std::string_view result = R"( @@ -533,7 +534,7 @@ let t2 : (R^3), t2 = (0, 0); | | +-(language::R_set:R) | | `-(language::integer:2:Z) | +-(language::name:a:R^2) - | `-(language::expression_list:Z*R) + | `-(language::vector_expression:R^2) | +-(language::integer:2:Z) | `-(language::real:3.1:R) +-(language::var_declaration:void) @@ -543,21 +544,34 @@ let t2 : (R^3), t2 = (0, 0); | | +-(language::R_set:R) | | `-(language::integer:2:Z) | +-(language::name:t1:(R^2...)) - | `-(language::expression_list:R^2*(Z*Z)*Z) + | `-(language::expression_list:R^2*R^2*Z) | +-(language::name:a:R^2) - | +-(language::tuple_expression:Z*Z) + | +-(language::vector_expression:R^2) | | +-(language::integer:1:Z) | | `-(language::integer:2:Z) | `-(language::integer:0:Z) + +-(language::var_declaration:void) + | +-(language::name:t2:(R^3...)) + | +-(language::tuple_type_specifier:(R^3...)) + | | `-(language::vector_type:R^3) + | | +-(language::R_set:R) + | | `-(language::integer:3:Z) + | +-(language::name:t2:(R^3...)) + | `-(language::expression_list:Z*Z) + | +-(language::integer:0:Z) + | `-(language::integer:0:Z) `-(language::var_declaration:void) - +-(language::name:t2:(R^3...)) - +-(language::tuple_type_specifier:(R^3...)) - | `-(language::vector_type:R^3) + +-(language::name:t3:(R^2...)) + +-(language::tuple_type_specifier:(R^2...)) + | `-(language::vector_type:R^2) | +-(language::R_set:R) - | `-(language::integer:3:Z) - +-(language::name:t2:(R^3...)) - `-(language::expression_list:Z*Z) - +-(language::integer:0:Z) + | `-(language::integer:2:Z) + +-(language::name:t3:(R^2...)) + `-(language::expression_list:R^2*R^2*Z) + +-(language::vector_expression:R^2) + | +-(language::integer:1:Z) + | `-(language::integer:2:Z) + +-(language::name:a:R^2) `-(language::integer:0:Z) )"; @@ -567,9 +581,10 @@ let t2 : (R^3), t2 = (0, 0); SECTION("R^dxd tuples") { std::string_view data = R"( -let a : R^2x2, a = (2, 3.1, -1.2, 4); -let t1 : (R^2x2), t1 = (a, (1,2,1,3), 0); +let a : R^2x2, a = [[2, 3.1], [-1.2, 4]]; +let t1 : (R^2x2), t1 = (a, [[1,2],[1,3]], 0); let t2 : (R^3x3), t2 = (0, 0); +let t3 : (R^2x2), t3 = ([[1,2],[1,3]], a, 0); )"; std::string_view result = R"( @@ -581,12 +596,14 @@ let t2 : (R^3x3), t2 = (0, 0); | | +-(language::integer:2:Z) | | `-(language::integer:2:Z) | +-(language::name:a:R^2x2) - | `-(language::expression_list:Z*R*R*Z) - | +-(language::integer:2:Z) - | +-(language::real:3.1:R) - | +-(language::unary_minus:R) - | | `-(language::real:1.2:R) - | `-(language::integer:4:Z) + | `-(language::matrix_expression:R^2x2) + | +-(language::row_expression:void) + | | +-(language::integer:2:Z) + | | `-(language::real:3.1:R) + | `-(language::row_expression:void) + | +-(language::unary_minus:R) + | | `-(language::real:1.2:R) + | `-(language::integer:4:Z) +-(language::var_declaration:void) | +-(language::name:t1:(R^2x2...)) | +-(language::tuple_type_specifier:(R^2x2...)) @@ -595,24 +612,44 @@ let t2 : (R^3x3), t2 = (0, 0); | | +-(language::integer:2:Z) | | `-(language::integer:2:Z) | +-(language::name:t1:(R^2x2...)) - | `-(language::expression_list:R^2x2*(Z*Z*Z*Z)*Z) + | `-(language::expression_list:R^2x2*R^2x2*Z) | +-(language::name:a:R^2x2) - | +-(language::tuple_expression:Z*Z*Z*Z) - | | +-(language::integer:1:Z) - | | +-(language::integer:2:Z) - | | +-(language::integer:1:Z) - | | `-(language::integer:3:Z) + | +-(language::matrix_expression:R^2x2) + | | +-(language::row_expression:void) + | | | +-(language::integer:1:Z) + | | | `-(language::integer:2:Z) + | | `-(language::row_expression:void) + | | +-(language::integer:1:Z) + | | `-(language::integer:3:Z) + | `-(language::integer:0:Z) + +-(language::var_declaration:void) + | +-(language::name:t2:(R^3x3...)) + | +-(language::tuple_type_specifier:(R^3x3...)) + | | `-(language::matrix_type:R^3x3) + | | +-(language::R_set:R) + | | +-(language::integer:3:Z) + | | `-(language::integer:3:Z) + | +-(language::name:t2:(R^3x3...)) + | `-(language::expression_list:Z*Z) + | +-(language::integer:0:Z) | `-(language::integer:0:Z) `-(language::var_declaration:void) - +-(language::name:t2:(R^3x3...)) - +-(language::tuple_type_specifier:(R^3x3...)) - | `-(language::matrix_type:R^3x3) + +-(language::name:t3:(R^2x2...)) + +-(language::tuple_type_specifier:(R^2x2...)) + | `-(language::matrix_type:R^2x2) | +-(language::R_set:R) - | +-(language::integer:3:Z) - | `-(language::integer:3:Z) - +-(language::name:t2:(R^3x3...)) - `-(language::expression_list:Z*Z) - +-(language::integer:0:Z) + | +-(language::integer:2:Z) + | `-(language::integer:2:Z) + +-(language::name:t3:(R^2x2...)) + `-(language::expression_list:R^2x2*R^2x2*Z) + +-(language::matrix_expression:R^2x2) + | +-(language::row_expression:void) + | | +-(language::integer:1:Z) + | | `-(language::integer:2:Z) + | `-(language::row_expression:void) + | +-(language::integer:1:Z) + | `-(language::integer:3:Z) + +-(language::name:a:R^2x2) `-(language::integer:0:Z) )"; diff --git a/tests/test_BuiltinFunctionEmbedderUtils.cpp b/tests/test_BuiltinFunctionEmbedderUtils.cpp index da2fbdd1dc8466377621b077af3ef790ec15e24e..ebb0195dc01ed45a25085b8f3150d71de08005ac 100644 --- a/tests/test_BuiltinFunctionEmbedderUtils.cpp +++ b/tests/test_BuiltinFunctionEmbedderUtils.cpp @@ -349,7 +349,7 @@ foo(3,0); SECTION("builtin function R*R^2 -> R^2 (R^2 from list)") { std::string_view data = R"( -foo(3.1,(1,2.3)); +foo(3.1,[1,2.3]); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -431,7 +431,7 @@ foo(3,0); SECTION("builtin function R*R^3 -> R^2 (R^3 from list)") { std::string_view data = R"( -foo(3.1,(1,2.3,4)); +foo(3.1,[1,2.3,4]); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -621,7 +621,7 @@ foo(3,0); SECTION("builtin function R*R^2x2 -> R^2x2 (R^2x2 from list)") { std::string_view data = R"( -foo(3.1,(1,2.3,0,3)); +foo(3.1,[[1,2.3],[0,3]]); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -704,7 +704,7 @@ foo(3,0); SECTION("builtin function R*R^3x3 -> R^2x2 (R^3x3 from list)") { std::string_view data = R"( -foo(3.1,(1, 2.3, 4, 0.3, 2.5, 4.6, 2.7, 8.1, -9)); +foo(3.1,[[1, 2.3, 4], [0.3, 2.5, 4.6], [2.7, 8.1, -9]]); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -1009,29 +1009,6 @@ foo(0); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().dimension() == 2); } - SECTION("builtin function (R^2x2...) -> N (from castable list)") - { - std::string_view data = R"( -foo((1,2,3)); -)"; - - TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; - auto root_node = ASTBuilder::build(input); - register_functions(root_node, - FunctionList{std::make_pair("foo", std::make_shared<BuiltinFunctionEmbedder<uint64_t( - const std::vector<TinyMatrix<2>>&)>>( - [](const std::vector<TinyMatrix<2>>& x) -> uint64_t { - return x.size(); - }))}); - - auto function_embedder = getBuiltinFunctionEmbedder(*root_node->children[0]); - REQUIRE(function_embedder->getReturnDataType() == ASTNodeDataType::unsigned_int_t); - REQUIRE(function_embedder->getParameterDataTypes().size() == 1); - REQUIRE(function_embedder->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t); - REQUIRE(function_embedder->getParameterDataTypes()[0].contentType() == ASTNodeDataType::matrix_t); - REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfRows() == 2); - } - SECTION("builtin function (R^2x2...) -> N (from 0)") { std::string_view data = R"( @@ -1126,10 +1103,10 @@ foo(0); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().dimension() == 3); } - SECTION("builtin function (R^3x3...) -> N (from castable list)") + SECTION("builtin function (R^3x3...) -> N (from 0)") { std::string_view data = R"( -foo((1,2,3)); +foo(0); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -1149,10 +1126,11 @@ foo((1,2,3)); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfRows() == 3); } - SECTION("builtin function (R^3x3...) -> N (from 0)") + SECTION("builtin function (R^3x3...) -> N (from list)") { std::string_view data = R"( -foo(0); +let x:R^3x3; +foo((x,2*x)); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; @@ -1170,68 +1148,89 @@ foo(0); REQUIRE(function_embedder->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType() == ASTNodeDataType::matrix_t); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfRows() == 3); + REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfColumns() == 3); } + } - SECTION("builtin function (R^3x3...) -> N (from list)") + SECTION("complete case") + { + SECTION("tuple first") { std::string_view data = R"( let x:R^3x3; -foo((x,2*x)); +foo((x,2*x), 1, "bar", [2,3]); )"; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; auto root_node = ASTBuilder::build(input); register_functions(root_node, - FunctionList{std::make_pair("foo", std::make_shared<BuiltinFunctionEmbedder<uint64_t( - const std::vector<TinyMatrix<3>>&)>>( - [](const std::vector<TinyMatrix<3>>& x) -> uint64_t { - return x.size(); - }))}); + FunctionList{ + std::make_pair("foo", + std::make_shared<BuiltinFunctionEmbedder< + std::tuple<uint64_t, double, std::string>(const std::vector<TinyMatrix<3>>&, + const double&, const std::string&, + const TinyVector<2>&)>>( + [](const std::vector<TinyMatrix<3>>& x, const double& a, + const std::string& s, + const TinyVector<2>& y) -> std::tuple<uint64_t, double, std::string> { + return std::make_tuple(x.size(), a * y[0] + y[1], s + "_foo"); + }))}); auto function_embedder = getBuiltinFunctionEmbedder(*root_node->children[0]); - REQUIRE(function_embedder->getReturnDataType() == ASTNodeDataType::unsigned_int_t); - REQUIRE(function_embedder->getParameterDataTypes().size() == 1); + REQUIRE(function_embedder->getReturnDataType() == ASTNodeDataType::list_t); + REQUIRE(function_embedder->getReturnDataType().contentTypeList().size() == 3); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[1] == ASTNodeDataType::double_t); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[2] == ASTNodeDataType::string_t); + REQUIRE(function_embedder->getParameterDataTypes().size() == 4); REQUIRE(function_embedder->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType() == ASTNodeDataType::matrix_t); REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfRows() == 3); + REQUIRE(function_embedder->getParameterDataTypes()[0].contentType().numberOfColumns() == 3); + REQUIRE(function_embedder->getParameterDataTypes()[1] == ASTNodeDataType::double_t); + REQUIRE(function_embedder->getParameterDataTypes()[2] == ASTNodeDataType::string_t); + REQUIRE(function_embedder->getParameterDataTypes()[3] == ASTNodeDataType::vector_t); + REQUIRE(function_embedder->getParameterDataTypes()[3].dimension() == 2); } - } - SECTION("complete case") - { - std::string_view data = R"( + SECTION("tuple not first") + { + std::string_view data = R"( let x:R^3x3; -foo(1, "bar", (x,2*x), (2,3)); +foo(1, "bar", (x,2*x), [2,3]); )"; - TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; - auto root_node = ASTBuilder::build(input); - register_functions(root_node, - FunctionList{ - std::make_pair("foo", - std::make_shared<BuiltinFunctionEmbedder< - std::tuple<uint64_t, double, std::string>(const double&, const std::string&, - const std::vector<TinyMatrix<3>>&, - const TinyVector<2>&)>>( - [](const double& a, const std::string& s, const std::vector<TinyMatrix<3>>& x, - const TinyVector<2>& y) -> std::tuple<uint64_t, double, std::string> { - return std::make_tuple(x.size(), a * y[0] + y[1], s + "_foo"); - }))}); - - auto function_embedder = getBuiltinFunctionEmbedder(*root_node->children[0]); - REQUIRE(function_embedder->getReturnDataType() == ASTNodeDataType::list_t); - REQUIRE(function_embedder->getReturnDataType().contentTypeList().size() == 3); - REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); - REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[1] == ASTNodeDataType::double_t); - REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[2] == ASTNodeDataType::string_t); - REQUIRE(function_embedder->getParameterDataTypes().size() == 4); - REQUIRE(function_embedder->getParameterDataTypes()[0] == ASTNodeDataType::double_t); - REQUIRE(function_embedder->getParameterDataTypes()[1] == ASTNodeDataType::string_t); - REQUIRE(function_embedder->getParameterDataTypes()[2] == ASTNodeDataType::tuple_t); - REQUIRE(function_embedder->getParameterDataTypes()[2].contentType() == ASTNodeDataType::matrix_t); - REQUIRE(function_embedder->getParameterDataTypes()[2].contentType().numberOfRows() == 3); - REQUIRE(function_embedder->getParameterDataTypes()[3] == ASTNodeDataType::vector_t); - REQUIRE(function_embedder->getParameterDataTypes()[3].dimension() == 2); + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; + auto root_node = ASTBuilder::build(input); + register_functions(root_node, + FunctionList{ + std::make_pair("foo", + std::make_shared<BuiltinFunctionEmbedder< + std::tuple<uint64_t, double, std::string>(const double&, const std::string&, + const std::vector<TinyMatrix<3>>&, + const TinyVector<2>&)>>( + [](const double& a, const std::string& s, + const std::vector<TinyMatrix<3>>& x, + const TinyVector<2>& y) -> std::tuple<uint64_t, double, std::string> { + return std::make_tuple(x.size(), a * y[0] + y[1], s + "_foo"); + }))}); + + auto function_embedder = getBuiltinFunctionEmbedder(*root_node->children[0]); + REQUIRE(function_embedder->getReturnDataType() == ASTNodeDataType::list_t); + REQUIRE(function_embedder->getReturnDataType().contentTypeList().size() == 3); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[1] == ASTNodeDataType::double_t); + REQUIRE(*function_embedder->getReturnDataType().contentTypeList()[2] == ASTNodeDataType::string_t); + REQUIRE(function_embedder->getParameterDataTypes().size() == 4); + REQUIRE(function_embedder->getParameterDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(function_embedder->getParameterDataTypes()[1] == ASTNodeDataType::string_t); + REQUIRE(function_embedder->getParameterDataTypes()[2] == ASTNodeDataType::tuple_t); + REQUIRE(function_embedder->getParameterDataTypes()[2].contentType() == ASTNodeDataType::matrix_t); + REQUIRE(function_embedder->getParameterDataTypes()[2].contentType().numberOfRows() == 3); + REQUIRE(function_embedder->getParameterDataTypes()[2].contentType().numberOfColumns() == 3); + REQUIRE(function_embedder->getParameterDataTypes()[3] == ASTNodeDataType::vector_t); + REQUIRE(function_embedder->getParameterDataTypes()[3].dimension() == 2); + } } SECTION("errors") @@ -1332,7 +1331,7 @@ foo(x); SECTION("R^2: invalid argument list size") { std::string_view data = R"( -foo((1,2,3,4)); +foo(1,2,3,4); )"; std::string error_msg = "no matching function to call foo: Z*Z*Z*Z\n" @@ -1362,7 +1361,7 @@ foo((1,2,3,4)); SECTION("R^3: invalid argument list size") { std::string_view data = R"( -foo((1,2,3,4)); +foo(1,2,3,4); )"; std::string error_msg = "no matching function to call foo: Z*Z*Z*Z\n" @@ -1485,7 +1484,7 @@ foo(x); SECTION("R^2x2: invalid argument list size") { std::string_view data = R"( -foo((1,2,3)); +foo(1,2,3); )"; std::string error_msg = "no matching function to call foo: Z*Z*Z\n" @@ -1515,7 +1514,7 @@ foo((1,2,3)); SECTION("R^3x3: invalid argument list size") { std::string_view data = R"( -foo((1,2,3,4)); +foo(1,2,3,4); )"; std::string error_msg = "no matching function to call foo: Z*Z*Z*Z\n" diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp index 95d743f1a79ea55ea482b14d63fdb1f7fa49e664..57f5d65ac3b745e65b4d40047db87444b9550b55 100644 --- a/tests/test_BuiltinFunctionProcessor.cpp +++ b/tests/test_BuiltinFunctionProcessor.cpp @@ -376,10 +376,17 @@ runtimeError(); { { std::string_view data = R"( -let x:R, x = tuple_ZtoR((1,2,3-4)); +let x:R, x = tuple_ZtoR((1,2,3,-4)); )"; CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", 0.5 * (1 + 2 + 3 - 4)); } + + { + std::string_view data = R"( +let (X,x):(R)*R, (X,x) = R22ToTupleRxR([[1,1], [2,3]]); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "X", (std::vector<double>{1, 1, 2, 3})); + } } SECTION("simple N to tuple args evalation") diff --git a/tests/test_BuiltinFunctionRegister.hpp b/tests/test_BuiltinFunctionRegister.hpp index fc4e6c0d40db70eeb6e9774d2147e121e6b0a89d..bcabd3ad49c5933b938eea5bb6f97b46a6bd4cc4 100644 --- a/tests/test_BuiltinFunctionRegister.hpp +++ b/tests/test_BuiltinFunctionRegister.hpp @@ -168,6 +168,14 @@ class test_BuiltinFunctionRegister std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<2>>&)>>( [](const std::vector<TinyMatrix<2>>&) -> double { return 1; }))); + m_name_builtin_function_map.insert( + std::make_pair("R22ToTupleRxR:(R^2x2)", + std::make_shared< + BuiltinFunctionEmbedder<std::tuple<std::vector<double>, double>(const TinyMatrix<2>&)>>( + [](const TinyMatrix<2>& A) -> std::tuple<std::vector<double>, double> { + return std::make_tuple(std::vector<double>{A(0, 0), A(0, 1), A(1, 0), A(1, 1)}, A(0, 0)); + }))); + m_name_builtin_function_map.insert( std::make_pair("tuple_R33ToR:(R^3x3...)", std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<3>>)>>( diff --git a/tests/test_FunctionArgumentConverter.cpp b/tests/test_FunctionArgumentConverter.cpp index 643fbba0b092d337743ea5aeaedb25eb86a5b445..909353f79e3b7a1dcada5bfcaba92d80967420f6 100644 --- a/tests/test_FunctionArgumentConverter.cpp +++ b/tests/test_FunctionArgumentConverter.cpp @@ -136,6 +136,34 @@ TEST_CASE("FunctionArgumentConverter", "[language]") "unexpected error: cannot convert 'unsigned long' to 'TinyVector<3ul, double>'"); } + SECTION("FunctionTupleArgumentConverter (string tuple)") + { + const TinyVector<3> x3{1.7, 2.9, -3}; + FunctionTupleArgumentConverter<std::string, TinyVector<3>> converter0{0}; + converter0.convert(execution_policy, TinyVector{x3}); + + FunctionTupleArgumentConverter<std::string, std::vector<std::string>> converter1{1}; + converter1.convert(execution_policy, std::vector<std::string>{"foo"}); + + FunctionTupleArgumentConverter<std::string, std::vector<TinyVector<3>>> converter2{2}; + converter2.convert(execution_policy, std::vector<TinyVector<3>>{TinyVector<3>{1, 2, 3}}); + + REQUIRE(std::get<std::vector<std::string>>(execution_policy.currentContext()[0]) == + std::vector<std::string>{[](auto x) { + std::ostringstream os; + os << x; + return os.str(); + }(x3)}); + REQUIRE(std::get<std::vector<std::string>>(execution_policy.currentContext()[1]) == + std::vector<std::string>{"foo"}); + + REQUIRE(std::get<std::vector<std::string>>(execution_policy.currentContext()[2]) == std::vector<std::string>{[]() { + std::ostringstream os; + os << TinyVector<3>{1, 2, 3}; + return os.str(); + }()}); + } + SECTION("FunctionListArgumentConverter") { const uint64_t i = 3;