diff --git a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp index 037bd767fc1064aab5b0848497749860b89a7350..3d9e58f07a0fc8ec1e9af130d886a2cf91aef407 100644 --- a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp @@ -30,17 +30,6 @@ ASTNodeAffectationExpressionBuilder::ASTNodeAffectationExpressionBuilder(ASTNode } }(); - // // Special treatment dedicated to R^1 to be able to initialize them - // if (((target_data_type != source_data_type) and (target_data_type == ASTNodeDataType::vector_t) and - // (target_data_type.dimension() == 1)) or - // // Special treatment for R^d vectors and operator *= - // ((target_data_type == ASTNodeDataType::vector_t) and (source_data_type != ASTNodeDataType::vector_t) and - // n.is_type<language::multiplyeq_op>())) { - // ASTNodeNaturalConversionChecker{*n.children[1], ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - // } else { - // ASTNodeNaturalConversionChecker{*n.children[1], target_data_type}; - // } - const auto& optional_processor_builder = OperatorRepository::instance().getAffectationProcessorBuilder(affectation_name); diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp index f6fe085672ad9e6fa0443bd113f343c2e3603dae..1cac7a59f967ea930bcd661c268044a09f778742 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp @@ -72,6 +72,8 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (content_node->is_type<language::vector_type>()) { content_node->m_data_type = getVectorDataType(*type_node.children[0]); + } else if (content_node->is_type<language::matrix_type>()) { + content_node->m_data_type = getMatrixDataType(*type_node.children[0]); } else if (content_node->is_type<language::string_type>()) { content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else { @@ -153,6 +155,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); } else if (n.is_type<language::vector_type>()) { n.m_data_type = getVectorDataType(n); + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = getMatrixDataType(n); } else if (n.is_type<language::literal>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); @@ -317,6 +321,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const value_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (image_node.is_type<language::vector_type>()) { value_type = getVectorDataType(image_node); + } else if (image_node.is_type<language::matrix_type>()) { + value_type = getMatrixDataType(image_node); } else if (image_node.is_type<language::string_type>()) { value_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } @@ -549,6 +555,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::double_t>()); } else if (n.is_type<language::vector_type>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getVectorDataType(n)); + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getMatrixDataType(n)); } else if (n.is_type<language::name_list>() or n.is_type<language::lvalue_list>() or n.is_type<language::function_argument_list>() or n.is_type<language::expression_list>()) { std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp index af8b1cbc0b4ca138f86601755493e90a089ff0c2..5363575e202cb514c0d738d6935c8f1061fc057a 100644 --- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp @@ -83,6 +83,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy } }; + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + switch (node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (node_sub_data_type.m_parent_node.children.size() == parameter_v.dimension()) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_argument_converter_for_string = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { return std::make_unique<FunctionArgumentToStringConverter>(parameter_id); }; @@ -116,9 +158,28 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy return get_function_argument_converter_for_vector(TinyVector<3>{}); } } - [[fallthrough]]; + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP } + case ASTNodeDataType::matrix_t: { + Assert(parameter_symbol.attributes().dataType().nbRows() == parameter_symbol.attributes().dataType().nbColumns()); + switch (parameter_symbol.attributes().dataType().nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP + } // LCOV_EXCL_START default: { throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); @@ -268,6 +329,51 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r } }; + auto get_function_processor_for_expression_matrix = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> { + using ReturnT = std::decay_t<decltype(return_v)>; + switch (function_component_expression.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((function_component_expression.m_data_type.nbRows() == return_v.nbRows()) and + (function_component_expression.m_data_type.nbColumns() == return_v.nbColumns())) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ReturnT>>(function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (function_component_expression.children.size() == return_v.dimension()) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, AggregateDataVariant>>( + function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (function_component_expression.is_type<language::integer>()) { + if (std::stoi(function_component_expression.string()) == 0) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ZeroType>>(function_component_expression); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_processor_for_value = [&]() { switch (return_value_type) { case ASTNodeDataType::bool_t: { @@ -304,6 +410,30 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(return_value_type.nbRows() == return_value_type.nbColumns()); + + switch (return_value_type.nbRows()) { + case 1: { + if (function_component_expression.m_data_type == ASTNodeDataType::matrix_t) { + return get_function_processor_for_expression_matrix(TinyMatrix<1>{}); + } else { + return get_function_processor_for_expression_value(TinyMatrix<1>{}); + } + } + case 2: { + return get_function_processor_for_expression_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_processor_for_expression_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid dimension in returned type", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_processor_for_expression_value(std::string{}); } diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp index 85251f8446b73f2947decc15b1857db13f2386f7..59ac02e78d2ac0c77a9138c8e5cd6cb341df35a5 100644 --- a/src/language/node_processor/AffectationProcessor.hpp +++ b/src/language/node_processor/AffectationProcessor.hpp @@ -559,6 +559,41 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor m_node.children[1]->children[i]->begin()); // LCOV_EXCL_STOP } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + if constexpr (std::is_same_v<T, AggregateDataVariant>) { + ValueT& A = tuple_value[i]; + Assert(A.nbRows() * A.nbColumns() == child_value.size()); + for (size_t j = 0, l = 0; j < A.nbRows(); ++j) { + for (size_t k = 0; k < A.nbColumns(); ++k, ++l) { + std::visit( + [&](auto&& Ajk) { + using Ti = std::decay_t<decltype(Ajk)>; + if constexpr (std::is_convertible_v<Ti, typename ValueT::data_type>) { + A(k, k) = Ajk; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } + }, + child_value[l]); + } + } + } else if constexpr (std::is_same_v<T, int64_t>) { + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + tuple_value[i](0, 0) = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } } else { // LCOV_EXCL_START throw ParseError("unexpected error: unexpected right hand side type in affectation", diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index e226e5336f4add578c5df2abe102121d067f6402..3f8cac14f0558d7a2fa86cccfbfebe5199389347 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -121,6 +121,55 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver FunctionTinyVectorArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} }; +template <typename ExpectedValueType, typename ProvidedValueType> +class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + if constexpr (std::is_same_v<ExpectedValueType, ProvidedValueType>) { + std::visit( + [&](auto&& v) { + using ValueT = std::decay_t<decltype(v)>; + if constexpr (std::is_same_v<ValueT, ExpectedValueType>) { + exec_policy.currentContext()[m_argument_id] = std::move(value); + } else if constexpr (std::is_same_v<ValueT, AggregateDataVariant>) { + ExpectedValueType matrix_value{}; + for (size_t i = 0, l = 0; i < matrix_value.nbRows(); ++i) { + for (size_t j = 0; j < matrix_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& A_ij) { + using Vi_T = std::decay_t<decltype(A_ij)>; + if constexpr (std::is_arithmetic_v<Vi_T>) { + matrix_value(i, j) = A_ij; + } else { + throw UnexpectedError(demangle<Vi_T>() + " unexpected aggregate value type"); + } + }, + v[l]); + } + } + exec_policy.currentContext()[m_argument_id] = std::move(matrix_value); + } + }, + value); + } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; + } else { + static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); + exec_policy.currentContext()[m_argument_id] = + std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + } + return {}; + } + + FunctionTinyMatrixArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} +}; + template <typename ContentType, typename ProvidedValueType> class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter { diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp index b28c2ce499538818ce60b1ac36a1d158114409de..f8c86ea9e5cb51e169e48ecadc0b4770ef873cc8 100644 --- a/src/language/node_processor/FunctionProcessor.hpp +++ b/src/language/node_processor/FunctionProcessor.hpp @@ -21,18 +21,35 @@ class FunctionExpressionProcessor final : public INodeProcessor if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { return m_function_expression.execute(exec_policy); } else if constexpr (std::is_same_v<AggregateDataVariant, ExpressionValueType>) { - static_assert(is_tiny_vector_v<ReturnType>, "unexpected return type"); + static_assert(is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>, "unexpected return type"); ReturnType return_value{}; auto value = std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)); - for (size_t i = 0; i < ReturnType::Dimension; ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - return_value[i] = vi; - } - }, - value[i]); + if constexpr (is_tiny_vector_v<ReturnType>) { + for (size_t i = 0; i < ReturnType::Dimension; ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value[i] = vi; + } + }, + value[i]); + } + } else { + static_assert(is_tiny_matrix_v<ReturnType>); + + for (size_t i = 0, l = 0; i < return_value.nbRows(); ++i) { + for (size_t j = 0; j < return_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Vi_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value(i, j) = Aij; + } + }, + value[l]); + } + } } return return_value; } else if constexpr (std::is_same_v<ReturnType, std::string>) { diff --git a/src/utils/PugsTraits.hpp b/src/utils/PugsTraits.hpp index e34e0433bb69316af695fec1c8d571ef23e03155..f51b455cb5ea4efa678dd980d1fec2721ab1fd2a 100644 --- a/src/utils/PugsTraits.hpp +++ b/src/utils/PugsTraits.hpp @@ -84,7 +84,14 @@ inline constexpr bool is_tiny_vector_v = false; template <size_t N, typename T> inline constexpr bool is_tiny_vector_v<TinyVector<N, T>> = true; -// Traits is_tiny_vector +// Traits is_tiny_matrix + +template <typename T> +inline constexpr bool is_tiny_matrix_v = false; + +template <size_t N, typename T> +inline constexpr bool is_tiny_matrix_v<TinyMatrix<N, T>> = true; + // helper to check if a type is part of a variant template <typename T, typename V>