Skip to content
Snippets Groups Projects
Commit 68da7c49 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add interpolation mechanism for functions with value in tuples of R

Up to now it was defined for list of functions of R. This was not
handy in view of boundary conditions for this kind of data.

It is not completely finished yet
- unit tests are missing
- one must allow other types (R^d, R^dxd for instance)
- integration is not coded
parent c0435206
No related branches found
No related tags found
1 merge request!167Improve fluxing based remapping
#ifndef EVALUATE_ARRAY_AT_POINTS_HPP
#define EVALUATE_ARRAY_AT_POINTS_HPP
#include <language/utils/PugsFunctionAdapter.hpp>
#include <utils/Array.hpp>
#include <utils/Table.hpp>
class FunctionSymbolId;
template <typename T>
class EvaluateArrayAtPoints;
template <typename OutputType, typename InputType>
class EvaluateArrayAtPoints<OutputType(InputType)> : public PugsFunctionAdapter<OutputType(InputType)>
{
using Adapter = PugsFunctionAdapter<OutputType(InputType)>;
public:
template <typename InputArrayT, typename OutputTableT>
static PUGS_INLINE void
evaluateTo(const FunctionSymbolId& function_symbol_id, const InputArrayT& position, OutputTableT& table)
{
static_assert(std::is_same_v<std::remove_const_t<typename InputArrayT::data_type>, InputType>,
"invalid input data type");
static_assert(std::is_same_v<std::remove_const_t<typename OutputTableT::data_type>, OutputType>,
"invalid output data type");
auto& expression = Adapter::getFunctionExpression(function_symbol_id);
auto convert_result = Adapter::getArrayResultConverter(expression.m_data_type);
auto context_list = Adapter::getContextList(expression);
using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;
if constexpr (std::is_arithmetic_v<OutputType>) {
table.fill(0);
} else if constexpr (is_tiny_vector_v<OutputType> or is_tiny_matrix_v<OutputType>) {
table.fill(zero);
} else {
static_assert(std::is_same_v<OutputType, double>, "unexpected output type");
}
parallel_for(size(position), [=, &expression, &tokens](typename InputArrayT::index_type i) {
const int32_t t = tokens.acquire();
auto& execution_policy = context_list[t];
Adapter::convertArgs(execution_policy.currentContext(), position[i]);
auto result = expression.execute(execution_policy);
auto&& array = convert_result(std::move(result));
for (size_t j = 0; j < array.size(); ++j) {
table[i][j] = array[j];
}
tokens.release(t);
});
}
template <class InputArrayT>
static PUGS_INLINE Table<OutputType>
evaluate(const FunctionSymbolId& function_symbol_id, const InputArrayT& position)
{
static_assert(std::is_same_v<std::remove_const_t<typename InputArrayT::data_type>, InputType>,
"invalid input data type");
Table<OutputType> value(size(position));
evaluateArrayTo(function_symbol_id, position, value);
return value;
}
};
#endif // EVALUATE_ARRAY_AT_POINTS_HPP
#ifndef INTERPOLATE_ITEM_ARRAY_HPP
#define INTERPOLATE_ITEM_ARRAY_HPP
#include <language/utils/EvaluateArrayAtPoints.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemArray.hpp>
#include <mesh/ItemType.hpp>
......@@ -12,12 +13,34 @@ class InterpolateItemArray<OutputType(InputType)>
{
static constexpr size_t Dimension = OutputType::Dimension;
private:
PUGS_INLINE static bool
_isSingleTupleFunction(const std::vector<FunctionSymbolId>& function_symbol_id_list)
{
if (function_symbol_id_list.size() > 1) {
return false;
} else {
Assert(function_symbol_id_list.size() == 1);
const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0];
return (function_symbol_id.descriptor().domainMappingNode().children[1]->m_data_type == ASTNodeDataType::tuple_t);
}
}
public:
template <ItemType item_type>
PUGS_INLINE static ItemArray<OutputType, item_type>
interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list,
const ItemValue<const InputType, item_type>& position)
{
if (_isSingleTupleFunction(function_symbol_id_list)) {
const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0];
const size_t table_size = function_symbol_id.descriptor().definitionNode().children[1]->children.size();
ItemArray<OutputType, item_type> item_array{*position.connectivity_ptr(), table_size};
EvaluateArrayAtPoints<OutputType(const InputType)>::evaluateTo(function_symbol_id, position, item_array);
return item_array;
} else {
ItemArray<OutputType, item_type> item_array{*position.connectivity_ptr(), function_symbol_id_list.size()};
for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) {
......@@ -31,6 +54,7 @@ class InterpolateItemArray<OutputType(InputType)>
return item_array;
}
}
template <ItemType item_type>
PUGS_INLINE static Table<OutputType>
......@@ -38,6 +62,23 @@ class InterpolateItemArray<OutputType(InputType)>
const ItemValue<const InputType, item_type>& position,
const Array<const ItemIdT<item_type>>& list_of_items)
{
if (_isSingleTupleFunction(function_symbol_id_list)) {
Array<InputType> item_position{list_of_items.size()};
using ItemId = ItemIdT<item_type>;
parallel_for(
list_of_items.size(), PUGS_LAMBDA(size_t i_item) {
ItemId item_id = list_of_items[i_item];
item_position[i_item] = position[item_id];
});
const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0];
const size_t table_size = function_symbol_id.descriptor().definitionNode().children[1]->children.size();
Table<OutputType> table{list_of_items.size(), table_size};
EvaluateArrayAtPoints<OutputType(const InputType)>::evaluateTo(function_symbol_id, item_position, table);
return table;
} else {
Table<OutputType> table{list_of_items.size(), function_symbol_id_list.size()};
for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) {
......@@ -51,6 +92,7 @@ class InterpolateItemArray<OutputType(InputType)>
return table;
}
}
template <ItemType item_type>
PUGS_INLINE static Table<OutputType>
......
......@@ -288,6 +288,44 @@ class PugsFunctionAdapter<OutputType(InputType...)>
}
}
[[nodiscard]] PUGS_INLINE static std::function<std::vector<OutputType>(DataVariant&& result)>
getArrayResultConverter(const ASTNodeDataType& data_type)
{
Assert(data_type == ASTNodeDataType::list_t);
if constexpr (std::is_arithmetic_v<OutputType>) {
return [&](DataVariant&& result) -> std::vector<OutputType> {
return std::visit(
[&](auto&& value) -> std::vector<OutputType> {
using ValueType = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<ValueType, AggregateDataVariant>) {
std::vector<OutputType> array(value.size());
for (size_t i = 0; i < value.size(); ++i) {
array[i] = std::visit(
[&](auto&& value_i) -> OutputType {
using Value_I_Type = std::decay_t<decltype(value_i)>;
if constexpr (std::is_arithmetic_v<Value_I_Type>) {
return value_i;
} else {
throw UnexpectedError("expecting arithmetic type");
}
},
value[i]);
}
return array;
} else {
throw UnexpectedError("invalid DataVariant");
}
},
result);
};
} else {
throw NotImplementedError("non-arithmetic tuple type");
}
}
PugsFunctionAdapter() = delete;
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment