diff --git a/src/language/utils/EvaluateArrayAtPoints.hpp b/src/language/utils/EvaluateArrayAtPoints.hpp new file mode 100644 index 0000000000000000000000000000000000000000..558cdff30cc355287cd2c1233884f765b1cac1b7 --- /dev/null +++ b/src/language/utils/EvaluateArrayAtPoints.hpp @@ -0,0 +1,72 @@ +#ifndef EVALUATE_ARRAY_AT_POINTS_HPP +#define EVALUATE_ARRAY_AT_POINTS_HPP + +#include <language/utils/PugsFunctionAdapter.hpp> +#include <utils/Array.hpp> +#include <utils/Table.hpp> + +class FunctionSymbolId; + +template <typename T> +class EvaluateArrayAtPoints; +template <typename OutputType, typename InputType> +class EvaluateArrayAtPoints<OutputType(InputType)> : public PugsFunctionAdapter<OutputType(InputType)> +{ + using Adapter = PugsFunctionAdapter<OutputType(InputType)>; + + public: + template <typename InputArrayT, typename OutputTableT> + static PUGS_INLINE void + evaluateTo(const FunctionSymbolId& function_symbol_id, const InputArrayT& position, OutputTableT& table) + { + static_assert(std::is_same_v<std::remove_const_t<typename InputArrayT::data_type>, InputType>, + "invalid input data type"); + static_assert(std::is_same_v<std::remove_const_t<typename OutputTableT::data_type>, OutputType>, + "invalid output data type"); + + auto& expression = Adapter::getFunctionExpression(function_symbol_id); + auto convert_result = Adapter::getArrayResultConverter(expression.m_data_type); + + auto context_list = Adapter::getContextList(expression); + + using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; + Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; + + if constexpr (std::is_arithmetic_v<OutputType>) { + table.fill(0); + } else if constexpr (is_tiny_vector_v<OutputType> or is_tiny_matrix_v<OutputType>) { + table.fill(zero); + } else { + static_assert(std::is_same_v<OutputType, double>, "unexpected output type"); + } + + parallel_for(size(position), [=, &expression, &tokens](typename InputArrayT::index_type i) { + const int32_t t = tokens.acquire(); + + auto& execution_policy = context_list[t]; + + Adapter::convertArgs(execution_policy.currentContext(), position[i]); + auto result = expression.execute(execution_policy); + auto&& array = convert_result(std::move(result)); + + for (size_t j = 0; j < array.size(); ++j) { + table[i][j] = array[j]; + } + + tokens.release(t); + }); + } + + template <class InputArrayT> + static PUGS_INLINE Table<OutputType> + evaluate(const FunctionSymbolId& function_symbol_id, const InputArrayT& position) + { + static_assert(std::is_same_v<std::remove_const_t<typename InputArrayT::data_type>, InputType>, + "invalid input data type"); + Table<OutputType> value(size(position)); + evaluateArrayTo(function_symbol_id, position, value); + return value; + } +}; + +#endif // EVALUATE_ARRAY_AT_POINTS_HPP diff --git a/src/language/utils/InterpolateItemArray.hpp b/src/language/utils/InterpolateItemArray.hpp index 17b72de6d62e0ac68062343010154257c43d3006..fde135d6bcbb4c8995395660226e8d3435c599b4 100644 --- a/src/language/utils/InterpolateItemArray.hpp +++ b/src/language/utils/InterpolateItemArray.hpp @@ -1,6 +1,7 @@ #ifndef INTERPOLATE_ITEM_ARRAY_HPP #define INTERPOLATE_ITEM_ARRAY_HPP +#include <language/utils/EvaluateArrayAtPoints.hpp> #include <language/utils/InterpolateItemValue.hpp> #include <mesh/ItemArray.hpp> #include <mesh/ItemType.hpp> @@ -12,24 +13,47 @@ class InterpolateItemArray<OutputType(InputType)> { static constexpr size_t Dimension = OutputType::Dimension; + private: + PUGS_INLINE static bool + _isSingleTupleFunction(const std::vector<FunctionSymbolId>& function_symbol_id_list) + { + if (function_symbol_id_list.size() > 1) { + return false; + } else { + Assert(function_symbol_id_list.size() == 1); + const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0]; + return (function_symbol_id.descriptor().domainMappingNode().children[1]->m_data_type == ASTNodeDataType::tuple_t); + } + } + public: template <ItemType item_type> PUGS_INLINE static ItemArray<OutputType, item_type> interpolate(const std::vector<FunctionSymbolId>& function_symbol_id_list, const ItemValue<const InputType, item_type>& position) { - ItemArray<OutputType, item_type> item_array{*position.connectivity_ptr(), function_symbol_id_list.size()}; + if (_isSingleTupleFunction(function_symbol_id_list)) { + const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0]; + const size_t table_size = function_symbol_id.descriptor().definitionNode().children[1]->children.size(); - for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) { - const FunctionSymbolId& function_symbol_id = function_symbol_id_list[i_function_symbol]; - ItemValue<OutputType, item_type> item_value = - InterpolateItemValue<OutputType(InputType)>::interpolate(function_symbol_id, position); - parallel_for( - item_value.numberOfItems(), - PUGS_LAMBDA(ItemIdT<item_type> item_id) { item_array[item_id][i_function_symbol] = item_value[item_id]; }); - } + ItemArray<OutputType, item_type> item_array{*position.connectivity_ptr(), table_size}; + EvaluateArrayAtPoints<OutputType(const InputType)>::evaluateTo(function_symbol_id, position, item_array); + + return item_array; + } else { + ItemArray<OutputType, item_type> item_array{*position.connectivity_ptr(), function_symbol_id_list.size()}; + + for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) { + const FunctionSymbolId& function_symbol_id = function_symbol_id_list[i_function_symbol]; + ItemValue<OutputType, item_type> item_value = + InterpolateItemValue<OutputType(InputType)>::interpolate(function_symbol_id, position); + parallel_for( + item_value.numberOfItems(), + PUGS_LAMBDA(ItemIdT<item_type> item_id) { item_array[item_id][i_function_symbol] = item_value[item_id]; }); + } - return item_array; + return item_array; + } } template <ItemType item_type> @@ -38,18 +62,36 @@ class InterpolateItemArray<OutputType(InputType)> const ItemValue<const InputType, item_type>& position, const Array<const ItemIdT<item_type>>& list_of_items) { - Table<OutputType> table{list_of_items.size(), function_symbol_id_list.size()}; + if (_isSingleTupleFunction(function_symbol_id_list)) { + Array<InputType> item_position{list_of_items.size()}; + using ItemId = ItemIdT<item_type>; + parallel_for( + list_of_items.size(), PUGS_LAMBDA(size_t i_item) { + ItemId item_id = list_of_items[i_item]; + item_position[i_item] = position[item_id]; + }); - for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) { - const FunctionSymbolId& function_symbol_id = function_symbol_id_list[i_function_symbol]; - Array<OutputType> array = - InterpolateItemValue<OutputType(InputType)>::interpolate(function_symbol_id, position, list_of_items); + const FunctionSymbolId& function_symbol_id = function_symbol_id_list[0]; + const size_t table_size = function_symbol_id.descriptor().definitionNode().children[1]->children.size(); - parallel_for( - array.size(), PUGS_LAMBDA(size_t i) { table[i][i_function_symbol] = array[i]; }); - } + Table<OutputType> table{list_of_items.size(), table_size}; + EvaluateArrayAtPoints<OutputType(const InputType)>::evaluateTo(function_symbol_id, item_position, table); - return table; + return table; + } else { + Table<OutputType> table{list_of_items.size(), function_symbol_id_list.size()}; + + for (size_t i_function_symbol = 0; i_function_symbol < function_symbol_id_list.size(); ++i_function_symbol) { + const FunctionSymbolId& function_symbol_id = function_symbol_id_list[i_function_symbol]; + Array<OutputType> array = + InterpolateItemValue<OutputType(InputType)>::interpolate(function_symbol_id, position, list_of_items); + + parallel_for( + array.size(), PUGS_LAMBDA(size_t i) { table[i][i_function_symbol] = array[i]; }); + } + + return table; + } } template <ItemType item_type> diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index de5e884bf63bc0ac3726864c4c4571c0a0fb9f55..1a4beb32e2889467a4685c79f6afa25e1aef4b68 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -288,6 +288,44 @@ class PugsFunctionAdapter<OutputType(InputType...)> } } + [[nodiscard]] PUGS_INLINE static std::function<std::vector<OutputType>(DataVariant&& result)> + getArrayResultConverter(const ASTNodeDataType& data_type) + { + Assert(data_type == ASTNodeDataType::list_t); + + if constexpr (std::is_arithmetic_v<OutputType>) { + return [&](DataVariant&& result) -> std::vector<OutputType> { + return std::visit( + [&](auto&& value) -> std::vector<OutputType> { + using ValueType = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<ValueType, AggregateDataVariant>) { + std::vector<OutputType> array(value.size()); + + for (size_t i = 0; i < value.size(); ++i) { + array[i] = std::visit( + [&](auto&& value_i) -> OutputType { + using Value_I_Type = std::decay_t<decltype(value_i)>; + if constexpr (std::is_arithmetic_v<Value_I_Type>) { + return value_i; + } else { + throw UnexpectedError("expecting arithmetic type"); + } + }, + value[i]); + } + + return array; + } else { + throw UnexpectedError("invalid DataVariant"); + } + }, + result); + }; + } else { + throw NotImplementedError("non-arithmetic tuple type"); + } + } + PugsFunctionAdapter() = delete; };