diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f62e1b88ed577fa73bbfddc6ff2263c8925ec0e..696c96578a15c9303d9135ec8a56089062c2cc30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -445,7 +445,7 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Coverage") add_custom_target(coverage ALL # in coverage mode we do coverage! - COMMAND ${FASTCOV} -q --gcov "${GCOV_BIN}" + COMMAND ${FASTCOV} --gcov "${GCOV_BIN}" --include "${PUGS_SOURCE_DIR}/src" --exclude "${PUGS_SOURCE_DIR}/src/main.cpp" "${PUGS_SOURCE_DIR}/src/utils/BacktraceManager.*" "${PUGS_SOURCE_DIR}/src/utils/FPEManager.*" "${PUGS_SOURCE_DIR}/src/utils/SignalManager.*" --lcov -o coverage.info -n diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp index 0635f1b32b12d29bbbaa85273d6c49d92aba12ca..bc64d550979c76055c2669b0555679daa5d351b5 100644 --- a/src/algebra/TinyMatrix.hpp +++ b/src/algebra/TinyMatrix.hpp @@ -36,6 +36,24 @@ class [[nodiscard]] TinyMatrix } public: + PUGS_INLINE + constexpr size_t dimension() const + { + return N * N; + } + + PUGS_INLINE + constexpr size_t nbRows() const + { + return N; + } + + PUGS_INLINE + constexpr size_t nbColumns() const + { + return N; + } + PUGS_INLINE constexpr TinyMatrix operator-() const { @@ -103,19 +121,16 @@ class [[nodiscard]] TinyMatrix PUGS_INLINE constexpr friend std::ostream& operator<<(std::ostream& os, const TinyMatrix& A) { - if constexpr (N == 1) { - os << A(0, 0); - } else { - os << '['; - for (size_t i = 0; i < N; ++i) { - os << '(' << A(i, 0); - for (size_t j = 1; j < N; ++j) { - os << ',' << A(i, j); - } - os << ')'; + os << '['; + for (size_t i = 0; i < N; ++i) { + os << '(' << A(i, 0); + for (size_t j = 1; j < N; ++j) { + os << ',' << A(i, j); } - os << ']'; + os << ')'; } + os << ']'; + return os; } diff --git a/src/language/PEGGrammar.hpp b/src/language/PEGGrammar.hpp index 553a2797c8cf227ab88d4b977f9f727b671a4c7a..beae437f0e5d0f9231b51afc0671e29e7ecfef30 100644 --- a/src/language/PEGGrammar.hpp +++ b/src/language/PEGGrammar.hpp @@ -75,11 +75,12 @@ struct string_type : TAO_PEGTL_KEYWORD("string") {}; struct scalar_type : sor< B_set, R_set, Z_set, N_set >{}; struct vector_type : seq< R_set, ignored, one< '^' >, ignored, integer >{}; +struct matrix_type : seq< R_set, ignored, one< '^' >, ignored, integer, ignored, one< 'x' >, ignored, integer >{}; struct basic_type : sor< scalar_type, string_type >{}; struct type_name_id; -struct simple_type_specifier : sor< vector_type, basic_type, type_name_id >{}; +struct simple_type_specifier : sor< matrix_type, vector_type, basic_type, type_name_id >{}; struct tuple_type_specifier : sor<try_catch< open_parent, simple_type_specifier, ignored, close_parent >, // non matching braces management @@ -180,7 +181,7 @@ struct postfix_operator : seq< sor< post_plusplus, post_minusminus>, ignored > { struct open_bracket : seq< one< '[' >, ignored > {}; struct close_bracket : seq< one< ']' >, ignored > {}; -struct subscript_expression : if_must< open_bracket, expression, close_bracket >{}; +struct subscript_expression : if_must< open_bracket, list_must<expression, COMMA>, close_bracket >{}; struct postfix_expression : seq< primary_expression, star< sor< subscript_expression , postfix_operator> > >{}; diff --git a/src/language/ast/ASTBuilder.cpp b/src/language/ast/ASTBuilder.cpp index 35d787354c37e5b1dbc3f176548fb4725d924cbb..ccf0f42db5c4144ef91475d3fa9c3c24520d1e35 100644 --- a/src/language/ast/ASTBuilder.cpp +++ b/src/language/ast/ASTBuilder.cpp @@ -108,21 +108,19 @@ struct ASTBuilder::simplify_unary : parse_tree::apply<ASTBuilder::simplify_unary } if (n->is_type<language::unary_expression>() or n->is_type<language::name_subscript_expression>()) { - const size_t child_nb = n->children.size(); - if (child_nb > 1) { + if (n->children.size() > 1) { if (n->children[1]->is_type<language::subscript_expression>()) { - auto expression = std::move(n->children[0]); - for (size_t i = 0; i < child_nb - 1; ++i) { - n->children[i] = std::move(n->children[i + 1]); - } + std::swap(n->children[0], n->children[1]); - auto& array_subscript_expression = n->children[0]; + n->children[0]->emplace_back(std::move(n->children[1])); n->children.pop_back(); - array_subscript_expression->children.emplace_back(std::move(expression)); - std::swap(array_subscript_expression->children[0], array_subscript_expression->children[1]); - - array_subscript_expression->m_begin = array_subscript_expression->children[0]->m_begin; + auto& array_subscript_expression = n->children[0]; + const size_t child_nb = array_subscript_expression->children.size(); + for (size_t i = 1; i < array_subscript_expression->children.size(); ++i) { + std::swap(array_subscript_expression->children[child_nb - i], + array_subscript_expression->children[child_nb - i - 1]); + } transform(n, st...); } @@ -247,6 +245,7 @@ using selector = parse_tree::selector< language::type_name_id, language::tuple_expression, language::vector_type, + language::matrix_type, language::string_type, language::cout_kw, language::cerr_kw, diff --git a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp index 6e712e3d84c0eb6d52064fdfecabd9dd6f9016db..884711c081ded349d32663f80180c06c4911bfba 100644 --- a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp @@ -2,8 +2,8 @@ #include <algebra/TinyVector.hpp> #include <language/PEGGrammar.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/AffectationMangler.hpp> #include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> @@ -30,17 +30,6 @@ ASTNodeAffectationExpressionBuilder::ASTNodeAffectationExpressionBuilder(ASTNode } }(); - // Special treatment dedicated to R^1 to be able to initialize them - if (((target_data_type != source_data_type) and (target_data_type == ASTNodeDataType::vector_t) and - (target_data_type.dimension() == 1)) or - // Special treatment for R^d vectors and operator *= - ((target_data_type == ASTNodeDataType::vector_t) and (source_data_type != ASTNodeDataType::vector_t) and - n.is_type<language::multiplyeq_op>())) { - ASTNodeNaturalConversionChecker{*n.children[1], ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{*n.children[1], target_data_type}; - } - const auto& optional_processor_builder = OperatorRepository::instance().getAffectationProcessorBuilder(affectation_name); diff --git a/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp b/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp index 350a4070e9a6bc65906627d54b2e2ed87c403cc6..e717044d1dea5fe6b85aec6c41941699c1e5ef65 100644 --- a/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp @@ -1,5 +1,6 @@ #include <language/ast/ASTNodeArraySubscriptExpressionBuilder.hpp> +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> #include <language/node_processor/ArraySubscriptProcessor.hpp> #include <language/utils/ParseError.hpp> @@ -27,6 +28,27 @@ ASTNodeArraySubscriptExpressionBuilder::ASTNodeArraySubscriptExpressionBuilder(A break; } } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<1>>>(node); + break; + } + case 2: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<2>>>(node); + break; + } + case 3: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<3>>>(node); + break; + } + default: { + throw ParseError("unexpected error: invalid array dimension", array_expression.begin()); + break; + } + } } else { throw ParseError("unexpected error: invalid array type", array_expression.begin()); } diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 6b1fe91ec1006600bf601c737719869497041e2b..248511e6c91c2eec08a94e87faa8179c06a0bf80 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -2,8 +2,8 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/BuiltinFunctionProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> #include <language/utils/SymbolTable.hpp> @@ -112,6 +112,84 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } }; + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + + if constexpr (std::is_same_v<ParameterT, TinyMatrix<1>>) { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == 1) and + (argument_node_sub_data_type.m_data_type.nbColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::bool_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, bool>>(argument_number); + } + case ASTNodeDataType::int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, int64_t>>(argument_number); + } + case ASTNodeDataType::unsigned_int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, uint64_t>>(argument_number); + } + case ASTNodeDataType::double_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, double>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } else { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (argument_node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (argument_node_sub_data_type.m_parent_node.children.size() == + (parameter_v.nbRows() * parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (argument_node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(argument_node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(argument_number); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } + }; + auto get_function_argument_to_string_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { return std::make_unique<FunctionArgumentToStringConverter>(argument_number); }; @@ -195,6 +273,25 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } // LCOV_EXCL_STOP } + } + case ASTNodeDataType::matrix_t: { + Assert(arg_data_type.nbRows() == arg_data_type.nbColumns()); + switch (arg_data_type.nbRows()) { + case 1: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<1>>>(argument_number); + } + case 2: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<2>>>(argument_number); + } + case 3: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<3>>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector"); + } + // LCOV_EXCL_STOP + } } // LCOV_EXCL_START default: { @@ -254,6 +351,26 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.nbRows() == parameter_type.nbColumns()); + switch (parameter_type.nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined parameter type for function", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_argument_to_string_converter(); } @@ -300,6 +417,27 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.contentType().nbRows() == parameter_type.contentType().nbColumns()); + switch (parameter_type.contentType().nbRows()) { + case 1: { + return get_function_argument_to_tuple_converter(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_to_tuple_converter(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_to_tuple_converter(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: unexpected tuple content for function: '" + dataTypeName(parameter_type) + + "'", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_argument_to_string_converter(); } @@ -320,13 +458,7 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } }; - if (parameter_type == ASTNodeDataType::vector_t and parameter_type.dimension() == 1) { - if (not isNaturalConversion(argument_node_sub_data_type.m_data_type, parameter_type)) { - ASTNodeNaturalConversionChecker{argument_node_sub_data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } - } else { - ASTNodeNaturalConversionChecker{argument_node_sub_data_type, parameter_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{argument_node_sub_data_type, parameter_type}; return get_function_argument_converter_for_argument_type(); } diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp index 7c93b50527bbae5c352d62159e40d0eba3bb7aad..018bc94f74160bfcd680df08ac9b2a08c54ba492 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp @@ -1,7 +1,7 @@ #include <language/ast/ASTNodeDataTypeBuilder.hpp> #include <language/PEGGrammar.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> @@ -41,6 +41,8 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (type_node.is_type<language::vector_type>()) { data_type = getVectorDataType(type_node); + } else if (type_node.is_type<language::matrix_type>()) { + data_type = getMatrixDataType(type_node); } else if (type_node.is_type<language::tuple_type_specifier>()) { const auto& content_node = type_node.children[0]; @@ -70,6 +72,8 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (content_node->is_type<language::vector_type>()) { content_node->m_data_type = getVectorDataType(*type_node.children[0]); + } else if (content_node->is_type<language::matrix_type>()) { + content_node->m_data_type = getMatrixDataType(*type_node.children[0]); } else if (content_node->is_type<language::string_type>()) { content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else { @@ -151,6 +155,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); } else if (n.is_type<language::vector_type>()) { n.m_data_type = getVectorDataType(n); + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = getMatrixDataType(n); } else if (n.is_type<language::literal>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); @@ -250,6 +256,14 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"; throw ParseError(message.str(), image_domain_node.begin()); } + } else if (image_domain_node.is_type<language::matrix_type>()) { + ASTNodeDataType image_type = getMatrixDataType(image_domain_node); + if (image_type.nbRows() * image_type.nbColumns() != nb_image_expressions) { + std::ostringstream message; + message << "expecting " << image_type.nbRows() * image_type.nbColumns() << " scalar expressions or an " + << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"; + throw ParseError(message.str(), image_domain_node.begin()); + } } else { std::ostringstream message; message << "number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow @@ -315,6 +329,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const value_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (image_node.is_type<language::vector_type>()) { value_type = getVectorDataType(image_node); + } else if (image_node.is_type<language::matrix_type>()) { + value_type = getMatrixDataType(image_node); } else if (image_node.is_type<language::string_type>()) { value_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } @@ -515,20 +531,37 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const throw ParseError(message.str(), n.begin()); } } else if (n.is_type<language::subscript_expression>()) { - Assert(n.children.size() == 2, "invalid number of sub-expressions in array subscript expression"); auto& array_expression = *n.children[0]; - auto& index_expression = *n.children[1]; - ASTNodeNaturalConversionChecker{index_expression, ASTNodeDataType::build<ASTNodeDataType::int_t>()}; - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + auto& index_expression = *n.children[1]; + ASTNodeNaturalConversionChecker{index_expression, ASTNodeDataType::build<ASTNodeDataType::int_t>()}; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + if (n.children.size() != 2) { + std::ostringstream message; + message << "invalid index type: " << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) + << rang::style::reset << " requires a single integer"; + throw ParseError(message.str(), index_expression.begin()); + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + for (size_t i = 1; i < n.children.size(); ++i) { + ASTNodeNaturalConversionChecker{*n.children[i], ASTNodeDataType::build<ASTNodeDataType::int_t>()}; + } + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + if (n.children.size() != 3) { + std::ostringstream message; + message << "invalid index type: " << rang::fgB::yellow << dataTypeName(n.children[0]->m_data_type) + << rang::style::reset << " requires two integers"; + throw ParseError(message.str(), n.children[1]->begin()); + } + + } else { std::ostringstream message; - message << "invalid types '" << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) - << rang::style::reset << '[' << dataTypeName(index_expression.m_data_type) << ']' - << "' for array subscript"; + message << "invalid subscript expression: " << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) + << rang::style::reset << " cannot be indexed"; throw ParseError(message.str(), n.begin()); - } else { - n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } } else if (n.is_type<language::B_set>()) { n.m_data_type = @@ -547,6 +580,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::double_t>()); } else if (n.is_type<language::vector_type>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getVectorDataType(n)); + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getMatrixDataType(n)); } else if (n.is_type<language::name_list>() or n.is_type<language::lvalue_list>() or n.is_type<language::function_argument_list>() or n.is_type<language::expression_list>()) { std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp index af8b1cbc0b4ca138f86601755493e90a089ff0c2..6484b85b0fed5626df7f46ade79c479cfe5f23ae 100644 --- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp @@ -2,11 +2,13 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/FunctionProcessor.hpp> +#include <language/node_processor/TupleToTinyMatrixProcessor.hpp> #include <language/node_processor/TupleToTinyVectorProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/FunctionTable.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/Exceptions.hpp> template <typename SymbolType> std::unique_ptr<IFunctionArgumentConverter> @@ -83,6 +85,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy } }; + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + switch (node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (node_sub_data_type.m_parent_node.children.size() == parameter_v.dimension()) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_argument_converter_for_string = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { return std::make_unique<FunctionArgumentToStringConverter>(parameter_id); }; @@ -116,9 +160,28 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy return get_function_argument_converter_for_vector(TinyVector<3>{}); } } - [[fallthrough]]; + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP } + case ASTNodeDataType::matrix_t: { + Assert(parameter_symbol.attributes().dataType().nbRows() == parameter_symbol.attributes().dataType().nbColumns()); + switch (parameter_symbol.attributes().dataType().nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP + } // LCOV_EXCL_START default: { throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); @@ -268,6 +331,51 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r } }; + auto get_function_processor_for_expression_matrix = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> { + using ReturnT = std::decay_t<decltype(return_v)>; + switch (function_component_expression.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((function_component_expression.m_data_type.nbRows() == return_v.nbRows()) and + (function_component_expression.m_data_type.nbColumns() == return_v.nbColumns())) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ReturnT>>(function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (function_component_expression.children.size() == return_v.dimension()) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, AggregateDataVariant>>( + function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (function_component_expression.is_type<language::integer>()) { + if (std::stoi(function_component_expression.string()) == 0) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ZeroType>>(function_component_expression); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_processor_for_value = [&]() { switch (return_value_type) { case ASTNodeDataType::bool_t: { @@ -304,6 +412,30 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(return_value_type.nbRows() == return_value_type.nbColumns()); + + switch (return_value_type.nbRows()) { + case 1: { + if (function_component_expression.m_data_type == ASTNodeDataType::matrix_t) { + return get_function_processor_for_expression_matrix(TinyMatrix<1>{}); + } else { + return get_function_processor_for_expression_value(TinyMatrix<1>{}); + } + } + case 2: { + return get_function_processor_for_expression_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_processor_for_expression_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid dimension in returned type", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_processor_for_expression_value(std::string{}); } @@ -335,11 +467,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node const ASTNodeDataType return_value_type = image_domain_node.m_data_type.contentType(); - if ((return_value_type == ASTNodeDataType::vector_t) and (return_value_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{expression_node, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{expression_node, return_value_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{expression_node, return_value_type}; function_processor->addFunctionExpressionProcessor( this->_getFunctionProcessor(return_value_type, node, expression_node)); @@ -354,11 +482,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node if (function_image_domain.is_type<language::vector_type>()) { ASTNodeDataType vector_type = getVectorDataType(function_image_domain); - if ((vector_type.dimension() == 1) and (function_expression.m_data_type != ASTNodeDataType::vector_t)) { - ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{function_expression, vector_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, vector_type}; + if (function_expression.is_type<language::expression_list>()) { Assert(vector_type.dimension() == function_expression.children.size()); @@ -421,6 +546,73 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node node.m_node_processor = std::move(function_processor); } + } else if (function_image_domain.is_type<language::matrix_type>()) { + ASTNodeDataType matrix_type = getMatrixDataType(function_image_domain); + + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, matrix_type}; + + if (function_expression.is_type<language::expression_list>()) { + Assert(matrix_type.nbRows() * matrix_type.nbColumns() == function_expression.children.size()); + + for (size_t i = 0; i < matrix_type.nbRows() * matrix_type.nbColumns(); ++i) { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(ASTNodeDataType::build<ASTNodeDataType::double_t>(), node, + *function_expression.children[i])); + } + + switch (matrix_type.nbRows()) { + case 2: { + node.m_node_processor = + std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 2>>(node, std::move(function_processor)); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 3>>(node, std::move(function_processor)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid matrix_t dimensions", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (function_expression.is_type<language::integer>()) { + if (std::stoi(function_expression.string()) == 0) { + switch (matrix_type.nbRows()) { + case 1: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<1>, ZeroType>>(function_expression); + break; + } + case 2: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<2>, ZeroType>>(function_expression); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<3>, ZeroType>>(function_expression); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("invalid matrix dimensions"); + } + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting 0"); + // LCOV_EXCL_STOP + } + } else { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(matrix_type, node, function_expression)); + + node.m_node_processor = std::move(function_processor); + } + } else { if (function_expression.is_type<language::expression_list>()) { ASTNode& image_domain_node = function_image_domain; diff --git a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp index 9da8319fa5119fa22b3d6ce95e6e1e1084b401f2..b135bbcd3f67127c23c044880d96ad28ee6527c5 100644 --- a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp @@ -2,8 +2,8 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/AffectationProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> template <typename OperatorT> @@ -77,6 +77,44 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; + auto add_affectation_processor_for_matrix_data = [&](const auto& value, + const ASTNodeSubDataType& node_sub_data_type) { + using ValueT = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else { + add_affectation_processor_for_data(value, node_sub_data_type); + } + } else if constexpr (std::is_same_v<ValueT, TinyMatrix<2>> or std::is_same_v<ValueT, TinyMatrix<3>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else if ((node_sub_data_type.m_data_type == ASTNodeDataType::list_t) and + (node_sub_data_type.m_parent_node.children.size() == value.nbRows() * value.nbColumns())) { + list_affectation_processor->template add<ValueT, AggregateDataVariant>(value_node); + } else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + list_affectation_processor->template add<ValueT, ZeroType>(value_node); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid operand value", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension", std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + throw ParseError("unexpected error: invalid value type", std::vector{node_sub_data_type.m_parent_node.begin()}); + } + }; + auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { switch (node_sub_data_type.m_data_type) { @@ -176,6 +214,29 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } break; } + case ASTNodeDataType::matrix_t: { + Assert(value_type.nbRows() == value_type.nbColumns()); + switch (value_type.nbRows()) { + case 1: { + add_affectation_processor_for_matrix_data(TinyMatrix<1>{}, node_sub_data_type); + break; + } + case 2: { + add_affectation_processor_for_matrix_data(TinyMatrix<2>{}, node_sub_data_type); + break; + } + case 3: { + add_affectation_processor_for_matrix_data(TinyMatrix<3>{}, node_sub_data_type); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("invalid dimension", std::vector{value_node.begin()}); + } + // LCOV_EXCL_STOP + } + break; + } case ASTNodeDataType::string_t: { add_affectation_processor_for_string_data(node_sub_data_type); break; @@ -188,12 +249,7 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; - if ((value_node.m_data_type != rhs_node_sub_data_type.m_data_type) and - (value_node.m_data_type == ASTNodeDataType::vector_t) and (value_node.m_data_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, value_node.m_data_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(rhs_node_sub_data_type, value_node.m_data_type); add_affectation_processor_for_value(value_node.m_data_type, rhs_node_sub_data_type); } diff --git a/src/language/ast/CMakeLists.txt b/src/language/ast/CMakeLists.txt index 458690fbd4fa341b257fcfe997e489048d1f961c..a901d7c09679c37d54420f51d347537fe5e18aef 100644 --- a/src/language/ast/CMakeLists.txt +++ b/src/language/ast/CMakeLists.txt @@ -9,7 +9,6 @@ add_library(PugsLanguageAST ASTNodeBuiltinFunctionExpressionBuilder.cpp ASTNodeDataTypeBuilder.cpp ASTNodeDataTypeChecker.cpp -# ASTNodeDataType.cpp ASTNodeDataTypeFlattener.cpp ASTNodeDeclarationToAffectationConverter.cpp ASTNodeEmptyBlockCleaner.cpp @@ -19,7 +18,6 @@ add_library(PugsLanguageAST ASTNodeIncDecExpressionBuilder.cpp ASTNodeJumpPlacementChecker.cpp ASTNodeListAffectationExpressionBuilder.cpp - ASTNodeNaturalConversionChecker.cpp ASTNodeUnaryOperatorExpressionBuilder.cpp ASTSymbolInitializationChecker.cpp ASTSymbolTableBuilder.cpp diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp index 88d4488365427aecbffd48abf9e37a3b08e55707..959578619b023bc28236b1ae52141b018aad8c69 100644 --- a/src/language/node_processor/AffectationProcessor.hpp +++ b/src/language/node_processor/AffectationProcessor.hpp @@ -125,20 +125,40 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs = std::get<DataT>(rhs); } else if constexpr (std::is_same_v<DataT, AggregateDataVariant>) { const AggregateDataVariant& v = std::get<AggregateDataVariant>(rhs); - static_assert(is_tiny_vector_v<ValueT>, "expecting lhs TinyVector"); - for (size_t i = 0; i < m_lhs.dimension(); ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - m_lhs[i] = vi; - } else { - // LCOV_EXCL_START - throw UnexpectedError("unexpected rhs type in affectation"); - // LCOV_EXCL_STOP - } - }, - v[i]); + if constexpr (is_tiny_vector_v<ValueT>) { + for (size_t i = 0; i < m_lhs.dimension(); ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + m_lhs[i] = vi; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[i]); + } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + for (size_t i = 0, l = 0; i < m_lhs.nbRows(); ++i) { + for (size_t j = 0; j < m_lhs.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Aij_T, double>) { + m_lhs(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + } else { + static_assert(is_tiny_matrix_v<ValueT> or is_tiny_vector_v<ValueT>, "invalid rhs type"); } } else if constexpr (std::is_same_v<TinyVector<1>, ValueT>) { std::visit( @@ -153,6 +173,21 @@ class AffectationExecutor final : public IAffectationExecutor } }, rhs); + } else if constexpr (std::is_same_v<TinyMatrix<1>, ValueT>) { + std::visit( + [&](auto&& v) { + using Vi_T = std::decay_t<decltype(v)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + m_lhs = v; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + rhs); + } else { + throw UnexpectedError("invalid value type"); } } else { AffOp<OperatorT>().eval(m_lhs, std::get<DataT>(rhs)); @@ -168,7 +203,93 @@ class AffectationExecutor final : public IAffectationExecutor }; template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT> -class ComponentAffectationExecutor final : public IAffectationExecutor +class MatrixComponentAffectationExecutor final : public IAffectationExecutor +{ + private: + ArrayT& m_lhs_array; + ASTNode& m_index0_expression; + ASTNode& m_index1_expression; + + static inline const bool m_is_defined{[] { + if constexpr (not std::is_same_v<typename ArrayT::data_type, ValueT>) { + return false; + } else if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) { + if constexpr (not std::is_same_v<OperatorT, language::eq_op>) { + return false; + } + } + return true; + }()}; + + public: + MatrixComponentAffectationExecutor(ASTNode& node, + ArrayT& lhs_array, + ASTNode& index0_expression, + ASTNode& index1_expression) + : m_lhs_array{lhs_array}, m_index0_expression{index0_expression}, m_index1_expression{index1_expression} + { + // LCOV_EXCL_START + if constexpr (not m_is_defined) { + throw ParseError("unexpected error: invalid operands to affectation expression", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + + PUGS_INLINE void + affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) + { + if constexpr (m_is_defined) { + auto get_index_value = [&](DataVariant&& value_variant) -> int64_t { + int64_t index_value = 0; + std::visit( + [&](auto&& value) { + using IndexValueT = std::decay_t<decltype(value)>; + if constexpr (std::is_integral_v<IndexValueT>) { + index_value = value; + } else { + // LCOV_EXCL_START + throw UnexpectedError("invalid index type"); + // LCOV_EXCL_STOP + } + }, + value_variant); + return index_value; + }; + + const int64_t index0_value = get_index_value(m_index0_expression.execute(exec_policy)); + const int64_t index1_value = get_index_value(m_index1_expression.execute(exec_policy)); + + if constexpr (std::is_same_v<ValueT, std::string>) { + if constexpr (std::is_same_v<OperatorT, language::eq_op>) { + if constexpr (std::is_same_v<std::string, DataT>) { + m_lhs_array(index0_value, index1_value) = std::get<DataT>(rhs); + } else { + m_lhs_array(index0_value, index1_value) = std::to_string(std::get<DataT>(rhs)); + } + } else { + if constexpr (std::is_same_v<std::string, DataT>) { + m_lhs_array(index0_value, index1_value) += std::get<std::string>(rhs); + } else { + m_lhs_array(index0_value, index1_value) += std::to_string(std::get<DataT>(rhs)); + } + } + } else { + if constexpr (std::is_same_v<OperatorT, language::eq_op>) { + if constexpr (std::is_same_v<ValueT, DataT>) { + m_lhs_array(index0_value, index1_value) = std::get<DataT>(rhs); + } else { + m_lhs_array(index0_value, index1_value) = static_cast<ValueT>(std::get<DataT>(rhs)); + } + } else { + AffOp<OperatorT>().eval(m_lhs_array(index0_value, index1_value), std::get<DataT>(rhs)); + } + } + } + } +}; + +template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT> +class VectorComponentAffectationExecutor final : public IAffectationExecutor { private: ArrayT& m_lhs_array; @@ -186,7 +307,7 @@ class ComponentAffectationExecutor final : public IAffectationExecutor }()}; public: - ComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression) + VectorComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression) : m_lhs_array{lhs_array}, m_index_expression{index_expression} { // LCOV_EXCL_START @@ -289,53 +410,99 @@ class AffectationProcessor final : public INodeProcessor Assert(found); DataVariant& value = i_symbol->attributes().value(); - // LCOV_EXCL_START - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { - throw ParseError("unexpected error: invalid lhs (expecting R^d)", - std::vector{array_subscript_expression.begin()}); - } - // LCOV_EXCL_STOP + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + Assert(array_subscript_expression.children.size() == 2); - auto& index_expression = *array_subscript_expression.children[1]; + auto& index_expression = *array_subscript_expression.children[1]; - switch (array_expression.m_data_type.dimension()) { - case 1: { - using ArrayTypeT = TinyVector<1>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + switch (array_expression.m_data_type.dimension()) { + case 1: { + using ArrayTypeT = TinyVector<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } - case 2: { - using ArrayTypeT = TinyVector<2>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 2: { + using ArrayTypeT = TinyVector<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } - case 3: { - using ArrayTypeT = TinyVector<3>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 3: { + using ArrayTypeT = TinyVector<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_subscript_expression.children.size() == 3); + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + auto& index0_expression = *array_subscript_expression.children[1]; + auto& index1_expression = *array_subscript_expression.children[2]; + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + using ArrayTypeT = TinyMatrix<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + case 2: { + using ArrayTypeT = TinyMatrix<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + case 3: { + using ArrayTypeT = TinyMatrix<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else { // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); - } + throw UnexpectedError("invalid subscript expression"); // LCOV_EXCL_STOP } - } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid lhs", std::vector{node.children[0]->begin()}); @@ -391,6 +558,55 @@ class AffectationToTinyVectorFromListProcessor final : public INodeProcessor } }; +template <typename OperatorT, typename ValueT> +class AffectationToTinyMatrixFromListProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + DataVariant* m_lhs; + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy)); + + static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors"); + + ValueT v; + for (size_t i = 0, l = 0; i < v.nbRows(); ++i) { + for (size_t j = 0; j < v.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& child_value) { + using T = std::decay_t<decltype(child_value)>; + if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or + std::is_same_v<T, double>) { + v(i, j) = child_value; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); + // LCOV_EXCL_STOP + } + }, + children_values[l]); + } + } + + *m_lhs = v; + return {}; + } + + AffectationToTinyMatrixFromListProcessor(ASTNode& node) : m_node{node} + { + const std::string& symbol = m_node.children[0]->string(); + auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin()); + Assert(found); + + m_lhs = &i_symbol->attributes().value(); + } +}; + template <typename ValueT> class AffectationToTupleProcessor final : public INodeProcessor { @@ -420,14 +636,19 @@ class AffectationToTupleProcessor final : public INodeProcessor os << v; *m_lhs = std::vector<std::string>{os.str()}; } - } else if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_arithmetic_v<T>) { - *m_lhs = std::vector<TinyVector<1>>{TinyVector<1>{static_cast<double>(v)}}; - } else if constexpr (std::is_same_v<ValueT, TinyVector<2>> and std::is_same_v<T, int64_t>) { - Assert(v == 0); - *m_lhs = std::vector<TinyVector<2>>{TinyVector<2>{zero}}; - } else if constexpr (std::is_same_v<ValueT, TinyVector<3>> and std::is_same_v<T, int64_t>) { - Assert(v == 0); - *m_lhs = std::vector<TinyVector<3>>{TinyVector<3>{zero}}; + } else if constexpr (is_tiny_vector_v<ValueT> or is_tiny_matrix_v<ValueT>) { + if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_arithmetic_v<T>) { + *m_lhs = std::vector<TinyVector<1>>{TinyVector<1>{static_cast<double>(v)}}; + } else if constexpr (std::is_same_v<ValueT, TinyMatrix<1>> and std::is_arithmetic_v<T>) { + *m_lhs = std::vector<TinyMatrix<1>>{TinyMatrix<1>{static_cast<double>(v)}}; + } else if constexpr (std::is_same_v<T, int64_t>) { + Assert(v == 0); + *m_lhs = std::vector<ValueT>{ValueT{zero}}; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); + // LCOV_EXCL_STOP + } } else { // LCOV_EXCL_START throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); @@ -496,7 +717,7 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor }, child_value[j]); } - } else if constexpr (std::is_same_v<T, int64_t>) { + } else if constexpr (std::is_arithmetic_v<T>) { if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { tuple_value[i][0] = child_value; } else { @@ -510,6 +731,41 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor m_node.children[1]->children[i]->begin()); // LCOV_EXCL_STOP } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + if constexpr (std::is_same_v<T, AggregateDataVariant>) { + ValueT& A = tuple_value[i]; + Assert(A.nbRows() * A.nbColumns() == child_value.size()); + for (size_t j = 0, l = 0; j < A.nbRows(); ++j) { + for (size_t k = 0; k < A.nbColumns(); ++k, ++l) { + std::visit( + [&](auto&& Ajk) { + using Ti = std::decay_t<decltype(Ajk)>; + if constexpr (std::is_convertible_v<Ti, typename ValueT::data_type>) { + A(j, k) = Ajk; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } + }, + child_value[l]); + } + } + } else if constexpr (std::is_arithmetic_v<T>) { + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + tuple_value[i](0, 0) = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } } else { // LCOV_EXCL_START throw ParseError("unexpected error: unexpected right hand side type in affectation", @@ -654,51 +910,98 @@ class ListAffectationProcessor final : public INodeProcessor Assert(found); DataVariant& value = i_symbol->attributes().value(); - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid lhs (expecting R^d)", - std::vector{array_subscript_expression.begin()}); - // LCOV_EXCL_STOP - } + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + Assert(array_subscript_expression.children.size() == 2); - auto& index_expression = *array_subscript_expression.children[1]; + auto& index_expression = *array_subscript_expression.children[1]; - switch (array_expression.m_data_type.dimension()) { - case 1: { - using ArrayTypeT = TinyVector<1>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + switch (array_expression.m_data_type.dimension()) { + case 1: { + using ArrayTypeT = TinyVector<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - case 2: { - using ArrayTypeT = TinyVector<2>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 2: { + using ArrayTypeT = TinyVector<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - case 3: { - using ArrayTypeT = TinyVector<3>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 3: { + using ArrayTypeT = TinyVector<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_subscript_expression.children.size() == 3); + + auto& index0_expression = *array_subscript_expression.children[1]; + auto& index1_expression = *array_subscript_expression.children[2]; + + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + using ArrayTypeT = TinyMatrix<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + case 2: { + using ArrayTypeT = TinyMatrix<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + case 3: { + using ArrayTypeT = TinyMatrix<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); - } - // LCOV_EXCL_STOP } } else { // LCOV_EXCL_START diff --git a/src/language/node_processor/ArraySubscriptProcessor.hpp b/src/language/node_processor/ArraySubscriptProcessor.hpp index 3a71c01846b9e17ce28818d732716aea764d2496..43eb962b9d724b19b4cfe7ab808f52dd73d1fce2 100644 --- a/src/language/node_processor/ArraySubscriptProcessor.hpp +++ b/src/language/node_processor/ArraySubscriptProcessor.hpp @@ -15,9 +15,7 @@ class ArraySubscriptProcessor : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - auto& index_expression = *m_array_subscript_expression.children[1]; - - const int64_t index_value = [&](DataVariant&& value_variant) -> int64_t { + auto get_index_value = [&](DataVariant&& value_variant) -> int64_t { int64_t index_value = 0; std::visit( [&](auto&& value) { @@ -26,20 +24,39 @@ class ArraySubscriptProcessor : public INodeProcessor index_value = value; } else { // LCOV_EXCL_START - throw ParseError("unexpected error: invalid index type", std::vector{index_expression.begin()}); + throw UnexpectedError("invalid index type"); // LCOV_EXCL_STOP } }, value_variant); return index_value; - }(index_expression.execute(exec_policy)); + }; + + if constexpr (is_tiny_vector_v<ArrayTypeT>) { + auto& index_expression = *m_array_subscript_expression.children[1]; + + const int64_t index_value = get_index_value(index_expression.execute(exec_policy)); + + auto& array_expression = *m_array_subscript_expression.children[0]; + + auto&& array_value = array_expression.execute(exec_policy); + ArrayTypeT& array = std::get<ArrayTypeT>(array_value); + + return array[index_value]; + } else if constexpr (is_tiny_matrix_v<ArrayTypeT>) { + auto& index0_expression = *m_array_subscript_expression.children[1]; + auto& index1_expression = *m_array_subscript_expression.children[2]; + + const int64_t index0_value = get_index_value(index0_expression.execute(exec_policy)); + const int64_t index1_value = get_index_value(index1_expression.execute(exec_policy)); - auto& array_expression = *m_array_subscript_expression.children[0]; + auto& array_expression = *m_array_subscript_expression.children[0]; - auto&& array_value = array_expression.execute(exec_policy); - ArrayTypeT& array = std::get<ArrayTypeT>(array_value); + auto&& array_value = array_expression.execute(exec_policy); + ArrayTypeT& array = std::get<ArrayTypeT>(array_value); - return array[index_value]; + return array(index0_value, index1_value); + } } ArraySubscriptProcessor(ASTNode& array_subscript_expression) diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index e226e5336f4add578c5df2abe102121d067f6402..af08e7806a4fd5daa01a7aecc98a84e52291040f 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -121,6 +121,55 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver FunctionTinyVectorArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} }; +template <typename ExpectedValueType, typename ProvidedValueType> +class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + if constexpr (std::is_same_v<ExpectedValueType, ProvidedValueType>) { + std::visit( + [&](auto&& v) { + using ValueT = std::decay_t<decltype(v)>; + if constexpr (std::is_same_v<ValueT, ExpectedValueType>) { + exec_policy.currentContext()[m_argument_id] = std::move(value); + } else if constexpr (std::is_same_v<ValueT, AggregateDataVariant>) { + ExpectedValueType matrix_value{}; + for (size_t i = 0, l = 0; i < matrix_value.nbRows(); ++i) { + for (size_t j = 0; j < matrix_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& A_ij) { + using Vi_T = std::decay_t<decltype(A_ij)>; + if constexpr (std::is_arithmetic_v<Vi_T>) { + matrix_value(i, j) = A_ij; + } else { + throw UnexpectedError(demangle<Vi_T>() + " unexpected aggregate value type"); + } + }, + v[l]); + } + } + exec_policy.currentContext()[m_argument_id] = std::move(matrix_value); + } + }, + value); + } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; + } else { + static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); + exec_policy.currentContext()[m_argument_id] = + std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + } + return {}; + } + + FunctionTinyMatrixArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} +}; + template <typename ContentType, typename ProvidedValueType> class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter { @@ -147,7 +196,8 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter list_value.emplace_back(std::move(v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); - } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType>) { + } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { TupleType list_value; list_value.reserve(v.size()); for (size_t i = 0; i < v.size(); ++i) { @@ -160,7 +210,8 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter demangle<ContentType>() + "'"); // LCOV_EXCL_STOP } - } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType>) { + } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + @@ -205,7 +256,7 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter using Vi_T = std::decay_t<decltype(vi)>; if constexpr (std::is_same_v<Vi_T, ContentType>) { list_value.emplace_back(vi); - } else if constexpr (is_tiny_vector_v<ContentType>) { + } else if constexpr (is_tiny_vector_v<ContentType> or is_tiny_matrix_v<ContentType>) { // LCOV_EXCL_START throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<Vi_T>() + "' to '" + demangle<ContentType>() + "'"); @@ -234,7 +285,8 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter } else if constexpr (std::is_same_v<ValueT, ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{v}); } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ValueT> and - not is_tiny_vector_v<ContentType>) { + not is_tiny_vector_v<ContentType> and not is_tiny_matrix_v<ValueT> and + not is_tiny_matrix_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { // LCOV_EXCL_START diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp index b28c2ce499538818ce60b1ac36a1d158114409de..f8c86ea9e5cb51e169e48ecadc0b4770ef873cc8 100644 --- a/src/language/node_processor/FunctionProcessor.hpp +++ b/src/language/node_processor/FunctionProcessor.hpp @@ -21,18 +21,35 @@ class FunctionExpressionProcessor final : public INodeProcessor if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { return m_function_expression.execute(exec_policy); } else if constexpr (std::is_same_v<AggregateDataVariant, ExpressionValueType>) { - static_assert(is_tiny_vector_v<ReturnType>, "unexpected return type"); + static_assert(is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>, "unexpected return type"); ReturnType return_value{}; auto value = std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)); - for (size_t i = 0; i < ReturnType::Dimension; ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - return_value[i] = vi; - } - }, - value[i]); + if constexpr (is_tiny_vector_v<ReturnType>) { + for (size_t i = 0; i < ReturnType::Dimension; ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value[i] = vi; + } + }, + value[i]); + } + } else { + static_assert(is_tiny_matrix_v<ReturnType>); + + for (size_t i = 0, l = 0; i < return_value.nbRows(); ++i) { + for (size_t j = 0; j < return_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Vi_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value(i, j) = Aij; + } + }, + value[l]); + } + } } return return_value; } else if constexpr (std::is_same_v<ReturnType, std::string>) { diff --git a/src/language/node_processor/TupleToTinyMatrixProcessor.hpp b/src/language/node_processor/TupleToTinyMatrixProcessor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..77bc45e8419f7d2e58447247aef3975c337813c5 --- /dev/null +++ b/src/language/node_processor/TupleToTinyMatrixProcessor.hpp @@ -0,0 +1,53 @@ +#ifndef TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP +#define TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP + +#include <language/ast/ASTNode.hpp> +#include <language/node_processor/INodeProcessor.hpp> + +template <typename TupleProcessorT, size_t N> +class TupleToTinyMatrixProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + std::unique_ptr<TupleProcessorT> m_tuple_processor; + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + AggregateDataVariant v = std::get<AggregateDataVariant>(m_tuple_processor->execute(exec_policy)); + + Assert(v.size() == N * N); + + TinyMatrix<N> A; + + for (size_t i = 0, l = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using ValueT = std::decay_t<decltype(Aij)>; + if constexpr (std::is_arithmetic_v<ValueT>) { + A(i, j) = Aij; + } else { + // LCOV_EXCL_START + Assert(false, "unexpected value type"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + + return DataVariant{std::move(A)}; + } + + TupleToTinyMatrixProcessor(ASTNode& node) : m_node{node}, m_tuple_processor{std::make_unique<TupleProcessorT>(node)} + {} + + TupleToTinyMatrixProcessor(ASTNode& node, std::unique_ptr<TupleProcessorT>&& tuple_processor) + : m_node{node}, m_tuple_processor{std::move(tuple_processor)} + {} +}; + +#endif // TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP diff --git a/src/language/utils/ASTNodeDataType.cpp b/src/language/utils/ASTNodeDataType.cpp index 02a352ac3d0f7435c89445d4113b037bd0f54c2c..81f3fc81caba045ca4ae719ddfa21c09c1c8fc98 100644 --- a/src/language/utils/ASTNodeDataType.cpp +++ b/src/language/utils/ASTNodeDataType.cpp @@ -16,9 +16,42 @@ getVectorDataType(const ASTNode& type_node) throw ParseError("unexpected non integer constant dimension", dimension_node.begin()); } const size_t dimension = std::stol(dimension_node.string()); + if (not(dimension > 0 and dimension <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", dimension_node.begin()); + } return ASTNodeDataType::build<ASTNodeDataType::vector_t>(dimension); } +ASTNodeDataType +getMatrixDataType(const ASTNode& type_node) +{ + if (not(type_node.is_type<language::matrix_type>() and (type_node.children.size() == 3))) { + throw ParseError("unexpected node type", type_node.begin()); + } + + ASTNode& dimension0_node = *type_node.children[1]; + if (not dimension0_node.is_type<language::integer>()) { + throw ParseError("unexpected non integer constant dimension", dimension0_node.begin()); + } + const size_t dimension0 = std::stol(dimension0_node.string()); + + ASTNode& dimension1_node = *type_node.children[2]; + if (not dimension1_node.is_type<language::integer>()) { + throw ParseError("unexpected non integer constant dimension", dimension1_node.begin()); + } + const size_t dimension1 = std::stol(dimension1_node.string()); + + if (dimension0 != dimension1) { + throw ParseError("only square matrices are supported", type_node.begin()); + } + + if (not(dimension0 > 0 and dimension0 <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", dimension0_node.begin()); + } + + return ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension0, dimension1); +} + std::string dataTypeName(const ASTNodeDataType& data_type) { @@ -42,6 +75,9 @@ dataTypeName(const ASTNodeDataType& data_type) case ASTNodeDataType::vector_t: name = "R^" + std::to_string(data_type.dimension()); break; + case ASTNodeDataType::matrix_t: + name = "R^" + std::to_string(data_type.nbRows()) + "x" + std::to_string(data_type.nbColumns()); + break; case ASTNodeDataType::tuple_t: name = "tuple(" + dataTypeName(data_type.contentType()) + ')'; break; @@ -129,6 +165,9 @@ isNaturalConversion(const ASTNodeDataType& data_type, const ASTNodeDataType& tar return (data_type.nameOfTypeId() == target_data_type.nameOfTypeId()); } else if (data_type == ASTNodeDataType::vector_t) { return (data_type.dimension() == target_data_type.dimension()); + } else if (data_type == ASTNodeDataType::matrix_t) { + return ((data_type.nbRows() == target_data_type.nbRows()) and + (data_type.nbColumns() == target_data_type.nbColumns())); } else { return true; } diff --git a/src/language/utils/ASTNodeDataType.hpp b/src/language/utils/ASTNodeDataType.hpp index 951aff6e4d070e4747f5eb3f5a5fcc93f8ba5b8e..c31c9e15b6e8df0055aa4854bb7ede0a73318abd 100644 --- a/src/language/utils/ASTNodeDataType.hpp +++ b/src/language/utils/ASTNodeDataType.hpp @@ -14,6 +14,8 @@ class ASTNodeDataType; ASTNodeDataType getVectorDataType(const ASTNode& type_node); +ASTNodeDataType getMatrixDataType(const ASTNode& type_node); + std::string dataTypeName(const std::vector<ASTNodeDataType>& data_type_vector); std::string dataTypeName(const ASTNodeDataType& data_type); @@ -33,9 +35,10 @@ class ASTNodeDataType unsigned_int_t = 2, double_t = 3, vector_t = 4, - tuple_t = 5, - list_t = 6, - string_t = 7, + matrix_t = 5, + tuple_t = 6, + list_t = 7, + string_t = 8, typename_t = 10, type_name_id_t = 11, type_id_t = 21, @@ -49,6 +52,7 @@ class ASTNodeDataType using DataTypeDetails = std::variant<std::monostate, size_t, + std::array<size_t, 2>, std::string, std::shared_ptr<const ASTNodeDataType>, std::vector<std::shared_ptr<const ASTNodeDataType>>>; @@ -64,6 +68,22 @@ class ASTNodeDataType return std::get<size_t>(m_details); } + PUGS_INLINE + size_t + nbRows() const + { + Assert(std::holds_alternative<std::array<size_t, 2>>(m_details)); + return std::get<std::array<size_t, 2>>(m_details)[0]; + } + + PUGS_INLINE + size_t + nbColumns() const + { + Assert(std::holds_alternative<std::array<size_t, 2>>(m_details)); + return std::get<std::array<size_t, 2>>(m_details)[1]; + } + PUGS_INLINE const std::string& nameOfTypeId() const @@ -129,6 +149,14 @@ class ASTNodeDataType return ASTNodeDataType{data_type, dimension}; } + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const size_t nb_rows, const size_t nb_columns) + { + static_assert((data_type == matrix_t), "incorrect data_type construction: cannot have dimension"); + return ASTNodeDataType{data_type, nb_rows, nb_columns}; + } + template <DataType data_type> [[nodiscard]] static ASTNodeDataType build(const std::string& type_name) @@ -171,6 +199,10 @@ class ASTNodeDataType explicit ASTNodeDataType(DataType data_type, const size_t dimension) : m_data_type{data_type}, m_details{dimension} {} + explicit ASTNodeDataType(DataType data_type, const size_t nb_rows, const size_t nb_columns) + : m_data_type{data_type}, m_details{std::array{nb_rows, nb_columns}} + {} + explicit ASTNodeDataType(DataType data_type, const std::string& type_name) : m_data_type{data_type}, m_details{type_name} {} diff --git a/src/language/utils/ASTNodeDataTypeTraits.hpp b/src/language/utils/ASTNodeDataTypeTraits.hpp index ed668e5e40e1ff4920602014367f379aa384adb9..2ba019646dac53291dc73169fadc47e737f68bad 100644 --- a/src/language/utils/ASTNodeDataTypeTraits.hpp +++ b/src/language/utils/ASTNodeDataTypeTraits.hpp @@ -1,6 +1,7 @@ #ifndef AST_NODE_DATA_TYPE_TRAITS_HPP #define AST_NODE_DATA_TYPE_TRAITS_HPP +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> #include <language/utils/ASTNodeDataType.hpp> #include <language/utils/FunctionSymbolId.hpp> @@ -27,6 +28,8 @@ inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = ASTNodeDataType::build<ASTNodeDataType::function_t>(); template <size_t N> inline ASTNodeDataType ast_node_data_type_from<TinyVector<N>> = ASTNodeDataType::build<ASTNodeDataType::vector_t>(N); +template <size_t N> +inline ASTNodeDataType ast_node_data_type_from<TinyMatrix<N>> = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(N, N); template <typename T> inline ASTNodeDataType ast_node_data_type_from<std::vector<T>> = diff --git a/src/language/ast/ASTNodeNaturalConversionChecker.cpp b/src/language/utils/ASTNodeNaturalConversionChecker.cpp similarity index 61% rename from src/language/ast/ASTNodeNaturalConversionChecker.cpp rename to src/language/utils/ASTNodeNaturalConversionChecker.cpp index cf1ccdabf1cffa0e2256ed17e13dca6717229fc9..81a277574753dc3e8cc75871a07d8d8a3b1a0cc6 100644 --- a/src/language/ast/ASTNodeNaturalConversionChecker.cpp +++ b/src/language/utils/ASTNodeNaturalConversionChecker.cpp @@ -1,4 +1,4 @@ -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/PEGGrammar.hpp> #include <language/utils/ParseError.hpp> @@ -13,7 +13,9 @@ ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalTypeConversion( { if (not isNaturalConversion(data_type, target_data_type)) { if constexpr (std::is_same_v<RToR1ConversionStrategy, AllowRToR1Conversion>) { - if ((target_data_type == ASTNodeDataType::vector_t) and (target_data_type.dimension() == 1)) { + if (((target_data_type == ASTNodeDataType::vector_t) and (target_data_type.dimension() == 1)) or + ((target_data_type == ASTNodeDataType::matrix_t) and (target_data_type.nbRows() == 1) and + (target_data_type.nbColumns() == 1))) { if (isNaturalConversion(data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>())) { return; } @@ -25,7 +27,9 @@ ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalTypeConversion( << rang::fg::reset; if ((data_type == ASTNodeDataType::undefined_t) or (target_data_type == ASTNodeDataType::undefined_t)) { + // LCOV_EXCL_START throw UnexpectedError(error_message.str()); + // LCOV_EXCL_STOP } else { throw ParseError(error_message.str(), node.begin()); } @@ -46,7 +50,10 @@ ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalExpressionConve case ASTNodeDataType::list_t: { const auto& content_type_list = data_type.contentTypeList(); if (content_type_list.size() != target_data_type.dimension()) { - throw ParseError("incompatible dimensions in affectation", std::vector{node.begin()}); + std::ostringstream os; + os << "incompatible dimensions in affectation: expecting " << target_data_type.dimension() << ", but provided " + << content_type_list.size(); + throw ParseError(os.str(), std::vector{node.begin()}); } Assert(content_type_list.size() == node.children.size()); @@ -82,34 +89,61 @@ ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalExpressionConve this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); } } + } else if (target_data_type == ASTNodeDataType::matrix_t) { + switch (data_type) { + case ASTNodeDataType::list_t: { + const auto& content_type_list = data_type.contentTypeList(); + if (content_type_list.size() != (target_data_type.nbRows() * target_data_type.nbColumns())) { + std::ostringstream os; + os << "incompatible dimensions in affectation: expecting " + << target_data_type.nbRows() * target_data_type.nbColumns() << ", but provided " << content_type_list.size(); + throw ParseError(os.str(), std::vector{node.begin()}); + } + + Assert(content_type_list.size() == node.children.size()); + for (size_t i = 0; i < content_type_list.size(); ++i) { + const auto& child_type = *content_type_list[i]; + const auto& child_node = *node.children[i]; + Assert(child_type == child_node.m_data_type); + this->_checkIsNaturalExpressionConversion(child_node, child_type, + ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } + + break; + } + case ASTNodeDataType::matrix_t: { + if ((data_type.nbRows() != target_data_type.nbRows()) or + (data_type.nbColumns() != target_data_type.nbColumns())) { + std::ostringstream error_message; + error_message << "invalid implicit conversion: "; + error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type) + << rang::fg::reset; + throw ParseError(error_message.str(), std::vector{node.begin()}); + } + break; + } + case ASTNodeDataType::int_t: { + if (node.is_type<language::integer>()) { + if (std::stoi(node.string()) == 0) { + break; + } + } + [[fallthrough]]; + } + default: { + this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); + } + } } else if (target_data_type == ASTNodeDataType::tuple_t) { const ASTNodeDataType& target_content_type = target_data_type.contentType(); if (node.m_data_type == ASTNodeDataType::tuple_t) { this->_checkIsNaturalExpressionConversion(node, data_type.contentType(), target_content_type); } else if (node.m_data_type == ASTNodeDataType::list_t) { - if ((target_data_type.contentType() == ASTNodeDataType::vector_t) and - (target_data_type.contentType().dimension() == 1)) { - for (const auto& child : node.children) { - if (not isNaturalConversion(child->m_data_type, target_data_type)) { - this->_checkIsNaturalExpressionConversion(*child, child->m_data_type, - ASTNodeDataType::build<ASTNodeDataType::double_t>()); - } - } - } else { - for (const auto& child : node.children) { - this->_checkIsNaturalExpressionConversion(*child, child->m_data_type, target_content_type); - } + for (const auto& child : node.children) { + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(*child, target_data_type.contentType()); } } else { - if ((target_data_type.contentType() == ASTNodeDataType::vector_t) and - (target_data_type.contentType().dimension() == 1)) { - if (not isNaturalConversion(data_type, target_data_type)) { - this->_checkIsNaturalExpressionConversion(node, data_type, - ASTNodeDataType::build<ASTNodeDataType::double_t>()); - } - } else { - this->_checkIsNaturalExpressionConversion(node, data_type, target_content_type); - } + this->_checkIsNaturalExpressionConversion(node, data_type, target_content_type); } } else { this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); diff --git a/src/language/ast/ASTNodeNaturalConversionChecker.hpp b/src/language/utils/ASTNodeNaturalConversionChecker.hpp similarity index 100% rename from src/language/ast/ASTNodeNaturalConversionChecker.hpp rename to src/language/utils/ASTNodeNaturalConversionChecker.hpp diff --git a/src/language/utils/AffectationProcessorBuilder.hpp b/src/language/utils/AffectationProcessorBuilder.hpp index 02042ad8825c7c890727edba6dd5dc41691a7d9e..8d11599a01c67566e0f88fcb34cad581bc76de45 100644 --- a/src/language/utils/AffectationProcessorBuilder.hpp +++ b/src/language/utils/AffectationProcessorBuilder.hpp @@ -4,6 +4,7 @@ #include <algebra/TinyVector.hpp> #include <language/PEGGrammar.hpp> #include <language/node_processor/AffectationProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/IAffectationProcessorBuilder.hpp> #include <type_traits> @@ -50,6 +51,7 @@ class AffectationToTupleFromListProcessorBuilder final : public IAffectationProc std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const { + ASTNodeNaturalConversionChecker(*node.children[1], node.children[0]->m_data_type); return std::make_unique<AffectationToTupleFromListProcessor<ValueT>>(node); } }; @@ -66,6 +68,18 @@ class AffectationToTinyVectorFromListProcessorBuilder final : public IAffectatio } }; +template <typename OperatorT, typename ValueT> +class AffectationToTinyMatrixFromListProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationToTinyMatrixFromListProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<AffectationToTinyMatrixFromListProcessor<OperatorT, ValueT>>(node); + } +}; + template <typename OperatorT, typename ValueT> class AffectationFromZeroProcessorBuilder final : public IAffectationProcessorBuilder { @@ -76,10 +90,9 @@ class AffectationFromZeroProcessorBuilder final : public IAffectationProcessorBu { if (std::stoi(node.children[1]->string()) == 0) { return std::make_unique<AffectationFromZeroProcessor<ValueT>>(node); + } else { + throw ParseError("invalid integral value (0 is the solely valid value)", std::vector{node.children[1]->begin()}); } - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid integral value", std::vector{node.children[1]->begin()}); - // LCOV_EXCL_STOP } }; diff --git a/src/language/utils/AffectationRegisterForRnxn.cpp b/src/language/utils/AffectationRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efa95d92bbcfd32af37f6c3f69477d7ee985a560 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRnxn.cpp @@ -0,0 +1,143 @@ +#include <language/utils/AffectationRegisterForRnxn.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationFromZeroProcessorBuilder<language::eq_op, TinyMatrix<Dimension>>>()); + + repository.addAffectation<language::eq_op>(Rnxn, + ASTNodeDataType::build<ASTNodeDataType::list_t>( + std::vector<std::shared_ptr<const ASTNodeDataType>>{}), + std::make_shared<AffectationToTinyMatrixFromListProcessorBuilder< + language::eq_op, TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); +} + +template <> +void +AffectationRegisterForRnxn<1>::_register_eq_op() +{ + constexpr size_t Dimension = 1; + + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, bool>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, uint64_t>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, int64_t>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, double>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addAffectation<language::pluseq_op>(Rnxn, Rnxn, + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addAffectation<language::minuseq_op>(Rnxn, Rnxn, + std::make_shared<AffectationProcessorBuilder< + language::minuseq_op, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, uint64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, int64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, double>>()); +} + +template <size_t Dimension> +AffectationRegisterForRnxn<Dimension>::AffectationRegisterForRnxn() +{ + BasicAffectationRegisterFor<TinyMatrix<Dimension>>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); +} + +template class AffectationRegisterForRnxn<1>; +template class AffectationRegisterForRnxn<2>; +template class AffectationRegisterForRnxn<3>; diff --git a/src/language/utils/AffectationRegisterForRnxn.hpp b/src/language/utils/AffectationRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33f2ad07d3911f2b7488cd10b321be5e4f59f5e4 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRnxn.hpp @@ -0,0 +1,19 @@ +#ifndef AFFECTATION_REGISTER_FOR_RNXN_HPP +#define AFFECTATION_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class AffectationRegisterForRnxn +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + + public: + AffectationRegisterForRnxn(); +}; + +#endif // AFFECTATION_REGISTER_FOR_RNXN_HPP diff --git a/src/language/utils/AffectationRegisterForString.cpp b/src/language/utils/AffectationRegisterForString.cpp index 37446e243d6fcb14fcd6c3e97a845d17b831c2e6..7946aeb45fc72e7569e5d8299e184c220baa7106 100644 --- a/src/language/utils/AffectationRegisterForString.cpp +++ b/src/language/utils/AffectationRegisterForString.cpp @@ -39,6 +39,18 @@ AffectationRegisterForString::_register_eq_op() language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyVector<3>>>()); + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<1>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<2>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<3>>>()); + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), ASTNodeDataType::build<ASTNodeDataType::bool_t>(), std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); @@ -66,6 +78,18 @@ AffectationRegisterForString::_register_eq_op() repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); } void @@ -107,29 +131,17 @@ AffectationRegisterForString::_register_pluseq_op() std::make_shared<AffectationProcessorBuilder< language::pluseq_op, std::string, TinyVector<3>>>()); - // this->_addAffectation("string += string", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, - // std::string>>()); - // this->_addAffectation("string += B", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, bool>>()); - // this->_addAffectation("string += N", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, uint64_t>>()); - // this->_addAffectation("string += Z", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, int64_t>>()); - // this->_addAffectation("string += R", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, double>>()); - // this - // ->_addAffectation("string += R^1", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, - // TinyVector<1>>>()); - // this - // ->_addAffectation("string += R^2", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, - // TinyVector<2>>>()); - // this - // ->_addAffectation("string += R^3", - // std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, - // TinyVector<3>>>()); + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<1>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<2>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<3>>>()); } AffectationRegisterForString::AffectationRegisterForString() diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d8850e79a4129c9a2098baa14b869f48bb8d1ce --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp @@ -0,0 +1,83 @@ +#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> + +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_comparisons() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<language::eqeq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::eqeq_op, bool, Rnxn, Rnxn>>()); + + repository.addBinaryOperator<language::not_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::not_eq_op, bool, Rnxn, Rnxn>>()); +} + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_scalar() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, bool, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, uint64_t, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, int64_t, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, double, Rnxn>>()); +} + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_vector() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, Rnxn, Rn>>()); +} + +template <size_t Dimension> +template <typename OperatorT> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, Rnxn, Rnxn, Rnxn>>()); +} + +template <size_t Dimension> +BinaryOperatorRegisterForRnxn<Dimension>::BinaryOperatorRegisterForRnxn() +{ + this->_register_comparisons(); + + this->_register_product_by_a_scalar(); + this->_register_product_by_a_vector(); + + this->_register_arithmetic<language::plus_op>(); + this->_register_arithmetic<language::minus_op>(); + this->_register_arithmetic<language::multiply_op>(); +} + +template class BinaryOperatorRegisterForRnxn<1>; +template class BinaryOperatorRegisterForRnxn<2>; +template class BinaryOperatorRegisterForRnxn<3>; diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..594740b629b0a7201d64139fd8e2ce4a12b38cd2 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp @@ -0,0 +1,22 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP +#define BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class BinaryOperatorRegisterForRnxn +{ + private: + void _register_comparisons(); + + void _register_product_by_a_scalar(); + void _register_product_by_a_vector(); + + template <typename OperatorT> + void _register_arithmetic(); + + public: + BinaryOperatorRegisterForRnxn(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForString.cpp b/src/language/utils/BinaryOperatorRegisterForString.cpp index 1323e21ea0cd17cca712a286c6cf4192447e88e4..0e89cbe00502938d817405c3f3f3d6d991f0433e 100644 --- a/src/language/utils/BinaryOperatorRegisterForString.cpp +++ b/src/language/utils/BinaryOperatorRegisterForString.cpp @@ -36,5 +36,8 @@ BinaryOperatorRegisterForString::BinaryOperatorRegisterForString() this->_register_concat<TinyVector<1>>(); this->_register_concat<TinyVector<2>>(); this->_register_concat<TinyVector<3>>(); + this->_register_concat<TinyMatrix<1>>(); + this->_register_concat<TinyMatrix<2>>(); + this->_register_concat<TinyMatrix<3>>(); this->_register_concat<std::string>(); } diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index ed105ca479f5b17c3366d44320505dba5d2eacad..22c0ae52f184a7f18d852057cbe9b0616f1905e3 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -5,16 +5,19 @@ add_library(PugsLanguageUtils AffectationRegisterForN.cpp AffectationRegisterForR.cpp AffectationRegisterForRn.cpp + AffectationRegisterForRnxn.cpp AffectationRegisterForString.cpp AffectationRegisterForZ.cpp ASTDotPrinter.cpp ASTExecutionInfo.cpp ASTNodeDataType.cpp + ASTNodeNaturalConversionChecker.cpp ASTPrinter.cpp BinaryOperatorRegisterForB.cpp BinaryOperatorRegisterForN.cpp BinaryOperatorRegisterForR.cpp BinaryOperatorRegisterForRn.cpp + BinaryOperatorRegisterForRnxn.cpp BinaryOperatorRegisterForString.cpp BinaryOperatorRegisterForZ.cpp DataVariant.cpp @@ -27,6 +30,7 @@ add_library(PugsLanguageUtils UnaryOperatorRegisterForN.cpp UnaryOperatorRegisterForR.cpp UnaryOperatorRegisterForRn.cpp + UnaryOperatorRegisterForRnxn.cpp UnaryOperatorRegisterForZ.cpp ) diff --git a/src/language/utils/DataVariant.hpp b/src/language/utils/DataVariant.hpp index 55e537f0922a53020797b360d4c235f358bcbc8c..6add9e52853bcbf58ee42e4f2869649dab43d37d 100644 --- a/src/language/utils/DataVariant.hpp +++ b/src/language/utils/DataVariant.hpp @@ -1,6 +1,7 @@ #ifndef DATA_VARIANT_HPP #define DATA_VARIANT_HPP +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> #include <language/utils/EmbeddedData.hpp> #include <language/utils/FunctionSymbolId.hpp> @@ -22,6 +23,9 @@ using DataVariant = std::variant<std::monostate, TinyVector<1>, TinyVector<2>, TinyVector<3>, + TinyMatrix<1>, + TinyMatrix<2>, + TinyMatrix<3>, EmbeddedData, std::vector<bool>, std::vector<uint64_t>, @@ -31,6 +35,9 @@ using DataVariant = std::variant<std::monostate, std::vector<TinyVector<1>>, std::vector<TinyVector<2>>, std::vector<TinyVector<3>>, + std::vector<TinyMatrix<1>>, + std::vector<TinyMatrix<2>>, + std::vector<TinyMatrix<3>>, std::vector<EmbeddedData>, AggregateDataVariant, FunctionSymbolId>; diff --git a/src/language/utils/OperatorRepository.cpp b/src/language/utils/OperatorRepository.cpp index 93856da8c33793cad9189a25285b612b550f8d61..d34344b967f41c72ae09849a554958b8ebfcc1f3 100644 --- a/src/language/utils/OperatorRepository.cpp +++ b/src/language/utils/OperatorRepository.cpp @@ -5,6 +5,7 @@ #include <language/utils/AffectationRegisterForN.hpp> #include <language/utils/AffectationRegisterForR.hpp> #include <language/utils/AffectationRegisterForRn.hpp> +#include <language/utils/AffectationRegisterForRnxn.hpp> #include <language/utils/AffectationRegisterForString.hpp> #include <language/utils/AffectationRegisterForZ.hpp> @@ -12,6 +13,7 @@ #include <language/utils/BinaryOperatorRegisterForN.hpp> #include <language/utils/BinaryOperatorRegisterForR.hpp> #include <language/utils/BinaryOperatorRegisterForRn.hpp> +#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> #include <language/utils/BinaryOperatorRegisterForString.hpp> #include <language/utils/BinaryOperatorRegisterForZ.hpp> @@ -23,6 +25,7 @@ #include <language/utils/UnaryOperatorRegisterForN.hpp> #include <language/utils/UnaryOperatorRegisterForR.hpp> #include <language/utils/UnaryOperatorRegisterForRn.hpp> +#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> #include <language/utils/UnaryOperatorRegisterForZ.hpp> #include <utils/PugsAssert.hpp> @@ -65,6 +68,9 @@ OperatorRepository::_initialize() AffectationRegisterForRn<1>{}; AffectationRegisterForRn<2>{}; AffectationRegisterForRn<3>{}; + AffectationRegisterForRnxn<1>{}; + AffectationRegisterForRnxn<2>{}; + AffectationRegisterForRnxn<3>{}; AffectationRegisterForString{}; BinaryOperatorRegisterForB{}; @@ -74,6 +80,9 @@ OperatorRepository::_initialize() BinaryOperatorRegisterForRn<1>{}; BinaryOperatorRegisterForRn<2>{}; BinaryOperatorRegisterForRn<3>{}; + BinaryOperatorRegisterForRnxn<1>{}; + BinaryOperatorRegisterForRnxn<2>{}; + BinaryOperatorRegisterForRnxn<3>{}; BinaryOperatorRegisterForString{}; IncDecOperatorRegisterForN{}; @@ -87,4 +96,7 @@ OperatorRepository::_initialize() UnaryOperatorRegisterForRn<1>{}; UnaryOperatorRegisterForRn<2>{}; UnaryOperatorRegisterForRn<3>{}; + UnaryOperatorRegisterForRnxn<1>{}; + UnaryOperatorRegisterForRnxn<2>{}; + UnaryOperatorRegisterForRnxn<3>{}; } diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 278c7274ff702cfc5e97f7b6075b5487e440580b..9b57a908e6341291d14e5e8f9c80aa8a87275dcc 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -229,6 +229,85 @@ class PugsFunctionAdapter<OutputType(InputType...)> } // LCOV_EXCL_STOP } + } else if constexpr (is_tiny_matrix_v<OutputType>) { + switch (data_type) { + case ASTNodeDataType::list_t: { + return [](DataVariant&& result) -> OutputType { + AggregateDataVariant& v = std::get<AggregateDataVariant>(result); + OutputType x; + + for (size_t i = 0, l = 0; i < x.dimension(); ++i) { + for (size_t j = 0; j < x.dimension(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_arithmetic_v<Aij_T>) { + x(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting arithmetic value"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + return x; + }; + } + case ASTNodeDataType::matrix_t: { + return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; + } + case ASTNodeDataType::bool_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return + [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::unsigned_int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType(static_cast<double>(std::get<uint64_t>(result))); + }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType{static_cast<double>(std::get<int64_t>(result))}; + }; + } else { + // If this point is reached must be a 0 matrix + return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; }; + } + } + case ASTNodeDataType::double_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + } + // LCOV_EXCL_STOP + } } else if constexpr (std::is_arithmetic_v<OutputType>) { switch (data_type) { case ASTNodeDataType::bool_t: { diff --git a/src/language/utils/UnaryOperatorRegisterForRnxn.cpp b/src/language/utils/UnaryOperatorRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..798a60086b30d008901f0d9b5fa2502e745d6fe9 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRnxn.cpp @@ -0,0 +1,28 @@ +#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +template <size_t Dimension> +void +UnaryOperatorRegisterForRnxn<Dimension>::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addUnaryOperator<language::unary_minus>(Rnxn, + std::make_shared<UnaryOperatorProcessorBuilder< + language::unary_minus, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +UnaryOperatorRegisterForRnxn<Dimension>::UnaryOperatorRegisterForRnxn() +{ + this->_register_unary_minus(); +} + +template class UnaryOperatorRegisterForRnxn<1>; +template class UnaryOperatorRegisterForRnxn<2>; +template class UnaryOperatorRegisterForRnxn<3>; diff --git a/src/language/utils/UnaryOperatorRegisterForRnxn.hpp b/src/language/utils/UnaryOperatorRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42084f2961504130189f40967b50757401b7b939 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRnxn.hpp @@ -0,0 +1,16 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP +#define UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class UnaryOperatorRegisterForRnxn +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForRnxn(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP diff --git a/src/utils/PugsTraits.hpp b/src/utils/PugsTraits.hpp index e34e0433bb69316af695fec1c8d571ef23e03155..f51b455cb5ea4efa678dd980d1fec2721ab1fd2a 100644 --- a/src/utils/PugsTraits.hpp +++ b/src/utils/PugsTraits.hpp @@ -84,7 +84,14 @@ inline constexpr bool is_tiny_vector_v = false; template <size_t N, typename T> inline constexpr bool is_tiny_vector_v<TinyVector<N, T>> = true; -// Traits is_tiny_vector +// Traits is_tiny_matrix + +template <typename T> +inline constexpr bool is_tiny_matrix_v = false; + +template <size_t N, typename T> +inline constexpr bool is_tiny_matrix_v<TinyMatrix<N, T>> = true; + // helper to check if a type is part of a variant template <typename T, typename V> diff --git a/tests/test_ASTNodeAffectationExpressionBuilder.cpp b/tests/test_ASTNodeAffectationExpressionBuilder.cpp index ecd72b1f04b52b47431e10d07b83d3dab441066e..d2aeabb06ed1e69c51e04a9c83842c260b796852 100644 --- a/tests/test_ASTNodeAffectationExpressionBuilder.cpp +++ b/tests/test_ASTNodeAffectationExpressionBuilder.cpp @@ -1548,8 +1548,7 @@ let x : R, x=1; x/=2.3; ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, - "unexpected error: invalid implicit conversion: undefined -> string"); + REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, "undefined affectation type: string = undefined"); } SECTION("Invalid string affectation operator") @@ -1629,7 +1628,7 @@ let s : string, s="foo"; s*=2; let x : R^3; let y : R^1; x = y; )"; - std::string error_message = "invalid implicit conversion: R^1 -> R^3"; + std::string error_message = "undefined affectation type: R^3 = R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1640,7 +1639,7 @@ let x : R^3; let y : R^1; x = y; let x : R^3; let y : R^2; x = y; )"; - std::string error_message = "invalid implicit conversion: R^2 -> R^3"; + std::string error_message = "undefined affectation type: R^3 = R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1651,7 +1650,7 @@ let x : R^3; let y : R^2; x = y; let x : R^2; let y : R^1; x = y; )"; - std::string error_message = "invalid implicit conversion: R^1 -> R^2"; + std::string error_message = "undefined affectation type: R^2 = R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1662,7 +1661,7 @@ let x : R^2; let y : R^1; x = y; let x : R^2; let y : R^3; x = y; )"; - std::string error_message = "invalid implicit conversion: R^3 -> R^2"; + std::string error_message = "undefined affectation type: R^2 = R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1673,7 +1672,7 @@ let x : R^2; let y : R^3; x = y; let x : R^1; let y : R^2; x = y; )"; - std::string error_message = "invalid implicit conversion: R^2 -> R^1"; + std::string error_message = "undefined affectation type: R^1 = R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1684,7 +1683,7 @@ let x : R^1; let y : R^2; x = y; let x : R^1; let y : R^3; x = y; )"; - std::string error_message = "invalid implicit conversion: R^3 -> R^1"; + std::string error_message = "undefined affectation type: R^1 = R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1698,7 +1697,7 @@ let x : R^1; let y : R^3; x = y; let x : R^3, x = 3; )"; - std::string error_message = "invalid implicit conversion: Z -> R^3"; + std::string error_message = "invalid integral value (0 is the solely valid value)"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1709,7 +1708,7 @@ let x : R^3, x = 3; let x : R^2, x = 2; )"; - std::string error_message = "invalid implicit conversion: Z -> R^2"; + std::string error_message = "invalid integral value (0 is the solely valid value)"; CHECK_AST_THROWS_WITH(data, error_message); } diff --git a/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp b/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp index 45b1303c3db5c69457921d98f76f43403274d7ed..7eadd2067e54e670021aa2bee03e92fe0abedccd 100644 --- a/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp +++ b/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp @@ -52,6 +52,45 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") auto& node_processor = *node->m_node_processor; REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyVector<3>>).name()); } + + SECTION("R^1x1") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<1>>).name()); + } + + SECTION("R^2x2") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<2>>).name()); + } + + SECTION("R^3x3") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<3>>).name()); + } } SECTION("R^d component bad access") @@ -67,6 +106,7 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") } REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); } + SECTION("R^d (d > 3)") { { @@ -76,6 +116,26 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") } REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); } + + SECTION("R^dxd (d < 1)") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(0, 0); + node->emplace_back(std::move(array_node)); + } + REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); + } + + SECTION("R^dxd (d > 3)") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(4, 4); + node->emplace_back(std::move(array_node)); + } + REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); + } } SECTION("invalid array expression") diff --git a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp index 2954cd8db2e691ea2206b30b158f9596a8290081..25115bfc2ff3011e582029ab495e17f4a9d108e2 100644 --- a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -133,7 +133,7 @@ RtoR(true); } } - SECTION("R -> R1") + SECTION("R -> R^1") { SECTION("from R") { @@ -201,9 +201,77 @@ RtoR1(true); } } - SECTION("R1 -> R") + SECTION("R -> R^1x1") { - SECTION("from R1") + SECTION("from R") + { + std::string_view data = R"( +RtoR11(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +RtoR11(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +RtoR11(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +RtoR11(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^1 -> R") + { + SECTION("from R^1") { std::string_view data = R"( let x : R^1, x = 2; @@ -286,7 +354,92 @@ R1toR(true); } } - SECTION("R2 -> R") + SECTION("R^1x1 -> R") + { + SECTION("from R^1x1") + { + std::string_view data = R"( +let x : R^1x1, x = 2; +R11toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R") + { + std::string_view data = R"( +R11toR(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +R11toR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +R11toR(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +R11toR(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^2 -> R") { SECTION("from 0") { @@ -340,7 +493,63 @@ R2toR((1,2)); } } - SECTION("R3 -> R") + SECTION("R^2x2 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R22toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^2x2") + { + std::string_view data = R"( +let x:R^2x2, x = (1,2,3,4); +R22toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R22toR((1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^3 -> R") { SECTION("from 0") { @@ -395,6 +604,67 @@ R3toR((1,2,3)); } } + SECTION("R^3x3 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R33toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^3x3") + { + std::string_view data = R"( +let x:R^3x3, x = (1,2,3,4,5,6,7,8,9); +R33toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R33toR((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("Z -> R") { SECTION("from Z") @@ -603,6 +873,87 @@ R3R2toR((1,2,3),0); } } + SECTION("R^3x3*R^2x2 -> R") + { + SECTION("from R^3x3*R^2x2") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let y : R^2x2, y = (1,2,3,4); +R33R22toR(x,y); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::name:x:NameProcessor) + `-(language::name:y:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(R,R,R,R)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),(1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(0)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("string -> B") { std::string_view data = R"( @@ -1012,7 +1363,7 @@ tuple_builtinToB(t); CHECK_AST(data, result); } - SECTION("Z -> tuple(R1)") + SECTION("Z -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1); @@ -1028,7 +1379,7 @@ tuple_R1ToR(1); CHECK_AST(data, result); } - SECTION("R -> tuple(R1)") + SECTION("R -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1.2); @@ -1044,7 +1395,7 @@ tuple_R1ToR(1.2); CHECK_AST(data, result); } - SECTION("R1 -> tuple(R1)") + SECTION("R^1 -> tuple(R^1)") { std::string_view data = R"( let r:R^1, r = 3; @@ -1061,7 +1412,56 @@ tuple_R1ToR(r); CHECK_AST(data, result); } - SECTION("0 -> tuple(R2)") + SECTION("Z -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1.2); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::real:1.2:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^1x1 -> tuple(R^1x1)") + { + std::string_view data = R"( +let r:R^1x1, r = 3; +tuple_R11ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^2)") { std::string_view data = R"( tuple_R2ToR(0); @@ -1077,7 +1477,7 @@ tuple_R2ToR(0); CHECK_AST(data, result); } - SECTION("R2 -> tuple(R2)") + SECTION("R^2 -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1094,7 +1494,7 @@ tuple_R2ToR(r); CHECK_AST(data, result); } - SECTION("compound_list -> tuple(R2)") + SECTION("compound_list -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1113,7 +1513,59 @@ tuple_R2ToR((r,r)); CHECK_AST(data, result); } - SECTION("0 -> tuple(R3)") + SECTION("0 -> tuple(R^2x2)") + { + std::string_view data = R"( +tuple_R22ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^2x2 -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("compound_list -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR((r,r)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::name:r:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^3)") { std::string_view data = R"( tuple_R3ToR(0); @@ -1129,7 +1581,7 @@ tuple_R3ToR(0); CHECK_AST(data, result); } - SECTION("R3 -> tuple(R3)") + SECTION("R^3 -> tuple(R^3)") { std::string_view data = R"( let r:R^3, r = (1,2,3); @@ -1146,6 +1598,39 @@ tuple_R3ToR(r); CHECK_AST(data, result); } + SECTION("0 -> tuple(R^3x3)") + { + std::string_view data = R"( +tuple_R33ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^3x3 -> tuple(R^3x3)") + { + std::string_view data = R"( +let r:R^3x3, r = (1,2,3,4,5,6,7,8,9); +tuple_R33ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("FunctionSymbolId -> R") { std::string_view data = R"( diff --git a/tests/test_ASTNodeDataType.cpp b/tests/test_ASTNodeDataType.cpp index 2e6e9afb4715b86c67fa2511e7472418e912f8e8..4e4e5759b68c3f94a288bc785f313711b41ec7d0 100644 --- a/tests/test_ASTNodeDataType.cpp +++ b/tests/test_ASTNodeDataType.cpp @@ -9,6 +9,7 @@ namespace language struct integer; struct real; struct vector_type; +struct matrix_type; } // namespace language // clazy:excludeall=non-pod-global-static @@ -62,6 +63,14 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)) == "R^2"); REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)) == "R^3"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(7)) == "R^7"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)) == "R^1x1"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)) == "R^2x2"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)) == "R^3x3"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(7, 3)) == "R^7x3"); + REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{}) == "void"); REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt}) == "B"); REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt, unsigned_int_dt}) == "(B,N)"); @@ -104,7 +113,7 @@ TEST_CASE("ASTNodeDataType", "[language]") { std::unique_ptr dimension_node = std::make_unique<ASTNode>(); dimension_node->set_type<language::integer>(); - dimension_node->source = "17"; + dimension_node->source = "3"; auto& source = dimension_node->source; dimension_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; dimension_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; @@ -113,8 +122,8 @@ TEST_CASE("ASTNodeDataType", "[language]") SECTION("good node") { - REQUIRE(getVectorDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); - REQUIRE(getVectorDataType(*type_node).dimension() == 17); + REQUIRE(getVectorDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + REQUIRE(getVectorDataType(*type_node).dimension() == 3); } SECTION("bad node type") @@ -140,6 +149,99 @@ TEST_CASE("ASTNodeDataType", "[language]") type_node->children[1]->set_type<language::real>(); REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected non integer constant dimension"); } + + SECTION("bad dimension value") + { + type_node->children[1]->source = "0"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + } + + SECTION("getMatrixDataType") + { + std::unique_ptr type_node = std::make_unique<ASTNode>(); + type_node->set_type<language::matrix_type>(); + + type_node->emplace_back(std::make_unique<ASTNode>()); + + { + { + std::unique_ptr dimension0_node = std::make_unique<ASTNode>(); + dimension0_node->set_type<language::integer>(); + dimension0_node->source = "3"; + auto& source0 = dimension0_node->source; + dimension0_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[0]}; + dimension0_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[source0.size()]}; + type_node->emplace_back(std::move(dimension0_node)); + } + { + std::unique_ptr dimension1_node = std::make_unique<ASTNode>(); + dimension1_node->set_type<language::integer>(); + dimension1_node->source = "3"; + auto& source1 = dimension1_node->source; + dimension1_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[0]}; + dimension1_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[source1.size()]}; + type_node->emplace_back(std::move(dimension1_node)); + } + } + + SECTION("good node") + { + REQUIRE(getMatrixDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + REQUIRE(getMatrixDataType(*type_node).nbRows() == 3); + REQUIRE(getMatrixDataType(*type_node).nbColumns() == 3); + } + + SECTION("bad node type") + { + type_node->set_type<language::integer>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.clear(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.emplace_back(std::unique_ptr<ASTNode>()); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad dimension 0 type") + { + type_node->children[1]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad dimension 1 type") + { + type_node->children[2]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad nb rows value") + { + type_node->children[1]->source = "0"; + type_node->children[2]->source = "0"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + type_node->children[2]->source = "4"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + + SECTION("none square matrices") + { + type_node->children[1]->source = "1"; + type_node->children[2]->source = "2"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "only square matrices are supported"); + } } SECTION("isNaturalConversion") @@ -153,6 +255,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, bool_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt), bool_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), bool_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), bool_dt)); } SECTION("-> N") @@ -165,6 +268,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE( not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt), unsigned_int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), unsigned_int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), unsigned_int_dt)); } SECTION("-> Z") @@ -176,6 +280,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt), int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), int_dt)); } SECTION("-> R") @@ -187,6 +292,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, double_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), double_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), double_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), double_dt)); } SECTION("-> string") @@ -198,6 +304,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(isNaturalConversion(string_dt, string_dt)); REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt), string_dt)); REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), string_dt)); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), string_dt)); } SECTION("-> tuple") @@ -227,6 +334,21 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(4))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(9))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), @@ -249,6 +371,53 @@ TEST_CASE("ASTNodeDataType", "[language]") ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); } + SECTION("-> matrix") + { + REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(4), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(9), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + } + SECTION("-> type_id") { REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); diff --git a/tests/test_ASTNodeDataTypeBuilder.cpp b/tests/test_ASTNodeDataTypeBuilder.cpp index 6c37b1afe9a7c411e2c1643ac1bd1e3cb1a72aeb..369374986fa27f69b12e147434ab3605872cac16 100644 --- a/tests/test_ASTNodeDataTypeBuilder.cpp +++ b/tests/test_ASTNodeDataTypeBuilder.cpp @@ -305,7 +305,46 @@ let x : R; x[2]; auto ast = ASTBuilder::build(input); ASTSymbolTableBuilder{*ast}; - REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid types 'R[Z]' for array subscript"); + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid subscript expression: R cannot be indexed"); + } + + SECTION("invalid R^d subscript index list") + { + std::string_view data = R"( +let x : R^2; x[2,2]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2 requires a single integer"); + } + + SECTION("invalid R^dxd subscript index list 1") + { + std::string_view data = R"( +let x : R^2x2; x[2]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2x2 requires two integers"); + } + + SECTION("invalid R^dxd subscript index list 2") + { + std::string_view data = R"( +let x : R^2x2; x[2,3,1]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2x2 requires two integers"); } SECTION("too many variables") @@ -510,6 +549,61 @@ let t2 : (R^3), t2 = (0, 0); CHECK_AST(data, result); } + SECTION("R^dxd tuples") + { + std::string_view data = R"( +let a : R^2x2, a = (2, 3.1, -1.2, 4); +let t1 : (R^2x2), t1 = (a, (1,2,1,3), 0); +let t2 : (R^3x3), t2 = (0, 0); +)"; + + std::string_view result = R"( +(root:void) + +-(language::var_declaration:void) + | +-(language::name:a:R^2x2) + | +-(language::matrix_type:typename(R^2x2)) + | | +-(language::R_set:typename(R)) + | | +-(language::integer:2:Z) + | | `-(language::integer:2:Z) + | +-(language::name:a:R^2x2) + | `-(language::expression_list:list(Z*R*R*Z)) + | +-(language::integer:2:Z) + | +-(language::real:3.1:R) + | +-(language::unary_minus:R) + | | `-(language::real:1.2:R) + | `-(language::integer:4:Z) + +-(language::var_declaration:void) + | +-(language::name:t1:tuple(R^2x2)) + | +-(language::tuple_type_specifier:typename(tuple(R^2x2))) + | | `-(language::matrix_type:typename(R^2x2)) + | | +-(language::R_set:typename(R)) + | | +-(language::integer:2:Z) + | | `-(language::integer:2:Z) + | +-(language::name:t1:tuple(R^2x2)) + | `-(language::expression_list:list(R^2x2*list(Z*Z*Z*Z)*Z)) + | +-(language::name:a:R^2x2) + | +-(language::tuple_expression:list(Z*Z*Z*Z)) + | | +-(language::integer:1:Z) + | | +-(language::integer:2:Z) + | | +-(language::integer:1:Z) + | | `-(language::integer:3:Z) + | `-(language::integer:0:Z) + `-(language::var_declaration:void) + +-(language::name:t2:tuple(R^3x3)) + +-(language::tuple_type_specifier:typename(tuple(R^3x3))) + | `-(language::matrix_type:typename(R^3x3)) + | +-(language::R_set:typename(R)) + | +-(language::integer:3:Z) + | `-(language::integer:3:Z) + +-(language::name:t2:tuple(R^3x3)) + `-(language::expression_list:list(Z*Z)) + +-(language::integer:0:Z) + `-(language::integer:0:Z) +)"; + + CHECK_AST(data, result); + } + SECTION("string tuples") { std::string_view data = R"( @@ -647,6 +741,84 @@ let square : R -> R^2, x -> (x, x*x); } } + SECTION("R^dxd-functions") + { + SECTION("matrix function") + { + std::string_view data = R"( +let double : R^2x2 -> R^2x2, x -> 2*x; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:double:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("matrix vector product") + { + std::string_view data = R"( +let prod : R^2x2*R^2 -> R^2, (A,x) -> A*x; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:prod:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("matrix function") + { + std::string_view data = R"( +let det : R^2x2 -> R, x -> x[0,0]*x[1,1]-x[1,0]*x[0,1]; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:det:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("R-list -> R^dxd") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x, x*x, 2-x, 0); +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:f:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^d*R^d -> R^dxd") + { + std::string_view data = R"( +let f : R^2*R^2 -> R^2x2, (x,y) -> (x[0], y[0], x[1], y[1]); +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:f:function) +)"; + + CHECK_AST(data, result); + } + } + SECTION("R-functions") { SECTION("multiple variable") @@ -866,6 +1038,19 @@ let f : R -> R*R, x -> x*x*x; REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "number of image spaces (2) R*R differs from number of expressions (1) x*x*x"); } + + SECTION("wrong image size 3") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x, 2*x, 2); +)"; + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, + "expecting 4 scalar expressions or an R^2x2, found 3 scalar expressions"); + } } } diff --git a/tests/test_ASTNodeFunctionExpressionBuilder.cpp b/tests/test_ASTNodeFunctionExpressionBuilder.cpp index 5a071366578271161d21a27d2ce0d5e883f9c7c3..fd0163b7aed4f94bf99bee19eec4ae46ed9b7f3b 100644 --- a/tests/test_ASTNodeFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeFunctionExpressionBuilder.cpp @@ -397,6 +397,60 @@ f(x); CHECK_AST(data, result); } + SECTION("Return R^1x1 -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x+x; +let x : R^1x1, x = 1; +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^2x2 -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x+x; +let x : R^2x2, x = (1,2,3,4); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^3x3 -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x+x; +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return scalar -> R^1") { std::string_view data = R"( @@ -453,6 +507,73 @@ f(1,2,3); CHECK_AST(data, result); } + SECTION("Return scalar -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> x+1; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^2x2") + { + std::string_view data = R"( +let f : R*R*R*R -> R^2x2, (x,y,z,t) -> (x,y,z,t); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 2ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^3x3") + { + std::string_view data = R"( +let f : R^3*R^3*R^3 -> R^3x3, (x,y,z) -> (x[0],x[1],x[2],y[0],y[1],y[2],z[0],z[1],z[2]); +f((1,2,3),(4,5,6),(7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 3ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | `-(language::integer:3:ValueProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | `-(language::integer:6:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return '0' -> R^1") { std::string_view data = R"( @@ -504,6 +625,57 @@ f(1); CHECK_AST(data, result); } + SECTION("Return '0' -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<1ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^2x2") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<2ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^3x3") + { + std::string_view data = R"( +let f : R -> R^3x3, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<3ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound") { std::string_view data = R"( @@ -525,6 +697,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, (x), (x,y,z,t), (x,y,z, x,x,x, t,t,t)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound with '0'") { std::string_view data = R"( @@ -546,6 +739,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound with '0'") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, 0, 0, (x, y, z, t, x, y, z, t, x)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments '0' -> R^1") { std::string_view data = R"( @@ -597,6 +811,57 @@ f(0); CHECK_AST(data, result); } + SECTION("Arguments '0' -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments tuple -> R^d") { std::string_view data = R"( @@ -617,11 +882,37 @@ f((1,2,3)); CHECK_AST(data, result); } + SECTION("Arguments tuple -> R^dxd") + { + std::string_view data = R"( +let f: R^3x3 -> R, x -> x[0,0]+x[0,1]+x[0,2]; +f((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments compound with tuple") { std::string_view data = R"( -let f: R*R^3*R^2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0]+y[1]; -f(2,(1,2,3),(2,1.3)); +let f: R*R^3*R^2x2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0,0]+y[1,1]; +f(2,(1,2,3),(2,3,-1,1.3)); )"; std::string_view result = R"( @@ -636,6 +927,9 @@ f(2,(1,2,3),(2,1.3)); | `-(language::integer:3:ValueProcessor) `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>) + | `-(language::integer:1:ValueProcessor) `-(language::real:1.3:ValueProcessor) )"; @@ -785,7 +1079,9 @@ let f : R^2 -> R, x->x[0]; f((1,2,3)); )"; - CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 2, but provided 3"}); } SECTION("tuple[2] -> R^3") @@ -795,7 +1091,9 @@ let f : R^3 -> R, x->x[0]; f((1,2)); )"; - CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 3, but provided 2"}); } SECTION("compound tuple[3] -> R^2") @@ -805,7 +1103,9 @@ let f : R*R^2 -> R, (t,x)->x[0]; f(1,(1,2,3)); )"; - CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 2, but provided 3"}); } SECTION("compound tuple[2] -> R^3") @@ -815,7 +1115,9 @@ let f : R^3*R^2 -> R, (x,y)->x[0]*y[1]; f((1,2),(3,4)); )"; - CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 3, but provided 2"}); } SECTION("list instead of tuple -> R^3") diff --git a/tests/test_ASTNodeListAffectationExpressionBuilder.cpp b/tests/test_ASTNodeListAffectationExpressionBuilder.cpp index c15be1db66c0bbbe9a5e9eee4c895f21f3bb30d7..62b584eaa1f9e0bac8b05ae8f8d1072cf4e46bae 100644 --- a/tests/test_ASTNodeListAffectationExpressionBuilder.cpp +++ b/tests/test_ASTNodeListAffectationExpressionBuilder.cpp @@ -199,6 +199,56 @@ let (x1,x2,x3,x) : R^1*R^2*R^3*R, CHECK_AST(data, result); } + SECTION("without conversion R^1x1*R^2x2*R^3x3*R") + { + std::string_view data = R"( +let a:R^1x1, a = 0; +let b:R^2x2, b = (1, 2, 3, 4); +let c:R^3x3, c = (9, 8, 7, 6, 5, 4, 3, 2, 1); +let (x1,x2,x3,x) : R^1x1*R^2x2*R^3x3*R, + (x1,x2,x3,x) = (a, b, c, 2); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + +-(language::eq_op:AffectationProcessor<language::eq_op, TinyMatrix<1ul, double>, long>) + | +-(language::name:a:NameProcessor) + | `-(language::integer:0:ValueProcessor) + +-(language::eq_op:AffectationToTinyMatrixFromListProcessor<language::eq_op, TinyMatrix<2ul, double> >) + | +-(language::name:b:NameProcessor) + | `-(language::expression_list:ASTNodeExpressionListProcessor) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | `-(language::integer:4:ValueProcessor) + +-(language::eq_op:AffectationToTinyMatrixFromListProcessor<language::eq_op, TinyMatrix<3ul, double> >) + | +-(language::name:c:NameProcessor) + | `-(language::expression_list:ASTNodeExpressionListProcessor) + | +-(language::integer:9:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | `-(language::integer:1:ValueProcessor) + `-(language::eq_op:ListAffectationProcessor<language::eq_op>) + +-(language::name_list:FakeProcessor) + | +-(language::name:x1:NameProcessor) + | +-(language::name:x2:NameProcessor) + | +-(language::name:x3:NameProcessor) + | `-(language::name:x:NameProcessor) + `-(language::expression_list:ASTNodeExpressionListProcessor) + +-(language::name:a:NameProcessor) + +-(language::name:b:NameProcessor) + +-(language::name:c:NameProcessor) + `-(language::integer:2:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Zero initialization") { std::string_view data = R"( @@ -223,6 +273,30 @@ let (x1,x2,x3,x) : R^1*R^2*R^3*R, (x1,x2,x3,x) = (0, 0, 0, 0); CHECK_AST(data, result); } + SECTION("Zero initialization") + { + std::string_view data = R"( +let (x1,x2,x3,x) : R^1x1*R^2x2*R^3x3*R, (x1,x2,x3,x) = (0, 0, 0, 0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::eq_op:ListAffectationProcessor<language::eq_op>) + +-(language::name_list:FakeProcessor) + | +-(language::name:x1:NameProcessor) + | +-(language::name:x2:NameProcessor) + | +-(language::name:x3:NameProcessor) + | `-(language::name:x:NameProcessor) + `-(language::expression_list:ASTNodeExpressionListProcessor) + +-(language::integer:0:ValueProcessor) + +-(language::integer:0:ValueProcessor) + +-(language::integer:0:ValueProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("from function") { std::string_view data = R"( @@ -385,7 +459,7 @@ let x:R^2, x = (1,2); let y:R^3, y = x; )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R^2 -> R^3"}); + CHECK_AST_THROWS_WITH(data, std::string{"undefined affectation type: R^3 = R^2"}); } SECTION("invalid Z -> R^d conversion (non-zero)") @@ -394,7 +468,7 @@ let y:R^3, y = x; let x:R^2, x = 1; )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2"}); + CHECK_AST_THROWS_WITH(data, std::string{"invalid integral value (0 is the solely valid value)"}); } } } diff --git a/tests/test_ASTNodeNaturalConversionChecker.cpp b/tests/test_ASTNodeNaturalConversionChecker.cpp index ed1e834c016ac9e7c6f3d796a7018ce290aa9324..a9e10344d3f6eaacd1ac6fde102b1d4b82214927 100644 --- a/tests/test_ASTNodeNaturalConversionChecker.cpp +++ b/tests/test_ASTNodeNaturalConversionChecker.cpp @@ -2,7 +2,7 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> namespace language { @@ -84,6 +84,145 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") } } + SECTION("-> R^dxd") + { + SECTION("R^1x1 -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + + SECTION("list -> R^1x1") + { + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + + SECTION("'0' -> R^dxd") + { + data_node->m_data_type = int_dt; + data_node->set_type<language::integer>(); + data_node->source = "0"; + auto& source = data_node->source; + data_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; + data_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; + + SECTION("d = 1") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + SECTION("d = 2") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + SECTION("d = 3") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); + } + } + + SECTION("R^2x2 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + + SECTION("list -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list2_node)); + + std::unique_ptr list3_node = std::make_unique<ASTNode>(); + list3_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list3_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + + SECTION("R^3x3 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); + } + + SECTION("list -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt), std::make_shared<const ASTNodeDataType>(double_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt), std::make_shared<const ASTNodeDataType>(int_dt), + std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list2_node)); + + std::unique_ptr list3_node = std::make_unique<ASTNode>(); + list3_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list3_node)); + + std::unique_ptr list4_node = std::make_unique<ASTNode>(); + list4_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list4_node)); + + std::unique_ptr list5_node = std::make_unique<ASTNode>(); + list5_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list5_node)); + + std::unique_ptr list6_node = std::make_unique<ASTNode>(); + list6_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list6_node)); + + std::unique_ptr list7_node = std::make_unique<ASTNode>(); + list7_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list7_node)); + + std::unique_ptr list8_node = std::make_unique<ASTNode>(); + list8_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list8_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); + } + } + SECTION("-> R^d") { SECTION("R^1 -> R^1") @@ -718,6 +857,423 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { std::unique_ptr data_node = std::make_unique<ASTNode>(); + SECTION("-> R^dxd") + { + SECTION("R^2x2 -> R^1x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}), + "invalid implicit conversion: R^2x2 -> R^1x1"); + } + + SECTION("R^3x3 -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}), + "invalid implicit conversion: R^3x3 -> R^1x1"); + } + + SECTION("R^1x1 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}), + "invalid implicit conversion: R^1x1 -> R^2x2"); + } + + SECTION("R^3x3 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}), + "invalid implicit conversion: R^3x3 -> R^2x2"); + } + + SECTION("R^1x1 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}), + "invalid implicit conversion: R^1x1 -> R^3x3"); + } + + SECTION("R^2x2 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("list1 -> R^dxd") + { + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "incompatible dimensions in affectation: expecting 4, but provided 1"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "incompatible dimensions in affectation: expecting 9, but provided 1"); + } + } + + SECTION("list2 -> R^dxd") + { + data_node->m_data_type = list_dt; + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + } + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "incompatible dimensions in affectation: expecting 1, but provided 2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "incompatible dimensions in affectation: expecting 9, but provided 2"); + } + } + + SECTION("list3 -> R^dxd") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list2_node)); + } + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "incompatible dimensions in affectation: expecting 1, but provided 3"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "incompatible dimensions in affectation: expecting 4, but provided 3"); + } + } + + SECTION("tuple -> R^dxd") + { + SECTION("tuple(N) -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: tuple(N) -> R^1x1"); + } + + SECTION("tuple(R) -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: tuple(R) -> R^1x1"); + } + + SECTION("tuple(R) -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: tuple(R) -> R^2x2"); + } + + SECTION("tuple(B) -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: tuple(B) -> R^2x2"); + } + + SECTION("tuple(Z) -> R^3x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: tuple(Z) -> R^3x3"); + } + + SECTION("tuple(R) -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: tuple(R) -> R^3x3"); + } + + SECTION("tuple(R^1) -> tuple(R^3x3)") + { + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^1 -> R^3x3"); + } + + SECTION("tuple(R^2) -> tuple(R^3x3)") + { + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^2 -> R^3x3"); + } + + SECTION("tuple(R^2) -> tuple(R^1x1)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + data_node->m_data_type = tuple_R2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1x1}), + "invalid implicit conversion: R^2 -> R^1x1"); + } + + SECTION("tuple(R^1x1) -> tuple(R^3x3)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^1x1 -> R^3x3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^3x3)") + { + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^1x1)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1x1}), + "invalid implicit conversion: R^2x2 -> R^1x1"); + } + } + + SECTION("R -> R^dxd") + { + data_node->m_data_type = double_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: R -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: R -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: R -> R^3x3"); + } + } + + SECTION("Z -> R^dxd (non-zero)") + { + data_node->m_data_type = int_dt; + data_node->set_type<language::integer>(); + data_node->source = "1"; + auto& source = data_node->source; + data_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; + data_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: Z -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: Z -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: Z -> R^3x3"); + } + } + + SECTION("N -> R^dxd") + { + data_node->m_data_type = unsigned_int_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: N -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: N -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: N -> R^3x3"); + } + } + + SECTION("B -> R^dxd") + { + data_node->m_data_type = bool_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: B -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: B -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: B -> R^3x3"); + } + } + + SECTION("string -> R^dxd") + { + data_node->m_data_type = string_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: string -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: string -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: string -> R^3x3"); + } + } + } + SECTION("-> R^d") { SECTION("R^2 -> R^1") @@ -782,14 +1338,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 2, but provided 1"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 3, but provided 1"); } } @@ -810,14 +1366,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 1, but provided 2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 3, but provided 2"); } } @@ -844,14 +1400,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 1, but provided 3"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), - "incompatible dimensions in affectation"); + "incompatible dimensions in affectation: expecting 2, but provided 3"); } } @@ -905,6 +1461,39 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") "invalid implicit conversion: tuple(R) -> R^3"); } + SECTION("tuple(R^1x1) -> tuple(R^3)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + data_node->m_data_type = tuple_R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), + "invalid implicit conversion: R^1x1 -> R^3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^3)") + { + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), + "invalid implicit conversion: R^2x2 -> R^3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^1)") + { + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1}), + "invalid implicit conversion: R^2x2 -> R^1"); + } + SECTION("tuple(R^1) -> tuple(R^3)") { auto tuple_R1 = @@ -1415,6 +2004,36 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") "invalid implicit conversion: R -> N"); } + SECTION("R^1x1 -> tuple(R^2x2)") + { + auto R1x1 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + data_node->m_data_type = R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2x2)}), + "invalid implicit conversion: R^1x1 -> R^2x2"); + } + + SECTION("R^2x2 -> tuple(R^3x3)") + { + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + auto R3x3 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + data_node->m_data_type = R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3x3)}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("R^3x3 -> tuple(R^2x2)") + { + auto R3x3 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + data_node->m_data_type = R3x3; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2x2)}), + "invalid implicit conversion: R^3x3 -> R^2x2"); + } + SECTION("R^1 -> tuple(R^2)") { auto R1 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); diff --git a/tests/test_AffectationProcessor.cpp b/tests/test_AffectationProcessor.cpp index 75ba801b59cc8d06b213fd56b2dc18dbb74b7c45..331046feb18edffb9e110a27a9050cdb94795bea 100644 --- a/tests/test_AffectationProcessor.cpp +++ b/tests/test_AffectationProcessor.cpp @@ -102,7 +102,6 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^1; x[0] = -2.3;", "x", (TinyVector<1>{-2.3})); CHECK_AFFECTATION_RESULT("let x : R^1, x = 0;", "x", (TinyVector<1>{zero})); - CHECK_AFFECTATION_RESULT("let x : R^1; x = 0;", "x", (TinyVector<1>{zero})); } SECTION("R^2") @@ -115,7 +114,6 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^2; x[0] = -0.3; x[1] = 12;", "x", (TinyVector<2>{-0.3, 12})); CHECK_AFFECTATION_RESULT("let x : R^2, x = 0;", "x", (TinyVector<2>{zero})); - CHECK_AFFECTATION_RESULT("let x : R^2; x = 0;", "x", (TinyVector<2>{zero})); } SECTION("R^3") @@ -126,9 +124,51 @@ TEST_CASE("AffectationProcessor", "[language]") (TinyVector<3>{-1, true, false})); CHECK_AFFECTATION_RESULT("let x : R^3; x[0] = -0.3; x[1] = 12; x[2] = 6.2;", "x", (TinyVector<3>{-0.3, 12, 6.2})); - CHECK_AFFECTATION_RESULT("let x : R^3, x = 0;", "x", (TinyVector<3>{zero})); CHECK_AFFECTATION_RESULT("let x : R^3; x = 0;", "x", (TinyVector<3>{zero})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = -1;", "x", (TinyMatrix<1>{-1})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = true;", "x", (TinyMatrix<1>{true})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = false;", "x", (TinyMatrix<1>{false})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = -2.3;", "x", (TinyMatrix<1>{-2.3})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = -1;", "x", (TinyMatrix<1>{-1})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = true;", "x", (TinyMatrix<1>{true})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = false;", "x", (TinyMatrix<1>{false})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = -2.3;", "x", (TinyMatrix<1>{-2.3})); + + CHECK_AFFECTATION_RESULT("let x : R^1x1; x = 0;", "x", (TinyMatrix<1>{zero})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 5);", "x", (TinyMatrix<2>{-1, true, 3, 5})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (true, false, 1==2, 2==2);", "x", + (TinyMatrix<2>{true, false, false, true})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-0.3, 12, 2, -3);", "x", (TinyMatrix<2>{-0.3, 12, 2, -3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = -1; x[0,1] = true; x[1,0] = 2; x[1,1] = 3.3;", "x", + (TinyMatrix<2>{-1, true, 2, 3.3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = true; x[0,1] = false; x[1,0] = 2.1; x[1,1] = -1;", "x", + (TinyMatrix<2>{true, false, 2.1, -1})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = -0.3; x[0,1] = 12; x[1,0] = 1.3; x[1,1] = 7;", "x", + (TinyMatrix<2>{-0.3, 12, 1.3, 7})); + + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = 0;", "x", (TinyMatrix<2>{zero})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-1, true, false, 2, 3.1, 4, -1, true, 2);", "x", + (TinyMatrix<3>{-1, true, false, 2, 3.1, 4, -1, true, 2})); + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 7.1, 3.2, 2-3, 2, -1, 0);", "x", + (TinyMatrix<3>{-0.3, 12, 6.2, 7.1, 3.2, 2 - 3, 2, -1, 0})); + CHECK_AFFECTATION_RESULT("let x : R^3x3; x[0,0] = -1; x[0,1] = true; x[0,2] = false; x[1,0] = -11; x[1,1] = 4; " + "x[1,2] = 3; x[2,0] = 6; x[2,1] = -3; x[2,2] = 5;", + "x", (TinyMatrix<3>{-1, true, false, -11, 4, 3, 6, -3, 5})); + + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = 0;", "x", (TinyMatrix<3>{zero})); + } } SECTION("+=") @@ -281,6 +321,29 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^3, x = (-0.3, 12, 6.2); x[0] *= -1; x[1] *= -3; x[2] *= 2;", "x", (TinyVector<3>{-0.3 * -1, 12 * -3, 6.2 * 2})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x *= 2;", "x", (TinyMatrix<1>{TinyMatrix<1>{2} *= 2})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x[0,0] *= 1.3;", "x", (TinyMatrix<1>{2 * 1.3})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 6); x *= 3;", "x", + (TinyMatrix<2>{TinyMatrix<2>{-1, true, 3, 6} *= 3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 6); x[0,0] *= 2; x[1,1] *= 3;", "x", + (TinyMatrix<2>{-1 * 2, true, 3, 6 * 3})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-1, true, false, 2, -3, 11, 5, -4, 2); x*=5.2;", "x", + (TinyMatrix<3>{TinyMatrix<3>{-1, true, false, 2, -3, 11, 5, -4, 2} *= 5.2})); + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 2, -3, 11, 5, -4, 2); x[0,0] *= -1; x[0,1] *= -3; " + "x[0,2] *= 2; x[1,1] *= 2; x[2,1] *= 6; x[2,2] *= 2;", + "x", (TinyMatrix<3>{-0.3 * -1, 12 * -3, 6.2 * 2, 2, -3 * 2, 11, 5, (-4) * 6, 2 * 2})); + } } SECTION("/=") @@ -323,6 +386,28 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^3, x = (-0.3, 12, 6.2); x[0] /= -1.2; x[1] /= -3.1; x[2] /= 2.4;", "x", (TinyVector<3>{-0.3 / -1.2, 12 / -3.1, 6.2 / 2.4})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x[0,0] /= 1.3;", "x", (TinyMatrix<1>{2 / 1.3})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 1); x[0,0] /= 2; x[0,1] /= 3; x[1,0] /= 0.5; x[1,1] " + "/= 4;", + "x", (TinyMatrix<2>{-1. / 2., true / 3., 3 / 0.5, 1. / 4})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 1.2, 3, 5, 1, 11, 2); x[0,0] /= -1.2; x[0,1] /= " + "-3.1; x[0,2] /= 2.4; x[1,0] /= -1.6; x[1,1] /= -3.1; x[1,2] /= 2.4; x[2,0] /= 0.4; " + "x[2,1] /= -1.7; x[2,2] /= 1.2;", + "x", + (TinyMatrix<3>{-0.3 / -1.2, 12 / -3.1, 6.2 / 2.4, 1.2 / -1.6, 3 / -3.1, 5 / 2.4, 1 / 0.4, + 11 / -1.7, 2 / 1.2})); + } } SECTION("errors") @@ -331,83 +416,83 @@ TEST_CASE("AffectationProcessor", "[language]") { SECTION("-> B") { - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 1; let b : B; b = n;", "invalid implicit conversion: N -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 1;", "invalid implicit conversion: Z -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 2.3;", "invalid implicit conversion: R -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = \"foo\";", "invalid implicit conversion: string -> B"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 1; let b : B; b = n;", "undefined affectation type: B = N"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 1;", "undefined affectation type: B = Z"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 2.3;", "undefined affectation type: B = R"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = \"foo\";", "undefined affectation type: B = string"); } SECTION("-> N") { - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2.3;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = \"bar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2.3;", "undefined affectation type: N = R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = \"bar\";", "undefined affectation type: N = string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += 1.1;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += \"foo\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += 1.1;", "undefined affectation type: N += R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += \"foo\";", "undefined affectation type: N += string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= 1.1;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= \"bar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= 1.1;", "undefined affectation type: N -= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= \"bar\";", "undefined affectation type: N -= string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= 2.51;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= \"foobar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= 2.51;", "undefined affectation type: N *= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= \"foobar\";", "undefined affectation type: N *= string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= 2.51;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= \"foo\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= 2.51;", "undefined affectation type: N /= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= \"foo\";", "undefined affectation type: N /= string"); } SECTION("-> Z") { - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = -2.3;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = \"foobar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = -2.3;", "undefined affectation type: Z = R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = \"foobar\";", "undefined affectation type: Z = string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += 1.1;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += \"foo\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += 1.1;", "undefined affectation type: Z += R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += \"foo\";", "undefined affectation type: Z += string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= 2.1;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= \"bar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= 2.1;", "undefined affectation type: Z -= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= \"bar\";", "undefined affectation type: Z -= string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= -2.51;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= \"foobar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= -2.51;", "undefined affectation type: Z *= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= \"foobar\";", "undefined affectation type: Z *= string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 4; z /= -2.;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z /= \"foo\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 4; z /= -2.;", "undefined affectation type: Z /= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z /= \"foo\";", "undefined affectation type: Z /= string"); } SECTION("-> R") { - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = \"foobar\";", "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.3; x += \"foo\";", "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.1; x -= \"bar\";", "invalid implicit conversion: string -> R"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = \"foobar\";", "undefined affectation type: R = string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.3; x += \"foo\";", "undefined affectation type: R += string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.1; x -= \"bar\";", "undefined affectation type: R -= string"); CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 1.2; x *= \"foobar\";", - "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x =-2.3; x /= \"foo\";", "invalid implicit conversion: string -> R"); + "undefined affectation type: R *= string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x =-2.3; x /= \"foo\";", "undefined affectation type: R /= string"); } SECTION("-> R^n") { - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = \"foobar\";", "invalid implicit conversion: string -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = \"foobar\";", "invalid implicit conversion: string -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = \"foobar\";", "undefined affectation type: R^2 = string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = \"foobar\";", "undefined affectation type: R^3 = string"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 3.2;", "invalid implicit conversion: R -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 2.3;", "invalid implicit conversion: R -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 3.2;", "undefined affectation type: R^2 = R"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 2.3;", "undefined affectation type: R^3 = R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 4;", "invalid implicit conversion: Z -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 3;", "invalid implicit conversion: Z -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 4;", "invalid integral value (0 is the solely valid value)"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 3;", "invalid integral value (0 is the solely valid value)"); CHECK_AFFECTATION_THROWS_WITH("let x : R^1, x = 0; let y : R^2, y = x;", - "invalid implicit conversion: R^1 -> R^2"); + "undefined affectation type: R^2 = R^1"); CHECK_AFFECTATION_THROWS_WITH("let x : R^1, x = 0; let y : R^3, y = x;", - "invalid implicit conversion: R^1 -> R^3"); + "undefined affectation type: R^3 = R^1"); CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 0; let y : R^1, y = x;", - "invalid implicit conversion: R^2 -> R^1"); + "undefined affectation type: R^1 = R^2"); CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 0; let y : R^3, y = x;", - "invalid implicit conversion: R^2 -> R^3"); + "undefined affectation type: R^3 = R^2"); CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 0; let y : R^1, y = x;", - "invalid implicit conversion: R^3 -> R^1"); + "undefined affectation type: R^1 = R^3"); CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 0; let y : R^2, y = x;", - "invalid implicit conversion: R^3 -> R^2"); + "undefined affectation type: R^2 = R^3"); } } } diff --git a/tests/test_AffectationToTupleProcessor.cpp b/tests/test_AffectationToTupleProcessor.cpp index 963e931bb0bba74e10063649896e0e11db109601..fccd584c6eb7c164e9b9389d8395f56bc3803ca0 100644 --- a/tests/test_AffectationToTupleProcessor.cpp +++ b/tests/test_AffectationToTupleProcessor.cpp @@ -79,6 +79,23 @@ let s :(string); s = x; let s :(R^1); s = 1.3; )", "s", (std::vector<TinyVector<1>>{TinyVector<1>{1.3}})); + + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<2, double>{1, 2, 3, 4}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A :R^2x2, A = (1,2,3,4); +let s :(string); s = A; +)", + "s", (std::vector<std::string>{A_string})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(R^1x1); s = 1.3; +)", + "s", (std::vector<TinyMatrix<1>>{TinyMatrix<1>{1.3}})); } SECTION("Affectations from list") @@ -137,6 +154,45 @@ let x : R^1, x = 1; let t :(R^1); t = (x,2); )", "t", (std::vector<TinyVector<1>>{TinyVector<1>{1}, TinyVector<1>{2}})); + + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<2, double>{1, 2, 3, 4}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A : R^2x2, A = (1,2,3,4); +let s : (string); s = (2.,3, A); +)", + "s", (std::vector<std::string>{std::to_string(2.), std::to_string(3), A_string})); + + CHECK_AFFECTATION_RESULT(R"( +let A : R^2x2, A = (1,2,3,4); +let t :(R^2x2); t = (A,0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{1, 2, 3, 4}, TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2x2); t = ((1,2,3,4),0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{1, 2, 3, 4}, TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2x2); t = (0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^3x3); t = 0; +)", + "t", (std::vector<TinyMatrix<3>>{TinyMatrix<3>{0, 0, 0, 0, 0, 0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let x : R^1x1, x = 1; +let t :(R^1x1); t = (x,2); +)", + "t", (std::vector<TinyMatrix<1>>{TinyMatrix<1>{1}, TinyMatrix<1>{2}})); } SECTION("Affectations from tuple") @@ -153,6 +209,18 @@ let s :(string); s = x; )", "s", (std::vector<std::string>{x_string})); + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<3, double>{1, 2, 3, 4, 5, 6, 7, 8, 9}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A :(R^3x3), A = ((1,2,3,4,5,6,7,8,9)); +let s :(string); s = A; +)", + "s", (std::vector<std::string>{A_string})); + CHECK_AFFECTATION_RESULT(R"( let x :(R), x = (1,2,3); let s :(string); s = x; diff --git a/tests/test_ArraySubscriptProcessor.cpp b/tests/test_ArraySubscriptProcessor.cpp index a48390214ac9cfd38bfa063613c0dc2880fe1f6a..43386ea66a5a207d319c8dd681dad4bb4a779b28 100644 --- a/tests/test_ArraySubscriptProcessor.cpp +++ b/tests/test_ArraySubscriptProcessor.cpp @@ -108,6 +108,55 @@ let x2 : R, x2 = x[2]; CHECK_EVALUATION_RESULT(data, "x2", double{3}); } + SECTION("R^1x1 component access") + { + std::string_view data = R"( +let x : R^1x1, x = 1; +let x00: R, x00 = x[0,0]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + } + + SECTION("R^2x2 component access") + { + std::string_view data = R"( +let x : R^2x2, x = (1,2,3,4); +let x00: R, x00 = x[0,0]; +let x01: R, x01 = x[0,1]; +let x10: R, x10 = x[1,0]; +let x11: R, x11 = x[1,1]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "x10", double{3}); + CHECK_EVALUATION_RESULT(data, "x11", double{4}); + } + + SECTION("R^3x3 component access") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let x00 : R, x00 = x[0,0]; +let x01 : R, x01 = x[0,1]; +let x02 : R, x02 = x[0,2]; +let x10 : R, x10 = x[1,0]; +let x11 : R, x11 = x[1,1]; +let x12 : R, x12 = x[1,2]; +let x20 : R, x20 = x[2,0]; +let x21 : R, x21 = x[2,1]; +let x22 : R, x22 = x[2,2]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "x02", double{3}); + CHECK_EVALUATION_RESULT(data, "x10", double{4}); + CHECK_EVALUATION_RESULT(data, "x11", double{5}); + CHECK_EVALUATION_RESULT(data, "x12", double{6}); + CHECK_EVALUATION_RESULT(data, "x20", double{7}); + CHECK_EVALUATION_RESULT(data, "x21", double{8}); + CHECK_EVALUATION_RESULT(data, "x22", double{9}); + } + SECTION("R^d component access from integer expression") { std::string_view data = R"( @@ -125,6 +174,23 @@ let z0: R, z0 = z[(2-2)*1]; CHECK_EVALUATION_RESULT(data, "z0", double{8}); } + SECTION("R^dxd component access from integer expression") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let x01: R, x01 = x[3-2-1,2+3-4]; + +let y : R^2x2, y = (2,7,6,-2); +let y11: R, y11 = y[2/2, 3/1-2]; + +let z : R^1x1, z = 8; +let z00: R, z00 = z[(2-2)*1, (3-1)*2-4]; +)"; + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "y11", double{-2}); + CHECK_EVALUATION_RESULT(data, "z00", double{8}); + } + SECTION("error invalid index type") { SECTION("R index type") diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 00f757fa68d935cdffbff876205e4ba5a16d1e68..b00a9649b6cf8a5f3b57c27d9d63e74ad5ed68b6 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -59,6 +59,52 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t); } + SECTION("R*R^2 -> R^2") + { + std::function c = [](double a, TinyVector<2> x) -> TinyVector<2> { return a * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(double, TinyVector<2>)> embedded_c{c}; + + double a_arg = 2.3; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + + SECTION("R^2x2*R^2 -> R^2") + { + std::function c = [](TinyMatrix<2> A, TinyVector<2> x) -> TinyVector<2> { return A * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(TinyMatrix<2>, TinyVector<2>)> embedded_c{c}; + + TinyMatrix<2> a_arg = {2.3, 1, -2, 3}; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::matrix_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + SECTION("POD BuiltinFunctionEmbedder") { std::function c = [](double x, uint64_t i) -> bool { return x > i; }; diff --git a/tests/test_BuiltinFunctionRegister.hpp b/tests/test_BuiltinFunctionRegister.hpp index c1c1eb29b74fb9c5e6f8a170f76c7b4b2157c7b7..bc36e2ea29d37cbcd2464ddc1e080bff9e0dd591 100644 --- a/tests/test_BuiltinFunctionRegister.hpp +++ b/tests/test_BuiltinFunctionRegister.hpp @@ -68,6 +68,28 @@ class test_BuiltinFunctionRegister std::make_shared<BuiltinFunctionEmbedder<double(TinyVector<3>, TinyVector<2>)>>( [](TinyVector<3> x, TinyVector<2> y) -> double { return x[0] * y[1] + (y[0] - x[2]) * x[1]; }))); + m_name_builtin_function_map.insert( + std::make_pair("RtoR11", std::make_shared<BuiltinFunctionEmbedder<TinyMatrix<1>(double)>>( + [](double r) -> TinyMatrix<1> { return {r}; }))); + + m_name_builtin_function_map.insert( + std::make_pair("R11toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<1>)>>( + [](TinyMatrix<1> x) -> double { return x(0, 0); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<2>)>>( + [](TinyMatrix<2> x) -> double { return x(0, 0) + x(0, 1) + x(1, 0) + x(1, 1); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33toR", std::make_shared<BuiltinFunctionEmbedder<double(const TinyMatrix<3>&)>>( + [](const TinyMatrix<3>& x) -> double { return x(0, 0) + x(1, 1) + x(2, 2); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<3>, TinyMatrix<2>)>>( + [](TinyMatrix<3> x, TinyMatrix<2> y) -> double { + return (x(0, 0) + x(1, 1) + x(2, 2)) * (y(0, 0) + y(0, 1) + y(1, 0) + y(1, 1)); + }))); + m_name_builtin_function_map.insert( std::make_pair("fidToR", std::make_shared<BuiltinFunctionEmbedder<double(const FunctionSymbolId&)>>( [](const FunctionSymbolId&) -> double { return 0; }))); @@ -116,6 +138,21 @@ class test_BuiltinFunctionRegister m_name_builtin_function_map.insert( std::make_pair("tuple_R3ToR", std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyVector<3>>)>>( [](const std::vector<TinyVector<3>>&) -> double { return 0; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R11ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<1>>&)>>( + [](const std::vector<TinyMatrix<1>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R22ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<2>>&)>>( + [](const std::vector<TinyMatrix<2>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R33ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<3>>)>>( + [](const std::vector<TinyMatrix<3>>&) -> double { return 0; }))); } void diff --git a/tests/test_FunctionArgumentConverter.cpp b/tests/test_FunctionArgumentConverter.cpp index dfa637bf735c66e02dab8c4345ccbf74e7a1e069..b0a191fc5c2474a65c34c8c5e5b822b5046cb8a2 100644 --- a/tests/test_FunctionArgumentConverter.cpp +++ b/tests/test_FunctionArgumentConverter.cpp @@ -75,6 +75,32 @@ TEST_CASE("FunctionArgumentConverter", "[language]") " unexpected aggregate value type"); } + SECTION("FunctionTinyMatrixArgumentConverter") + { + const TinyMatrix<3> x3{1.7, 2.9, -3, 4, 5.2, 6.1, -7, 8.3, 9.05}; + FunctionTinyMatrixArgumentConverter<TinyMatrix<3>, TinyMatrix<3>> converter0{0}; + converter0.convert(execution_policy, TinyMatrix{x3}); + + const double x1 = 6.3; + FunctionTinyMatrixArgumentConverter<TinyMatrix<1>, double> converter1{1}; + converter1.convert(execution_policy, double{x1}); + + AggregateDataVariant values{std::vector<DataVariant>{6.3, 3.2, 4ul, 2.3, -3.1, 6.7, 3.6, 2ul, 1.1}}; + FunctionTinyMatrixArgumentConverter<TinyMatrix<3>, TinyMatrix<3>> converter2{2}; + converter2.convert(execution_policy, values); + + REQUIRE(std::get<TinyMatrix<3>>(execution_policy.currentContext()[0]) == x3); + REQUIRE(std::get<TinyMatrix<1>>(execution_policy.currentContext()[1]) == TinyMatrix<1>{x1}); + REQUIRE(std::get<TinyMatrix<3>>(execution_policy.currentContext()[2]) == + TinyMatrix<3>{6.3, 3.2, 4ul, 2.3, -3.1, 6.7, 3.6, 2ul, 1.1}); + + AggregateDataVariant bad_values{std::vector<DataVariant>{6.3, 3.2, std::string{"bar"}, true}}; + + REQUIRE_THROWS_WITH(converter2.convert(execution_policy, bad_values), std::string{"unexpected error: "} + + demangle<std::string>() + + " unexpected aggregate value type"); + } + SECTION("FunctionTupleArgumentConverter") { const TinyVector<3> x3{1.7, 2.9, -3}; diff --git a/tests/test_FunctionProcessor.cpp b/tests/test_FunctionProcessor.cpp index 59386bbae65a807841ead561899e0a61719a82d3..64fe6776ee427db8b2e3641ee952cff2477279ca 100644 --- a/tests/test_FunctionProcessor.cpp +++ b/tests/test_FunctionProcessor.cpp @@ -406,6 +406,79 @@ let fx:R^3, fx = f(3); } } + SECTION("R^dxd functions (single value)") + { + SECTION(" R^1x1 -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let x:R^1x1, x = 3; + +let fx:R^1x1, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); + } + + SECTION(" R^2x2 -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> 2*x; +let x:R^2x2, x = (3, 7, 6, -2); + +let fx:R^2x2, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<2>{3, 7, 6, -2})); + } + + SECTION(" R^3x3 -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> 2*x; +let x:R^3x3, x = (2, 4, 7, 1, 3, 5, -6, 2, -3); + +let fx:R^3x3, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<3>{2, 4, 7, 1, 3, 5, -6, 2, -3})); + } + + SECTION(" R -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> 2*x; +let x:R, x = 3; + +let fx:R^1x1, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); + } + + SECTION(" R*R -> R^2x2") + { + std::string_view data = R"( +let f : R*R -> R^2x2, (x,y) -> (2*x, 3*y, 5*(x-y), 2*x-y); +let fx:R^2x2, fx = f(2, 3); +)"; + + const double x = 2; + const double y = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (TinyMatrix<2>{2 * x, 3 * y, 5 * (x - y), 2 * x - y})); + } + + SECTION(" R -> R^3x3") + { + std::string_view data = R"( +let f : R -> R^3x3, x -> (x, 2*x, x*x, 3*x, 2+x, x-1, x+0.5, 2*x-1, 1/x); + +let fx:R^3x3, fx = f(3); +)"; + + const double x = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", + (TinyMatrix<3>{x, 2 * x, x * x, 3 * x, 2 + x, x - 1, x + 0.5, 2 * x - 1, + 1 / x})); + } + } + SECTION("multi-expression functions (using R^d)") { SECTION(" R -> R*R^1*R^2*R^3") @@ -479,6 +552,89 @@ let (x, x1, x2, x3):R*R^1*R^2*R^3, (x, x1, x2, x3) = f(y2, 0); } } + SECTION("multi-expression functions (using R^dxd)") + { + SECTION(" R -> R*R^1x1*R^2x2*R^3x3") + { + std::string_view data = R"( +let f : R -> R*R^1x1*R^2x2*R^3x3, x -> (x+1, 2*x, (x-2, x+2, 3, 2), (1, 0.5*x, x*x, x+1, 1/x, 2, x*x, 2*x-1, 3*x)); + +let (x, x11, x22, x33):R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(3); +)"; + + const double x = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x + 1})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{2 * x})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x - 2, x + 2, 3, 2})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", + (TinyMatrix<3>{1, 0.5 * x, x * x, x + 1, 1 / x, 2, x * x, 2 * x - 1, 3 * x})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,0], x33[1,2], (x33[0,1], x22[1,1], x22[0,0], x33[2,2]), x22[0,0]*x33); + +let y22:R^2x2, y22 = (2.3, 4.1, 6, -3); +let y33:R^3x3, y33 = (1.2, 1.3, 2.1, 3.2, -1.5, 2.3, -0.2, 3.1, -2.6); +let(x, x11, x22, x33) : R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, y33); +)"; + + const TinyMatrix<2> x22{2.3, 4.1, 6, -3}; + const TinyMatrix<3> x33{1.2, 1.3, 2.1, 3.2, -1.5, 2.3, -0.2, 3.1, -2.6}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 0)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x33(0, 1), x22(1, 1), x22(0, 0), x33(2, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", (TinyMatrix<3>{x22(0, 0) * x33})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3 [with 0 as argument]") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,1], x33[1,2], (x33[0,1], x22[1,0], x22[0,1], x33[2,2]), + (x22[1,0], x33[0,2]+x22[1,1], x33[2,2], + x33[2,0], x33[2,0]+x22[0,0], x33[1,1], + x33[2,1], x33[1,2]+x22[1,1], x33[0,0])); + +let y22:R^2x2, y22 = (2.3, 4.1, 3.1, 1.7); +let y33:R^3x3, y33 = (2.7, 3.1, 2.1, + 0.3, 1.2, 1.6, + 1.7, 2.2, 1.4); +let (x, x11, x22, x33) : R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, y33); +)"; + + TinyMatrix<2> x22{2.3, 4.1, 3.1, 1.7}; + TinyMatrix<3> x33{2.7, 3.1, 2.1, 0.3, 1.2, 1.6, 1.7, 2.2, 1.4}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 1)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x33(0, 1), x22(1, 0), x22(0, 1), x33(2, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", + (TinyMatrix<3>{x22(1, 0), x33(0, 2) + x22(1, 1), x33(2, 2), x33(2, 0), + x33(2, 0) + x22(0, 0), x33(1, 1), x33(2, 1), + x33(1, 2) + x22(1, 1), x33(0, 0)})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3 [with 0 in result]") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,0], x33[1,1], 0, 0); + +let y22:R^2x2, y22 = (2.3, 4.1, 3.1, 1.7); +let (x, x11, x22, x33):R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, 0); +)"; + + TinyMatrix<2> x22{2.3, 4.1, 3.1, 1.7}; + TinyMatrix<3> x33{zero}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 0)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 1)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{zero})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", (TinyMatrix<3>{zero})); + } + } + SECTION("function composition") { SECTION("N -> N -> R") @@ -518,7 +674,6 @@ let x:R, x = pow(f(2)); SECTION("R -> R^2 -> R") { std::string_view data = R"( -import math; let f : R -> R^2, x -> (x+1, x*2); let g : R^2 -> R, x -> x[0] + x[1]; @@ -532,7 +687,6 @@ let x:R, x = g(f(3)); SECTION("R -> R^2*R^3 -> R") { std::string_view data = R"( -import math; let f : R -> R^2*R^3, x -> ((x+1, x*2), (6*x, 7-x, x/2.3)); let g : R^2*R^3 -> R, (x, y) -> x[0]*x[1] + y[0]*y[1]-y[2]; @@ -542,5 +696,37 @@ let x:R, x = g(f(3)); double x0 = 3; CHECK_FUNCTION_EVALUATION_RESULT(data, "x", double{(x0 + 1) * x0 * 2 + 6 * x0 * (7 - x0) - x0 / 2.3}); } + + SECTION("R -> R^2x2 -> R") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x+1, x*2, x-1, x); +let g : R^2x2 -> R, A -> A[0,0] + 2*A[1,1] + 3*A[0,1]+ A[1, 0]; + +let x:R, x = g(f(3)); +)"; + + const double x = 3; + const TinyMatrix<2> A{x + 1, x * 2, x - 1, x}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", double{A(0, 0) + 2 * A(1, 1) + 3 * A(0, 1) + A(1, 0)}); + } + + SECTION("R -> R^2x2*R^3x3 -> R") + { + std::string_view data = R"( +let f : R -> R^2x2*R^3x3, x -> ((x+1, x*2, x-1, x), (6*x, 7-x, x/2.3, -x, 2*x, x/2.5, x*x, 2*x, x)); +let g : R^2x2*R^3x3 -> R, (A22, A33) -> A22[0,0]*A22[1,1] + (A33[0,0]*A33[1,0]-A33[2,2])*A22[0,1]-A33[2,0]*A33[0,2]-A22[1,1]; + +let x:R, x = g(f(3)); +)"; + + const double x = 3; + const TinyMatrix<2> A22{x + 1, x * 2, x - 1, x}; + const TinyMatrix<3> A33{6 * x, 7 - x, x / 2.3, -x, 2 * x, x / 2.5, x * x, 2 * x, x}; + + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", + double{A22(0, 0) * A22(1, 1) + (A33(0, 0) * A33(1, 0) - A33(2, 2)) * A22(0, 1) - + A33(2, 0) * A33(0, 2) - A22(1, 1)}); + } } } diff --git a/tests/test_ListAffectationProcessor.cpp b/tests/test_ListAffectationProcessor.cpp index e2a0919952220b6ba6a2c723a31604674254357d..0dc1b8d77c0e6dc23cb2681c30e8006d9e939f86 100644 --- a/tests/test_ListAffectationProcessor.cpp +++ b/tests/test_ListAffectationProcessor.cpp @@ -75,12 +75,16 @@ TEST_CASE("ListAffectationProcessor", "[language]") { SECTION("ListAffectations") { - SECTION("R*R^2*string") + SECTION("R*R^2*R^2x2*string") { - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "x", double{1.2}); - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "u", + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "x", + double{1.2}); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "u", (TinyVector<2>{2, 3})); - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "s", std::string{"foo"}); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "A", + (TinyMatrix<2>{4, 3, 2, 1})); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "s", + std::string{"foo"}); } SECTION("compound with string conversion") @@ -114,11 +118,28 @@ TEST_CASE("ListAffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3*R^2*R^1, (x,y,z) = (0,0,0);)", "z", (TinyVector<1>{zero})); } + SECTION("compound R^dxd from '0'") + { + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "x", (TinyMatrix<3>{zero})); + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "y", (TinyMatrix<2>{zero})); + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "z", (TinyMatrix<1>{zero})); + } + SECTION("compound with subscript values") { CHECK_AFFECTATION_RESULT(R"(let x:R^3; (x[0], x[2], x[1]) = (4, 6, 5);)", "x", (TinyVector<3>{4, 5, 6})); CHECK_AFFECTATION_RESULT(R"(let x:R^2; (x[1], x[0]) = (3, 6);)", "x", (TinyVector<2>{6, 3})); CHECK_AFFECTATION_RESULT(R"(let x:R^1; let y:R; (y, x[0]) = (4, 2.3);)", "x", (TinyVector<1>{2.3})); } + + SECTION("compound with subscript values") + { + CHECK_AFFECTATION_RESULT( + R"(let x:R^3x3; (x[0,0], x[1,0], x[1,2], x[2,0], x[0,1], x[0,2], x[1,1], x[2,1], x[2,2]) = (1, 4, 6, 7, 2, 3, 5, 8, 9);)", + "x", (TinyMatrix<3>{1, 2, 3, 4, 5, 6, 7, 8, 9})); + CHECK_AFFECTATION_RESULT(R"(let x:R^2x2; (x[1,1], x[0,0], x[1,0], x[0,1]) = (3, 6, 2, 4);)", "x", + (TinyMatrix<2>{6, 4, 2, 3})); + CHECK_AFFECTATION_RESULT(R"(let x:R^1x1; let y:R; (y, x[0,0]) = (4, 2.3);)", "x", (TinyMatrix<1>{2.3})); + } } } diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp index 3a7ca2813845fa72b0db312b6d259195c162c09d..4b04c7f5e088c1bdec314c91414fac4924b8bd6b 100644 --- a/tests/test_PugsFunctionAdapter.cpp +++ b/tests/test_PugsFunctionAdapter.cpp @@ -77,11 +77,17 @@ let ZplusZ: Z*Z -> Z, (x,y) -> x+y; let RplusR: R*R -> R, (x,y) -> x+y; let RRtoR2: R*R -> R^2, (x,y) -> (x+y, x-y); let R3times2: R^3 -> R^3, x -> 2*x; +let R33times2: R^3x3 -> R^3x3, x -> 2*x; let BtoR1: B -> R^1, b -> not b; +let BtoR11: B -> R^1x1, b -> not b; let NtoR1: N -> R^1, n -> n*n; +let NtoR11: N -> R^1x1, n -> n*n; let ZtoR1: Z -> R^1, z -> -z; +let ZtoR11: Z -> R^1x1, z -> -z; let RtoR1: R -> R^1, x -> x*x; +let RtoR11: R -> R^1x1, x -> x*x; let R3toR3zero: R^3 -> R^3, x -> 0; +let R33toR33zero: R^3x3 -> R^3x3, x -> 0; )"; string_input input{data, "test.pgs"}; @@ -194,6 +200,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == 2 * x); } + { + auto [i_symbol, found] = symbol_table->find("R33times2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{2, 3, 4, 1, 6, 5, 9, 7, 8}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == 2 * x); + } + { auto [i_symbol, found] = symbol_table->find("BtoR1", position); REQUIRE(found); @@ -218,6 +237,30 @@ let R3toR3zero: R^3 -> R^3, x -> 0; } } + { + auto [i_symbol, found] = symbol_table->find("BtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + { + const bool b = true; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + + { + const bool b = false; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + } + { auto [i_symbol, found] = symbol_table->find("NtoR1", position); REQUIRE(found); @@ -231,6 +274,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == n * n); } + { + auto [i_symbol, found] = symbol_table->find("NtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const uint64_t n = 4; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(uint64_t)>::one_arg(function_symbol_id, n); + + REQUIRE(result == n * n); + } + { auto [i_symbol, found] = symbol_table->find("ZtoR1", position); REQUIRE(found); @@ -244,6 +300,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == -z); } + { + auto [i_symbol, found] = symbol_table->find("ZtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const int64_t z = 3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(int64_t)>::one_arg(function_symbol_id, z); + + REQUIRE(result == -z); + } + { auto [i_symbol, found] = symbol_table->find("RtoR1", position); REQUIRE(found); @@ -257,6 +326,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == x * x); } + { + auto [i_symbol, found] = symbol_table->find("RtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 3.3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(double)>::one_arg(function_symbol_id, x); + + REQUIRE(result == x * x); + } + { auto [i_symbol, found] = symbol_table->find("R3toR3zero", position); REQUIRE(found); @@ -269,6 +351,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == TinyVector<3>{0, 0, 0}); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR33zero", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{1, 0, 0, 0, 1, 0, 0, 0, 1}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == TinyMatrix<3>{0, 0, 0, 0, 0, 0, 0, 0, 0}); + } } SECTION("Errors calls") @@ -280,6 +375,7 @@ let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); let RtoNS: R -> N*string, x -> (1, "foo"); let RtoR: R -> R, x -> 2*x; +let R33toR22: R^3x3 -> R^2x2, x -> (x[0,0], x[0,1]+x[0,2], x[2,0]*x[1,1], x[2,1]+x[2,2]); )"; string_input input{data, "test.pgs"}; @@ -385,5 +481,19 @@ let RtoR: R -> R, x -> 2*x; "note: expecting R -> R^3\n" "note: provided function RtoR: R -> R"); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR22", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<double(double)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R -> R\n" + "note: provided function R33toR22: R^3x3 -> R^2x2"); + } } } diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp index 0d3510dc361a9fd43b1a17559b514c2c3748bc08..ca5a12fbddb68e01324158c14ef42546b27ddec1 100644 --- a/tests/test_TinyMatrix.cpp +++ b/tests/test_TinyMatrix.cpp @@ -202,10 +202,25 @@ TEST_CASE("TinyMatrix", "[algebra]") } } + SECTION("checking for sizes") + { + REQUIRE(TinyMatrix<1>{}.nbRows() == 1); + REQUIRE(TinyMatrix<1>{}.nbColumns() == 1); + REQUIRE(TinyMatrix<1>{}.dimension() == 1); + + REQUIRE(TinyMatrix<2>{}.nbRows() == 2); + REQUIRE(TinyMatrix<2>{}.nbColumns() == 2); + REQUIRE(TinyMatrix<2>{}.dimension() == 4); + + REQUIRE(TinyMatrix<3>{}.nbRows() == 3); + REQUIRE(TinyMatrix<3>{}.nbColumns() == 3); + REQUIRE(TinyMatrix<3>{}.dimension() == 9); + } + SECTION("checking for matrices output") { REQUIRE(Catch::Detail::stringify(A) == "[(1,2,3)(4,5,6)(7,8,9)]"); - REQUIRE(Catch::Detail::stringify(TinyMatrix<1, int>(7)) == "7"); + REQUIRE(Catch::Detail::stringify(TinyMatrix<1, int>(7)) == "[(7)]"); } #ifndef NDEBUG