#include <scheme/DiscreteFunctionVectorInterpoler.hpp>

#include <language/utils/InterpolateItemArray.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <utils/Exceptions.hpp>

template <size_t Dimension, typename DataType>
std::shared_ptr<IDiscreteFunction>
DiscreteFunctionVectorInterpoler::_interpolate() const
{
  std::shared_ptr mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh);

  using MeshDataType      = MeshData<Dimension>;
  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

  return std::make_shared<
    DiscreteFunctionP0Vector<Dimension, DataType>>(mesh, InterpolateItemArray<DataType(TinyVector<Dimension>)>::
                                                           template interpolate<ItemType::cell>(m_function_id_list,
                                                                                                mesh_data.xj()));
}

template <size_t Dimension>
std::shared_ptr<IDiscreteFunction>
DiscreteFunctionVectorInterpoler::_interpolate() const
{
  for (const auto& function_id : m_function_id_list) {
    const auto& function_descriptor = function_id.descriptor();
    Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t);
    const ASTNodeDataType& data_type = function_descriptor.domainMappingNode().children[1]->m_data_type.contentType();

    switch (data_type) {
    case ASTNodeDataType::bool_t:
    case ASTNodeDataType::unsigned_int_t:
    case ASTNodeDataType::int_t:
    case ASTNodeDataType::double_t: {
      break;
    }
    default: {
      std::ostringstream os;
      os << "vector functions require scalar value type.\n"
         << "Invalid interpolation value type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset;
      throw NormalError(os.str());
    }
    }
  }
  return this->_interpolate<Dimension, double>();
}

std::shared_ptr<IDiscreteFunction>
DiscreteFunctionVectorInterpoler::interpolate() const
{
  if (m_discrete_function_descriptor->type() != DiscreteFunctionType::P0Vector) {
    throw NormalError("invalid discrete function type for vector interpolation");
  }

  switch (m_mesh->dimension()) {
  case 1: {
    return this->_interpolate<1>();
  }
  case 2: {
    return this->_interpolate<2>();
  }
  case 3: {
    return this->_interpolate<3>();
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid dimension");
  }
    // LCOV_EXCL_STOP
  }
}
