diff --git a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp index d9ae87fd35539c33dcf05014cef9004e43607feb..b135bbcd3f67127c23c044880d96ad28ee6527c5 100644 --- a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp @@ -77,6 +77,44 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; + auto add_affectation_processor_for_matrix_data = [&](const auto& value, + const ASTNodeSubDataType& node_sub_data_type) { + using ValueT = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else { + add_affectation_processor_for_data(value, node_sub_data_type); + } + } else if constexpr (std::is_same_v<ValueT, TinyMatrix<2>> or std::is_same_v<ValueT, TinyMatrix<3>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else if ((node_sub_data_type.m_data_type == ASTNodeDataType::list_t) and + (node_sub_data_type.m_parent_node.children.size() == value.nbRows() * value.nbColumns())) { + list_affectation_processor->template add<ValueT, AggregateDataVariant>(value_node); + } else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + list_affectation_processor->template add<ValueT, ZeroType>(value_node); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid operand value", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension", std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + throw ParseError("unexpected error: invalid value type", std::vector{node_sub_data_type.m_parent_node.begin()}); + } + }; + auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { switch (node_sub_data_type.m_data_type) { @@ -176,6 +214,29 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } break; } + case ASTNodeDataType::matrix_t: { + Assert(value_type.nbRows() == value_type.nbColumns()); + switch (value_type.nbRows()) { + case 1: { + add_affectation_processor_for_matrix_data(TinyMatrix<1>{}, node_sub_data_type); + break; + } + case 2: { + add_affectation_processor_for_matrix_data(TinyMatrix<2>{}, node_sub_data_type); + break; + } + case 3: { + add_affectation_processor_for_matrix_data(TinyMatrix<3>{}, node_sub_data_type); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("invalid dimension", std::vector{value_node.begin()}); + } + // LCOV_EXCL_STOP + } + break; + } case ASTNodeDataType::string_t: { add_affectation_processor_for_string_data(node_sub_data_type); break; @@ -188,12 +249,7 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; - if ((value_node.m_data_type != rhs_node_sub_data_type.m_data_type) and - (value_node.m_data_type == ASTNodeDataType::vector_t) and (value_node.m_data_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, value_node.m_data_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(rhs_node_sub_data_type, value_node.m_data_type); add_affectation_processor_for_value(value_node.m_data_type, rhs_node_sub_data_type); } diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp index c2d52bd0568a1df9a57b7fdbc2ce85cdee36ba4b..959578619b023bc28236b1ae52141b018aad8c69 100644 --- a/src/language/node_processor/AffectationProcessor.hpp +++ b/src/language/node_processor/AffectationProcessor.hpp @@ -125,20 +125,40 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs = std::get<DataT>(rhs); } else if constexpr (std::is_same_v<DataT, AggregateDataVariant>) { const AggregateDataVariant& v = std::get<AggregateDataVariant>(rhs); - static_assert(is_tiny_vector_v<ValueT>, "expecting lhs TinyVector"); - for (size_t i = 0; i < m_lhs.dimension(); ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - m_lhs[i] = vi; - } else { - // LCOV_EXCL_START - throw UnexpectedError("unexpected rhs type in affectation"); - // LCOV_EXCL_STOP - } - }, - v[i]); + if constexpr (is_tiny_vector_v<ValueT>) { + for (size_t i = 0; i < m_lhs.dimension(); ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + m_lhs[i] = vi; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[i]); + } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + for (size_t i = 0, l = 0; i < m_lhs.nbRows(); ++i) { + for (size_t j = 0; j < m_lhs.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Aij_T, double>) { + m_lhs(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + } else { + static_assert(is_tiny_matrix_v<ValueT> or is_tiny_vector_v<ValueT>, "invalid rhs type"); } } else if constexpr (std::is_same_v<TinyVector<1>, ValueT>) { std::visit(