Skip to content
Snippets Groups Projects
Commit bc6a2d3c authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tuple expressions for vector initialization

This allows to write

``
let f : R -> R^3, x -> (x, 2*x+1, -2);
R^3 x = f(2.3);
``

or even

``
let f : R -> R^3*R, x -> ((x, 2*x+1, -2), x-1);
R^3*R (x,t) = f(2.3);
``

Note that it is not completely functional since
``
let f : R -> R^3, x -> (x, 2*x+1, -2);
R^3 x = 2*f(2.3);
``
does not work!
The return type of `f(2.3)` is incorrect (should be an R^3 but
remains a `typename` which is improper for these calculations).
parent 6b2407fc
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -223,6 +223,7 @@ using selector = parse_tree::selector< ...@@ -223,6 +223,7 @@ using selector = parse_tree::selector<
N_set, N_set,
Z_set, Z_set,
R_set, R_set,
tuple_expression,
vector_type, vector_type,
string_type, string_type,
cout_kw, cout_kw,
......
...@@ -94,6 +94,21 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ...@@ -94,6 +94,21 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
n.m_data_type = ASTNodeDataType::int_t; n.m_data_type = ASTNodeDataType::int_t;
} else if (n.is_type<language::vector_type>()) { } else if (n.is_type<language::vector_type>()) {
n.m_data_type = getVectorDataType(n); n.m_data_type = getVectorDataType(n);
} else if (n.is_type<language::tuple_expression>()) {
for (auto&& child : n.children) {
this->_buildNodeDataTypes(*child);
}
for (auto&& child : n.children) {
ASTNodeNaturalConversionChecker{*child, child->m_data_type, ASTNodeDataType::double_t};
}
if (n.children.size() <= 3) {
n.m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, n.children.size()};
} else {
throw parse_error("invalid vector dimension (must be lesser than 3)", n.begin());
}
} else if (n.is_type<language::literal>()) { } else if (n.is_type<language::literal>()) {
n.m_data_type = ASTNodeDataType::string_t; n.m_data_type = ASTNodeDataType::string_t;
} else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) { } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) {
...@@ -194,12 +209,26 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ...@@ -194,12 +209,26 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];
ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1]; ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1];
this->_buildNodeDataTypes(image_domain_node);
for (auto& child : image_domain_node.children) {
this->_buildNodeDataTypes(*child);
}
const size_t nb_image_domains = const size_t nb_image_domains =
(image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1; (image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1;
const size_t nb_image_expressions = const size_t nb_image_expressions =
(image_expression_node.is_type<language::expression_list>()) ? image_expression_node.children.size() : 1; (image_expression_node.is_type<language::expression_list>()) ? image_expression_node.children.size() : 1;
if (nb_image_domains != nb_image_expressions) { if (nb_image_domains != nb_image_expressions) {
if (image_domain_node.is_type<language::vector_type>()) {
ASTNodeDataType image_type = getVectorDataType(image_domain_node);
if (image_type.dimension() != nb_image_expressions) {
std::ostringstream message;
message << "note: expecting " << image_domain_node.m_data_type.dimension() << " expressions found "
<< nb_image_expressions << std::ends;
throw parse_error(message.str(), image_domain_node.begin());
}
} else {
std::ostringstream message; std::ostringstream message;
message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow
<< image_domain_node.string() << rang::style::reset << rang::style::bold << image_domain_node.string() << rang::style::reset << rang::style::bold
...@@ -207,10 +236,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ...@@ -207,10 +236,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
<< image_expression_node.string() << rang::style::reset << std::ends; << image_expression_node.string() << rang::style::reset << std::ends;
throw parse_error(message.str(), image_domain_node.begin()); throw parse_error(message.str(), image_domain_node.begin());
} }
this->_buildNodeDataTypes(image_domain_node);
for (auto& child : image_domain_node.children) {
this->_buildNodeDataTypes(*child);
} }
auto check_image_type = [&](const ASTNode& image_node) { auto check_image_type = [&](const ASTNode& image_node) {
...@@ -343,8 +368,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ...@@ -343,8 +368,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];
Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t);
ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
if (image_domain_node.is_type<language::type_expression>()) { if (image_domain_node.is_type<language::type_expression>()) {
data_type = image_domain_node.m_data_type; data_type = image_domain_node.m_data_type;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <node_processor/LocalNameProcessor.hpp> #include <node_processor/LocalNameProcessor.hpp>
#include <node_processor/NameProcessor.hpp> #include <node_processor/NameProcessor.hpp>
#include <node_processor/OStreamProcessor.hpp> #include <node_processor/OStreamProcessor.hpp>
#include <node_processor/TupleToVectorProcessor.hpp>
#include <node_processor/ValueProcessor.hpp> #include <node_processor/ValueProcessor.hpp>
#include <node_processor/WhileProcessor.hpp> #include <node_processor/WhileProcessor.hpp>
...@@ -38,6 +39,24 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n) ...@@ -38,6 +39,24 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n)
ASTNodeAffectationExpressionBuilder{n}; ASTNodeAffectationExpressionBuilder{n};
} }
} else if (n.is_type<language::tuple_expression>()) {
switch (n.children.size()) {
case 1: {
n.m_node_processor = std::make_unique<TupleToVectorProcessor<1>>(n);
break;
}
case 2: {
n.m_node_processor = std::make_unique<TupleToVectorProcessor<2>>(n);
break;
}
case 3: {
n.m_node_processor = std::make_unique<TupleToVectorProcessor<3>>(n);
break;
}
default: {
throw parse_error("unexpected error: invalid tuple size", n.begin());
}
}
} else if (n.is_type<language::function_definition>()) { } else if (n.is_type<language::function_definition>()) {
n.m_node_processor = std::make_unique<FakeProcessor>(); n.m_node_processor = std::make_unique<FakeProcessor>();
......
...@@ -336,12 +336,19 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ...@@ -336,12 +336,19 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
ASTNode& function_expression = *function_descriptor.definitionNode().children[1]; ASTNode& function_expression = *function_descriptor.definitionNode().children[1];
if (function_expression.is_type<language::expression_list>()) { if (function_expression.is_type<language::expression_list>()) {
Assert(function_image_domain.is_type<language::type_expression>()); if (function_image_domain.is_type<language::vector_type>()) {
for (size_t i = 0; i < function_expression.children.size(); ++i) {
function_processor->addFunctionExpressionProcessor(
this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node,
*function_expression.children[i]));
}
} else {
ASTNode& image_domain_node = function_image_domain; ASTNode& image_domain_node = function_image_domain;
for (size_t i = 0; i < function_expression.children.size(); ++i) { for (size_t i = 0; i < function_expression.children.size(); ++i) {
add_component_expression(*function_expression.children[i], *image_domain_node.children[i]); add_component_expression(*function_expression.children[i], *image_domain_node.children[i]);
} }
}
} else { } else {
add_component_expression(function_expression, function_image_domain); add_component_expression(function_expression, function_image_domain);
} }
......
...@@ -213,7 +213,9 @@ struct logical_or : list_must< logical_and, or_op >{}; ...@@ -213,7 +213,9 @@ struct logical_or : list_must< logical_and, or_op >{};
struct expression : logical_or {}; struct expression : logical_or {};
struct expression_list : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{}; struct tuple_expression : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{};
struct expression_list : seq< open_parent, sor< tuple_expression, expression >, plus< if_must< COMMA, sor< tuple_expression, expression > > >, close_parent >{};
struct affect_op : sor< eq_op, multiplyeq_op, divideeq_op, pluseq_op, minuseq_op > {}; struct affect_op : sor< eq_op, multiplyeq_op, divideeq_op, pluseq_op, minuseq_op > {};
......
#ifndef TUPLE_TO_VECTOR_PROCESSOR_HPP
#define TUPLE_TO_VECTOR_PROCESSOR_HPP
#include <node_processor/INodeProcessor.hpp>
#include <node_processor/ASTNodeExpressionListProcessor.hpp>
template <size_t N>
class TupleToVectorProcessor final : public INodeProcessor
{
private:
ASTNode& m_node;
ASTNodeExpressionListProcessor m_list_processor;
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
AggregateDataVariant v = std::get<AggregateDataVariant>(m_list_processor.execute(exec_policy));
Assert(v.size() == N);
TinyVector<N> x;
for (size_t i = 0; i < N; ++i) {
std::visit(
[&](auto&& v) {
using ValueT = std::decay_t<decltype(v)>;
if constexpr (std::is_arithmetic_v<ValueT>) {
x[i] = v;
} else {
Assert(false, "unexpected value type");
}
},
v[i]);
}
return DataVariant{std::move(x)};
}
TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_list_processor{node} {}
};
#endif // TUPLE_TO_VECTOR_PROCESSOR_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment