Skip to content
Snippets Groups Projects
Commit 708facd9 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add a few missing infrastructures for function dealing with R^dxd

This commit namely adds tuple to matrices conversions and allows to
deal with product of spaces containing R^dxd
parent 472d646f
Branches
Tags
1 merge request!71Feature/language tiny matrices
This commit is part of merge request !71. Comments created here will be created in the context of that merge request.
...@@ -256,6 +256,14 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ...@@ -256,6 +256,14 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
<< dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"; << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions";
throw ParseError(message.str(), image_domain_node.begin()); throw ParseError(message.str(), image_domain_node.begin());
} }
} else if (image_domain_node.is_type<language::matrix_type>()) {
ASTNodeDataType image_type = getMatrixDataType(image_domain_node);
if (image_type.nbRows() * image_type.nbColumns() != nb_image_expressions) {
std::ostringstream message;
message << "expecting " << image_type.nbRows() * image_type.nbColumns() << " scalar expressions or an "
<< dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions";
throw ParseError(message.str(), image_domain_node.begin());
}
} else { } else {
std::ostringstream message; std::ostringstream message;
message << "number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow message << "number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
#include <language/ast/ASTNodeDataTypeFlattener.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp>
#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/ast/ASTNodeNaturalConversionChecker.hpp>
#include <language/node_processor/FunctionProcessor.hpp> #include <language/node_processor/FunctionProcessor.hpp>
#include <language/node_processor/TupleToTinyMatrixProcessor.hpp>
#include <language/node_processor/TupleToTinyVectorProcessor.hpp> #include <language/node_processor/TupleToTinyVectorProcessor.hpp>
#include <language/utils/FunctionTable.hpp> #include <language/utils/FunctionTable.hpp>
#include <language/utils/SymbolTable.hpp> #include <language/utils/SymbolTable.hpp>
#include <utils/Exceptions.hpp>
template <typename SymbolType> template <typename SymbolType>
std::unique_ptr<IFunctionArgumentConverter> std::unique_ptr<IFunctionArgumentConverter>
...@@ -551,6 +553,77 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ...@@ -551,6 +553,77 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
node.m_node_processor = std::move(function_processor); node.m_node_processor = std::move(function_processor);
} }
} else if (function_image_domain.is_type<language::matrix_type>()) {
ASTNodeDataType matrix_type = getMatrixDataType(function_image_domain);
if ((matrix_type.nbRows() == 1) and (matrix_type.nbColumns() == 1) and
(function_expression.m_data_type != ASTNodeDataType::vector_t)) {
ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()};
} else {
ASTNodeNaturalConversionChecker{function_expression, matrix_type};
}
if (function_expression.is_type<language::expression_list>()) {
Assert(matrix_type.nbRows() * matrix_type.nbColumns() == function_expression.children.size());
for (size_t i = 0; i < matrix_type.nbRows() * matrix_type.nbColumns(); ++i) {
function_processor->addFunctionExpressionProcessor(
this->_getFunctionProcessor(ASTNodeDataType::build<ASTNodeDataType::double_t>(), node,
*function_expression.children[i]));
}
switch (matrix_type.nbRows()) {
case 2: {
node.m_node_processor =
std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 2>>(node, std::move(function_processor));
break;
}
case 3: {
node.m_node_processor =
std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 3>>(node, std::move(function_processor));
break;
}
// LCOV_EXCL_START
default: {
throw ParseError("unexpected error: invalid vector_t dimension", std::vector{node.begin()});
}
// LCOV_EXCL_STOP
}
} else if (function_expression.is_type<language::integer>()) {
if (std::stoi(function_expression.string()) == 0) {
switch (matrix_type.nbRows()) {
case 1: {
node.m_node_processor =
std::make_unique<FunctionExpressionProcessor<TinyMatrix<1>, ZeroType>>(function_expression);
break;
}
case 2: {
node.m_node_processor =
std::make_unique<FunctionExpressionProcessor<TinyMatrix<2>, ZeroType>>(function_expression);
break;
}
case 3: {
node.m_node_processor =
std::make_unique<FunctionExpressionProcessor<TinyMatrix<3>, ZeroType>>(function_expression);
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("invalid matrix dimensions");
}
// LCOV_EXCL_STOP
}
} else {
// LCOV_EXCL_START
throw UnexpectedError("expecting 0");
// LCOV_EXCL_STOP
}
} else {
function_processor->addFunctionExpressionProcessor(
this->_getFunctionProcessor(matrix_type, node, function_expression));
node.m_node_processor = std::move(function_processor);
}
} else { } else {
if (function_expression.is_type<language::expression_list>()) { if (function_expression.is_type<language::expression_list>()) {
ASTNode& image_domain_node = function_image_domain; ASTNode& image_domain_node = function_image_domain;
......
#ifndef TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP
#define TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP
#include <language/ast/ASTNode.hpp>
#include <language/node_processor/INodeProcessor.hpp>
template <typename TupleProcessorT, size_t N>
class TupleToTinyMatrixProcessor final : public INodeProcessor
{
private:
ASTNode& m_node;
std::unique_ptr<TupleProcessorT> m_tuple_processor;
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
AggregateDataVariant v = std::get<AggregateDataVariant>(m_tuple_processor->execute(exec_policy));
Assert(v.size() == N * N);
TinyMatrix<N> A;
for (size_t i = 0, l = 0; i < N; ++i) {
for (size_t j = 0; j < N; ++j, ++l) {
std::visit(
[&](auto&& Aij) {
using ValueT = std::decay_t<decltype(Aij)>;
if constexpr (std::is_arithmetic_v<ValueT>) {
A(i, j) = Aij;
} else {
// LCOV_EXCL_START
Assert(false, "unexpected value type");
// LCOV_EXCL_STOP
}
},
v[l]);
}
}
return DataVariant{std::move(A)};
}
TupleToTinyMatrixProcessor(ASTNode& node) : m_node{node}, m_tuple_processor{std::make_unique<TupleProcessorT>(node)}
{}
TupleToTinyMatrixProcessor(ASTNode& node, std::unique_ptr<TupleProcessorT>&& tuple_processor)
: m_node{node}, m_tuple_processor{std::move(tuple_processor)}
{}
};
#endif // TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment