Skip to content
Snippets Groups Projects
Commit c013d0d6 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Use function domains to check its PugsFunctionAdapter compatibility

Related to issue #21
parent 7d6874a1
Branches
Tags
1 merge request!52Issue/21
...@@ -36,12 +36,14 @@ class PugsFunctionAdapter<OutputType(InputType...)> ...@@ -36,12 +36,14 @@ class PugsFunctionAdapter<OutputType(InputType...)>
template <size_t I> template <size_t I>
[[nodiscard]] PUGS_INLINE static bool [[nodiscard]] PUGS_INLINE static bool
_checkValidArgumentDataType(const ASTNode& arg_expression) noexcept _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
{ {
using Arg = std::tuple_element_t<I, InputTuple>; using Arg = std::tuple_element_t<I, InputTuple>;
constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>; constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;
const ASTNodeDataType& arg_data_type = arg_expression.m_data_type;
Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();
return isNaturalConversion(expected_input_data_type, arg_data_type); return isNaturalConversion(expected_input_data_type, arg_data_type);
} }
...@@ -55,53 +57,28 @@ class PugsFunctionAdapter<OutputType(InputType...)> ...@@ -55,53 +57,28 @@ class PugsFunctionAdapter<OutputType(InputType...)>
} }
[[nodiscard]] PUGS_INLINE static bool [[nodiscard]] PUGS_INLINE static bool
_checkValidInputDataType(const ASTNode& input_expression) noexcept _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
{ {
if constexpr (NArgs == 1) { if constexpr (NArgs == 1) {
return _checkValidArgumentDataType<0>(input_expression); return _checkValidArgumentDataType<0>(input_domain_expression);
} else { } else {
if (input_expression.children.size() != NArgs) { if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
(input_domain_expression.children.size() != NArgs)) {
return false; return false;
} }
using IndexSequence = std::make_index_sequence<NArgs>; using IndexSequence = std::make_index_sequence<NArgs>;
return _checkAllInputDataType(input_expression, IndexSequence{}); return _checkAllInputDataType(input_domain_expression, IndexSequence{});
} }
} }
[[nodiscard]] PUGS_INLINE static bool [[nodiscard]] PUGS_INLINE static bool
_checkValidOutputDataType(const ASTNode& return_expression) noexcept _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
{ {
constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>; constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
const ASTNodeDataType& return_data_type = return_expression.m_data_type; const ASTNodeDataType& return_data_type = output_domain_expression.m_data_type.contentType();
if (not isNaturalConversion(return_data_type, expected_return_data_type)) { return isNaturalConversion(return_data_type, expected_return_data_type);
if (expected_return_data_type == ASTNodeDataType::vector_t) {
if (return_data_type == ASTNodeDataType::list_t) {
if (expected_return_data_type.dimension() != return_expression.children.size()) {
return false;
} else {
for (const auto& child : return_expression.children) {
const ASTNodeDataType& data_type = child->m_data_type;
if (not isNaturalConversion(data_type, ast_node_data_type_from<double>)) {
return false;
}
}
}
} else if ((expected_return_data_type.dimension() == 1) and
isNaturalConversion(return_data_type, ast_node_data_type_from<double>)) {
return true;
} else if (return_data_type == ast_node_data_type_from<int64_t>) {
// 0 is the only valid value for vectors
return (return_expression.string() == "0");
} else {
return false;
}
} else {
return false;
}
}
return true;
} }
template <typename Arg, typename... RemainingArgs> template <typename Arg, typename... RemainingArgs>
...@@ -124,10 +101,10 @@ class PugsFunctionAdapter<OutputType(InputType...)> ...@@ -124,10 +101,10 @@ class PugsFunctionAdapter<OutputType(InputType...)>
PUGS_INLINE static void PUGS_INLINE static void
_checkFunction(const FunctionDescriptor& function) _checkFunction(const FunctionDescriptor& function)
{ {
bool has_valid_input = _checkValidInputDataType(*function.definitionNode().children[0]); bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]); bool has_valid_output = _checkValidOutputDomain(*function.domainMappingNode().children[1]);
if (not(has_valid_input and has_valid_output)) { if (not(has_valid_input_domain and has_valid_output)) {
std::ostringstream error_message; std::ostringstream error_message;
error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
<< _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>) << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
......
...@@ -275,7 +275,7 @@ let R3toR3zero: R^3 -> R^3, x -> 0; ...@@ -275,7 +275,7 @@ let R3toR3zero: R^3 -> R^3, x -> 0;
{ {
std::string_view data = R"( std::string_view data = R"(
let R1toR1: R^1 -> R^1, x -> x; let R1toR1: R^1 -> R^1, x -> x;
let R3toR3: R^3 -> R^3, x -> 1; let R3toR3: R^3 -> R^3, x -> 0;
let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z);
let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]);
let RtoNS: R -> N*string, x -> (1, "foo"); let RtoNS: R -> N*string, x -> (1, "foo");
...@@ -322,9 +322,9 @@ let RtoR: R -> R, x -> 2*x; ...@@ -322,9 +322,9 @@ let RtoR: R -> R, x -> 2*x;
const TinyVector<3> x{2, 1, 3}; const TinyVector<3> x{2, 1, 3};
FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);
REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x), REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<2>(TinyVector<3>)>::one_arg(function_symbol_id, x),
"error: invalid function type\n" "error: invalid function type\n"
"note: expecting R^3 -> R^3\n" "note: expecting R^3 -> R^2\n"
"note: provided function R3toR3: R^3 -> R^3"); "note: provided function R3toR3: R^3 -> R^3");
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment