#ifndef AFFECTATION_PROCESSOR_HPP #define AFFECTATION_PROCESSOR_HPP #include <node_processor/INodeProcessor.hpp> #include <SymbolTable.hpp> template <typename Op> struct AffOp; template <> struct AffOp<language::multiplyeq_op> { template <typename A, typename B> PUGS_INLINE void eval(A& a, const B& b) { a *= b; } }; template <> struct AffOp<language::divideeq_op> { template <typename A, typename B> PUGS_INLINE void eval(A& a, const B& b) { a /= b; } }; template <> struct AffOp<language::pluseq_op> { template <typename A, typename B> PUGS_INLINE void eval(A& a, const B& b) { a += b; } }; template <> struct AffOp<language::minuseq_op> { template <typename A, typename B> PUGS_INLINE void eval(A& a, const B& b) { a -= b; } }; struct IAffectationExecutor { virtual void affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) = 0; virtual ~IAffectationExecutor() = default; }; template <typename OperatorT, typename ValueT, typename DataT> class AffectationExecutor final : public IAffectationExecutor { private: ValueT& m_lhs; static inline const bool _is_defined{[] { if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) { if constexpr (not std::is_same_v<OperatorT, language::eq_op>) { return false; } } return true; }()}; public: AffectationExecutor(ASTNode& node, ValueT& lhs) : m_lhs(lhs) { // LCOV_EXCL_START if constexpr (not _is_defined) { throw parse_error("unexpected error: invalid operands to affectation expression", std::vector{node.begin()}); } // LCOV_EXCL_STOP } PUGS_INLINE void affect(ExecutionPolicy&, DataVariant&& rhs) { if constexpr (_is_defined) { if constexpr (not std::is_same_v<DataT, ZeroType>) { if constexpr (std::is_same_v<ValueT, std::string>) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { if constexpr (std::is_same_v<std::string, DataT>) { m_lhs = std::get<DataT>(rhs); } else if constexpr (std::is_arithmetic_v<DataT>) { m_lhs = std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; os << std::get<DataT>(rhs) << std::ends; m_lhs = os.str(); } } else { static_assert(std::is_same_v<OperatorT, language::pluseq_op>, "unexpected operator type"); if constexpr (std::is_same_v<std::string, DataT>) { m_lhs += std::get<std::string>(rhs); } else if constexpr (std::is_arithmetic_v<DataT>) { m_lhs += std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; os << std::get<DataT>(rhs) << std::ends; m_lhs += os.str(); } } } else { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { if constexpr (std::is_same_v<ValueT, DataT>) { m_lhs = std::get<DataT>(rhs); } else { m_lhs = static_cast<ValueT>(std::get<DataT>(rhs)); } } else { AffOp<OperatorT>().eval(m_lhs, std::get<DataT>(rhs)); } } } else if (std::is_same_v<OperatorT, language::eq_op>) { m_lhs = ValueT{zero}; } else { static_assert(std::is_same_v<OperatorT, language::eq_op>, "unexpected operator type"); } } } }; template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT> class ComponentAffectationExecutor final : public IAffectationExecutor { private: ArrayT& m_lhs_array; ASTNode& m_index_expression; static inline const bool _is_defined{[] { if constexpr (not std::is_same_v<typename ArrayT::data_type, ValueT>) { return false; } else if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) { if constexpr (not std::is_same_v<OperatorT, language::eq_op>) { return false; } } return true; }()}; public: ComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression) : m_lhs_array{lhs_array}, m_index_expression{index_expression} { // LCOV_EXCL_START if constexpr (not _is_defined) { throw parse_error("unexpected error: invalid operands to affectation expression", std::vector{node.begin()}); } // LCOV_EXCL_STOP } PUGS_INLINE void affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) { if constexpr (_is_defined) { const int64_t index_value = [&](DataVariant&& value_variant) -> int64_t { int64_t index_value = 0; std::visit( [&](auto&& value) { using IndexValueT = std::decay_t<decltype(value)>; if constexpr (std::is_integral_v<IndexValueT>) { index_value = value; } else { throw parse_error("unexpected error: invalid index type", std::vector{m_index_expression.begin()}); } }, value_variant); return index_value; }(m_index_expression.execute(exec_policy)); if constexpr (std::is_same_v<ValueT, std::string>) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { if constexpr (std::is_same_v<std::string, DataT>) { m_lhs_array[index_value] = std::get<DataT>(rhs); } else { m_lhs_array[index_value] = std::to_string(std::get<DataT>(rhs)); } } else { static_assert(std::is_same_v<OperatorT, language::pluseq_op>, "unexpected operator type"); if constexpr (std::is_same_v<std::string, DataT>) { m_lhs_array[index_value] += std::get<std::string>(rhs); } else { m_lhs_array[index_value] += std::to_string(std::get<DataT>(rhs)); } } } else { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { if constexpr (std::is_same_v<ValueT, DataT>) { m_lhs_array[index_value] = std::get<DataT>(rhs); } else { m_lhs_array[index_value] = static_cast<ValueT>(std::get<DataT>(rhs)); } } else { AffOp<OperatorT>().eval(m_lhs_array[index_value], std::get<DataT>(rhs)); } } } } }; template <typename OperatorT, typename ValueT, typename DataT> class AffectationProcessor final : public INodeProcessor { private: ASTNode& m_node; std::unique_ptr<IAffectationExecutor> m_affectation_executor; public: DataVariant execute(ExecutionPolicy& exec_policy) { m_affectation_executor->affect(exec_policy, m_node.children[1]->execute(exec_policy)); return {}; } AffectationProcessor(ASTNode& node) : m_node{node} { if (node.children[0]->is_type<language::name>()) { const std::string& symbol = m_node.children[0]->string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin()); Assert(found); DataVariant& value = i_symbol->attributes().value(); if (not std::holds_alternative<ValueT>(value)) { value = ValueT{}; } using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>; m_affectation_executor = std::make_unique<AffectationExecutorT>(m_node, std::get<ValueT>(value)); } else if (node.children[0]->is_type<language::subscript_expression>()) { auto& array_subscript_expression = *node.children[0]; auto& array_expression = *array_subscript_expression.children[0]; Assert(array_expression.is_type<language::name>()); const std::string& symbol = array_expression.string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, array_subscript_expression.begin()); Assert(found); DataVariant& value = i_symbol->attributes().value(); if (array_expression.m_data_type != ASTNodeDataType::vector_t) { throw parse_error("unexpected error: invalid lhs (expecting R^d)", std::vector{array_subscript_expression.begin()}); } auto& index_expression = *array_subscript_expression.children[1]; switch (array_expression.m_data_type.dimension()) { case 1: { using ArrayTypeT = TinyVector<1>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); break; } case 2: { using ArrayTypeT = TinyVector<2>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); break; } case 3: { using ArrayTypeT = TinyVector<3>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); break; } default: { throw parse_error("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); } } } else { throw parse_error("unexpected error: invalid lhs", std::vector{node.children[0]->begin()}); } } }; template <typename OperatorT, typename ValueT> class AffectationFromListProcessor final : public INodeProcessor { private: ASTNode& m_node; DataVariant* m_lhs; public: DataVariant execute(ExecutionPolicy& exec_policy) { AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy)); static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors"); ValueT v; for (size_t i = 0; i < v.dimension(); ++i) { std::visit( [&](auto&& child_value) { using T = std::decay_t<decltype(child_value)>; if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or std::is_same_v<T, double>) { v[i] = child_value; } else { throw parse_error("unexpected error: unexpected right hand side type in affectation", m_node.begin()); } }, children_values[i]); } *m_lhs = v; return {}; } AffectationFromListProcessor(ASTNode& node) : m_node{node} { const std::string& symbol = m_node.children[0]->string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin()); Assert(found); m_lhs = &i_symbol->attributes().value(); } }; template <typename ValueT> class AffectationFromZeroProcessor final : public INodeProcessor { private: ASTNode& m_node; DataVariant* m_lhs; public: DataVariant execute(ExecutionPolicy&) { *m_lhs = ValueT{zero}; return {}; } AffectationFromZeroProcessor(ASTNode& node) : m_node{node} { const std::string& symbol = m_node.children[0]->string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin()); Assert(found); m_lhs = &i_symbol->attributes().value(); } }; template <typename OperatorT> class ListAffectationProcessor final : public INodeProcessor { private: ASTNode& m_node; std::vector<std::unique_ptr<IAffectationExecutor>> m_affectation_executor_list; public: template <typename ValueT, typename DataT> void add(ASTNode& lhs_node) { using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>; if (lhs_node.is_type<language::name>()) { const std::string& symbol = lhs_node.string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->end()); Assert(found); DataVariant& value = i_symbol->attributes().value(); if (not std::holds_alternative<ValueT>(value)) { value = ValueT{}; } m_affectation_executor_list.emplace_back(std::make_unique<AffectationExecutorT>(m_node, std::get<ValueT>(value))); } else if (lhs_node.is_type<language::subscript_expression>()) { auto& array_subscript_expression = lhs_node; auto& array_expression = *array_subscript_expression.children[0]; Assert(array_expression.is_type<language::name>()); const std::string& symbol = array_expression.string(); auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, array_subscript_expression.begin()); Assert(found); DataVariant& value = i_symbol->attributes().value(); if (array_expression.m_data_type != ASTNodeDataType::vector_t) { throw parse_error("unexpected error: invalid lhs (expecting R^d)", std::vector{array_subscript_expression.begin()}); } auto& index_expression = *array_subscript_expression.children[1]; switch (array_expression.m_data_type.dimension()) { case 1: { using ArrayTypeT = TinyVector<1>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor_list.emplace_back( std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); break; } case 2: { using ArrayTypeT = TinyVector<2>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor_list.emplace_back( std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); break; } case 3: { using ArrayTypeT = TinyVector<3>; if (not std::holds_alternative<ArrayTypeT>(value)) { value = ArrayTypeT{}; } using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; m_affectation_executor_list.emplace_back( std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); break; } default: { throw parse_error("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); } } } else { throw parse_error("unexpected error: invalid left hand side", std::vector{lhs_node.begin()}); } } DataVariant execute(ExecutionPolicy& exec_policy) { AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy)); Assert(m_affectation_executor_list.size() == children_values.size()); for (size_t i = 0; i < m_affectation_executor_list.size(); ++i) { m_affectation_executor_list[i]->affect(exec_policy, std::move(children_values[i])); } return {}; } ListAffectationProcessor(ASTNode& node) : m_node{node} {} }; #endif // AFFECTATION_PROCESSOR_HPP