diff --git a/src/scheme/DiscreteFunctionInterpoler.cpp b/src/scheme/DiscreteFunctionInterpoler.cpp index 7ac4919ab77bd0be973925a88b59a94cf7dc57d5..fda64fdcb82d1eb36b8925e9cdc438989cd92306 100644 --- a/src/scheme/DiscreteFunctionInterpoler.cpp +++ b/src/scheme/DiscreteFunctionInterpoler.cpp @@ -3,29 +3,101 @@ #include <scheme/DiscreteFunctionP0.hpp> #include <utils/Exceptions.hpp> +template <size_t Dimension, typename DataType> +std::shared_ptr<IDiscreteFunction> +DiscreteFunctionInterpoler::_interpolate() const +{ + std::shared_ptr mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh); + return std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh, m_function_id); +} + +template <size_t Dimension> +std::shared_ptr<IDiscreteFunction> +DiscreteFunctionInterpoler::_interpolate() const +{ + const auto& function_descriptor = m_function_id.descriptor(); + Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t); + + const ASTNodeDataType& data_type = function_descriptor.domainMappingNode().children[1]->m_data_type.contentType(); + + switch (data_type) { + case ASTNodeDataType::bool_t: { + return this->_interpolate<Dimension, bool>(); + } + case ASTNodeDataType::unsigned_int_t: { + return this->_interpolate<Dimension, uint64_t>(); + } + case ASTNodeDataType::int_t: { + return this->_interpolate<Dimension, int64_t>(); + } + case ASTNodeDataType::double_t: { + return this->_interpolate<Dimension, double>(); + } + case ASTNodeDataType::vector_t: { + switch (data_type.dimension()) { + case 1: { + return this->_interpolate<Dimension, TinyVector<1>>(); + } + case 2: { + return this->_interpolate<Dimension, TinyVector<2>>(); + } + case 3: { + return this->_interpolate<Dimension, TinyVector<3>>(); + } + default: { + std::ostringstream os; + os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset; + + throw UnexpectedError(os.str()); + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(data_type.nbColumns() == data_type.nbRows(), "undefined matrix type"); + switch (data_type.nbColumns()) { + case 1: { + return this->_interpolate<Dimension, TinyMatrix<1>>(); + } + case 2: { + return this->_interpolate<Dimension, TinyMatrix<2>>(); + } + case 3: { + return this->_interpolate<Dimension, TinyMatrix<3>>(); + } + default: { + std::ostringstream os; + os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset; + + throw UnexpectedError(os.str()); + } + } + } + default: { + std::ostringstream os; + os << "invalid interpolation value type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset; + + throw UnexpectedError(os.str()); + } + } +} + std::shared_ptr<IDiscreteFunction> DiscreteFunctionInterpoler::interpolate() const { std::shared_ptr<IDiscreteFunction> discrete_function; switch (m_mesh->dimension()) { case 1: { - std::shared_ptr mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<1>>>(m_mesh); - discrete_function = std::make_shared<DiscreteFunctionP0<1, double>>(mesh, m_function_id); - break; + return this->_interpolate<1>(); } case 2: { - std::shared_ptr mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<2>>>(m_mesh); - discrete_function = std::make_shared<DiscreteFunctionP0<2, double>>(mesh, m_function_id); - break; + return this->_interpolate<2>(); } case 3: { - std::shared_ptr mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<3>>>(m_mesh); - discrete_function = std::make_shared<DiscreteFunctionP0<3, double>>(mesh, m_function_id); - break; + return this->_interpolate<3>(); } default: { throw UnexpectedError("invalid dimension"); } } - return discrete_function; + return nullptr; } diff --git a/src/scheme/DiscreteFunctionInterpoler.hpp b/src/scheme/DiscreteFunctionInterpoler.hpp index 8f679c7637d85fc9c499f27f2273a5bc3b69de16..758ff2454ec673ba2cc46c0c561eca1a4a256b51 100644 --- a/src/scheme/DiscreteFunctionInterpoler.hpp +++ b/src/scheme/DiscreteFunctionInterpoler.hpp @@ -15,6 +15,9 @@ class DiscreteFunctionInterpoler std::shared_ptr<const IDiscreteFunctionDescriptor> m_discrete_function_descriptor; const FunctionSymbolId m_function_id; + template <size_t Dimension, typename DataType> + std::shared_ptr<IDiscreteFunction> _interpolate() const; + template <size_t Dimension> std::shared_ptr<IDiscreteFunction> _interpolate() const;