diff --git a/src/language/PEGGrammar.hpp b/src/language/PEGGrammar.hpp index d173f50745e781d2be5cade0372a0b15524f4850..2771a48e6a481a42400a6dd55ad7c67af44e8af2 100644 --- a/src/language/PEGGrammar.hpp +++ b/src/language/PEGGrammar.hpp @@ -55,6 +55,9 @@ struct character : if_must_else< one< '\\' >, escaped_c, ascii::any> {}; struct open_parent : seq< one< '(' >, ignored > {}; struct close_parent : seq< one< ')' >, ignored > {}; +struct open_bracket : seq< one< '[' >, ignored > {}; +struct close_bracket : seq< one< ']' >, ignored > {}; + struct literal : star< minus<character, one < '"' > > >{}; struct quoted_literal : if_must< one< '"' >, seq< literal, one< '"' > > >{}; @@ -153,11 +156,16 @@ struct COMMA : seq< comma , ignored > {}; struct expression; struct parented_expression : if_must< open_parent, expression, close_parent >{}; +struct vector_expression : if_must< open_bracket, list_must<expression, COMMA >, close_bracket >{}; + +struct row_expression : if_must< open_bracket, list_must<expression, COMMA >, close_bracket >{}; +struct matrix_expression : seq< open_bracket, list_must<row_expression, COMMA >, close_bracket >{}; + struct tuple_expression; struct function_argument_list : if_must< open_parent, opt< list_must< sor< tuple_expression, expression >, COMMA > >, close_parent >{}; struct function_evaluation : seq< NAME, function_argument_list > {}; -struct primary_expression : sor< BOOL, REAL, INTEGER, LITERAL, function_evaluation, NAME, parented_expression > {}; +struct primary_expression : sor< BOOL, REAL, INTEGER, LITERAL, function_evaluation, NAME, parented_expression, matrix_expression, vector_expression >{}; struct unary_plusplus : TAO_PEGTL_STRING("++") {}; struct unary_minusminus : TAO_PEGTL_STRING("--") {}; @@ -174,9 +182,6 @@ struct post_minusminus : TAO_PEGTL_STRING("--") {}; struct postfix_operator : seq< sor< post_plusplus, post_minusminus>, ignored > {}; -struct open_bracket : seq< one< '[' >, ignored > {}; -struct close_bracket : seq< one< ']' >, ignored > {}; - struct subscript_expression : if_must< open_bracket, list_must<expression, COMMA>, close_bracket >{}; struct postfix_expression : seq< primary_expression, star< sor< subscript_expression , postfix_operator> > >{}; diff --git a/src/language/ast/ASTBuilder.cpp b/src/language/ast/ASTBuilder.cpp index db7f77b29695bf82fc650ade4176592eefe5b651..47642d04e2bc9a745a7962954844cd54f74d2ba0 100644 --- a/src/language/ast/ASTBuilder.cpp +++ b/src/language/ast/ASTBuilder.cpp @@ -230,7 +230,10 @@ using selector = TAO_PEGTL_NAMESPACE::parse_tree::selector< language::R_set, language::type_name_id, language::tuple_expression, + language::vector_expression, language::vector_type, + language::matrix_expression, + language::row_expression, language::matrix_type, language::string_type, language::var_declaration, diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp index cb9e479df5f2e1f64edb4d5cb732acebcb505af6..7433ad9f72e7e7314705ce36dfec44114172955e 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp @@ -154,6 +154,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (n.is_type<language::integer>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } else if (n.is_type<language::row_expression>()) { + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::vector_type>()) { n.m_data_type = getVectorDataType(n); } else if (n.is_type<language::matrix_type>()) { @@ -322,6 +324,11 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const this->_buildNodeDataTypes(*child); } + if (n.is_type<language::vector_expression>()) { + n.m_data_type = getVectorExpressionType(n); + } else if (n.is_type<language::matrix_expression>()) { + n.m_data_type = getMatrixExpressionType(n); + } if (n.is_type<language::break_kw>() or n.is_type<language::continue_kw>()) { n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::eq_op>() or n.is_type<language::multiplyeq_op>() or diff --git a/src/language/ast/ASTNodeExpressionBuilder.cpp b/src/language/ast/ASTNodeExpressionBuilder.cpp index bbbcda43e788356d6bf4ec338643a79049069fc0..78e192767fc869ed78a61c8c951e6bfd9c806f8d 100644 --- a/src/language/ast/ASTNodeExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeExpressionBuilder.cpp @@ -18,6 +18,8 @@ #include <language/node_processor/IfProcessor.hpp> #include <language/node_processor/LocalNameProcessor.hpp> #include <language/node_processor/NameProcessor.hpp> +#include <language/node_processor/TinyMatrixExpressionProcessor.hpp> +#include <language/node_processor/TinyVectorExpressionProcessor.hpp> #include <language/node_processor/TupleToTinyVectorProcessor.hpp> #include <language/node_processor/TupleToVectorProcessor.hpp> #include <language/node_processor/ValueProcessor.hpp> @@ -75,6 +77,50 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n) } else if (n.is_type<language::false_kw>()) { n.m_node_processor = std::make_unique<ValueProcessor>(n); + } else if (n.is_type<language::vector_expression>()) { + Assert(n.m_data_type == ASTNodeDataType::vector_t); + switch (n.m_data_type.dimension()) { + case 1: { + n.m_node_processor = std::make_unique<TinyVectorExpressionProcessor<1>>(n); + break; + } + case 2: { + n.m_node_processor = std::make_unique<TinyVectorExpressionProcessor<2>>(n); + break; + } + case 3: { + n.m_node_processor = std::make_unique<TinyVectorExpressionProcessor<3>>(n); + break; + } + default: { + throw UnexpectedError("invalid vector dimension"); + } + } + } else if (n.is_type<language::matrix_expression>()) { + Assert(n.m_data_type == ASTNodeDataType::matrix_t); + Assert(n.m_data_type.numberOfRows() == n.m_data_type.numberOfColumns()); + + switch (n.m_data_type.numberOfRows()) { + case 1: { + n.m_node_processor = std::make_unique<TinyMatrixExpressionProcessor<1, 1>>(n); + break; + } + case 2: { + n.m_node_processor = std::make_unique<TinyMatrixExpressionProcessor<2, 2>>(n); + break; + } + case 3: { + n.m_node_processor = std::make_unique<TinyMatrixExpressionProcessor<3, 3>>(n); + break; + } + default: { + throw UnexpectedError("invalid matrix dimension"); + } + } + + } else if ((n.is_type<language::row_expression>())) { + n.m_node_processor = std::make_unique<FakeProcessor>(); + } else if ((n.is_type<language::function_argument_list>()) or (n.is_type<language::expression_list>())) { n.m_node_processor = std::make_unique<ASTNodeExpressionListProcessor>(n); diff --git a/src/language/node_processor/TinyMatrixExpressionProcessor.hpp b/src/language/node_processor/TinyMatrixExpressionProcessor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bcb3e65f35b301c6adbe12eeb02e6024cb9ea0da --- /dev/null +++ b/src/language/node_processor/TinyMatrixExpressionProcessor.hpp @@ -0,0 +1,46 @@ +#ifndef TINY_MATRIX_EXPRESSION_PROCESSOR_HPP +#define TINY_MATRIX_EXPRESSION_PROCESSOR_HPP + +#include <algebra/TinyMatrix.hpp> +#include <language/PEGGrammar.hpp> +#include <language/ast/ASTNode.hpp> +#include <language/node_processor/INodeProcessor.hpp> + +template <size_t M, size_t N> +class TinyMatrixExpressionProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + public: + PUGS_INLINE + DataVariant + execute(ExecutionPolicy& exec_policy) + { + TinyMatrix<M, N> A{}; + Assert(m_node.children.size() == M); + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + std::visit( + [&](auto&& x) { + using ValueT = std::decay_t<decltype(x)>; + if constexpr (std::is_arithmetic_v<ValueT>) { + A(i, j) = x; + } else { + // LCOV_EXCL_START + Assert(false, "unexpected value type"); + // LCOV_EXCL_STOP + } + }, + m_node.children[i]->children[j]->execute(exec_policy)); + } + } + + return A; + } + + TinyMatrixExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +#endif // TINY_MATRIX_EXPRESSION_PROCESSOR_HPP diff --git a/src/language/node_processor/TinyVectorExpressionProcessor.hpp b/src/language/node_processor/TinyVectorExpressionProcessor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac86b9025b54f673c3cad636aaf1178a55e436a1 --- /dev/null +++ b/src/language/node_processor/TinyVectorExpressionProcessor.hpp @@ -0,0 +1,44 @@ +#ifndef TINY_VECTOR_EXPRESSION_PROCESSOR_HPP +#define TINY_VECTOR_EXPRESSION_PROCESSOR_HPP + +#include <algebra/TinyVector.hpp> +#include <language/PEGGrammar.hpp> +#include <language/ast/ASTNode.hpp> +#include <language/node_processor/INodeProcessor.hpp> + +template <size_t Dimension> +class TinyVectorExpressionProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + public: + PUGS_INLINE + DataVariant + execute(ExecutionPolicy& exec_policy) + { + TinyVector<Dimension> v{}; + Assert(m_node.children.size() == Dimension); + + for (size_t i = 0; i < Dimension; ++i) { + std::visit( + [&](auto&& x) { + using ValueT = std::decay_t<decltype(x)>; + if constexpr (std::is_arithmetic_v<ValueT>) { + v[i] = x; + } else { + // LCOV_EXCL_START + Assert(false, "unexpected value type"); + // LCOV_EXCL_STOP + } + }, + m_node.children[i]->execute(exec_policy)); + } + + return v; + } + + TinyVectorExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +#endif // TINY_VECTOR_EXPRESSION_PROCESSOR_HPP diff --git a/src/language/utils/ASTNodeDataType.cpp b/src/language/utils/ASTNodeDataType.cpp index ddea3a6dd9a6a79fda474bb93e05645d55dd7503..462e1e26532e7a6f2ab73369b8a0e462a9eaa6e9 100644 --- a/src/language/utils/ASTNodeDataType.cpp +++ b/src/language/utils/ASTNodeDataType.cpp @@ -2,6 +2,7 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> #include <utils/PugsAssert.hpp> @@ -19,6 +20,28 @@ getVectorDataType(const ASTNode& type_node) if (not(dimension > 0 and dimension <= 3)) { throw ParseError("invalid dimension (must be 1, 2 or 3)", dimension_node.begin()); } + + return ASTNodeDataType::build<ASTNodeDataType::vector_t>(dimension); +} + +ASTNodeDataType +getVectorExpressionType(const ASTNode& vector_expression_node) +{ + if (not(vector_expression_node.is_type<language::vector_expression>() and + (vector_expression_node.children.size() > 0))) { + throw ParseError("unexpected node type", vector_expression_node.begin()); + } + + const size_t dimension = vector_expression_node.children.size(); + if (not(dimension > 0 and dimension <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", vector_expression_node.begin()); + } + + for (size_t i = 0; i < dimension; ++i) { + ASTNodeNaturalConversionChecker(*vector_expression_node.children[i], + ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } + return ASTNodeDataType::build<ASTNodeDataType::vector_t>(dimension); } @@ -52,6 +75,44 @@ getMatrixDataType(const ASTNode& type_node) return ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension0, dimension1); } +ASTNodeDataType +getMatrixExpressionType(const ASTNode& matrix_expression_node) +{ + if (not matrix_expression_node.is_type<language::matrix_expression>()) { + throw ParseError("unexpected node type", matrix_expression_node.begin()); + } + + const size_t dimension0 = matrix_expression_node.children.size(); + if (not(dimension0 > 0 and dimension0 <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", matrix_expression_node.begin()); + } + for (size_t i = 0; i < dimension0; ++i) { + if (not matrix_expression_node.children[i]->is_type<language::row_expression>()) { + throw ParseError("expecting row expression", matrix_expression_node.children[i]->begin()); + } + } + + const size_t dimension1 = matrix_expression_node.children[0]->children.size(); + if (dimension0 != dimension1) { + throw ParseError("only square matrices are supported", matrix_expression_node.begin()); + } + + for (size_t i = 1; i < dimension0; ++i) { + if (matrix_expression_node.children[i]->children.size() != dimension1) { + throw ParseError("row must have same sizes", matrix_expression_node.begin()); + } + } + + for (size_t i = 0; i < dimension0; ++i) { + for (size_t j = 0; j < dimension1; ++j) { + ASTNodeNaturalConversionChecker(*matrix_expression_node.children[i]->children[j], + ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } + } + + return ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension0, dimension1); +} + std::string dataTypeName(const ASTNodeDataType& data_type) { diff --git a/src/language/utils/ASTNodeDataType.hpp b/src/language/utils/ASTNodeDataType.hpp index ea744757cf68b5055605111d35acc14564b06b30..27f08bf5a003727b4dea779ce7156870d29af448 100644 --- a/src/language/utils/ASTNodeDataType.hpp +++ b/src/language/utils/ASTNodeDataType.hpp @@ -13,8 +13,10 @@ class ASTNode; class ASTNodeDataType; ASTNodeDataType getVectorDataType(const ASTNode& type_node); +ASTNodeDataType getVectorExpressionType(const ASTNode& vector_expression_node); ASTNodeDataType getMatrixDataType(const ASTNode& type_node); +ASTNodeDataType getMatrixExpressionType(const ASTNode& matrix_expression_node); std::string dataTypeName(const std::vector<ASTNodeDataType>& data_type_vector);