#include <language/utils/ItemArrayVariantFunctionInterpoler.hpp>

#include <language/utils/InterpolateItemArray.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemArrayVariant.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <utils/Exceptions.hpp>

#include <memory>

template <size_t Dimension, typename DataType, typename ArrayType>
std::shared_ptr<ItemArrayVariant>
ItemArrayVariantFunctionInterpoler::_interpolate() const
{
  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh);
  using MeshDataType     = MeshData<Dimension>;

  switch (m_item_type) {
  case ItemType::cell: {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
    return std::make_shared<ItemArrayVariant>(
      InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(m_function_id_list,
                                                                                                  mesh_data.xj()));
  }
  case ItemType::face: {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
    return std::make_shared<ItemArrayVariant>(
      InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::face>(m_function_id_list,
                                                                                                  mesh_data.xl()));
  }
  case ItemType::edge: {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
    return std::make_shared<ItemArrayVariant>(
      InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::edge>(m_function_id_list,
                                                                                                  mesh_data.xe()));
  }
  case ItemType::node: {
    return std::make_shared<ItemArrayVariant>(
      InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::node>(m_function_id_list,
                                                                                                  p_mesh->xr()));
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid item type");
  }
    // LCOV_EXCL_STOP
  }
}

template <size_t Dimension>
std::shared_ptr<ItemArrayVariant>
ItemArrayVariantFunctionInterpoler::_interpolate() const
{
  const ASTNodeDataType data_type = [&] {
    const auto& function0_descriptor = m_function_id_list[0].descriptor();
    Assert(function0_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t);

    ASTNodeDataType data_type = function0_descriptor.domainMappingNode().children[1]->m_data_type.contentType();

    for (size_t i = 1; i < m_function_id_list.size(); ++i) {
      const auto& function_descriptor = m_function_id_list[i].descriptor();
      Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t);
      if (data_type != function_descriptor.domainMappingNode().children[1]->m_data_type.contentType()) {
        throw NormalError("functions must have the same type");
      }
    }

    return data_type;
  }();

  switch (data_type) {
  case ASTNodeDataType::bool_t: {
    return this->_interpolate<Dimension, bool>();
  }
  case ASTNodeDataType::unsigned_int_t: {
    return this->_interpolate<Dimension, uint64_t>();
  }
  case ASTNodeDataType::int_t: {
    return this->_interpolate<Dimension, int64_t>();
  }
  case ASTNodeDataType::double_t: {
    return this->_interpolate<Dimension, double>();
  }
  case ASTNodeDataType::vector_t: {
    switch (data_type.dimension()) {
    case 1: {
      return this->_interpolate<Dimension, TinyVector<1>>();
    }
    case 2: {
      return this->_interpolate<Dimension, TinyVector<2>>();
    }
    case 3: {
      return this->_interpolate<Dimension, TinyVector<3>>();
    }
      // LCOV_EXCL_START
    default: {
      std::ostringstream os;
      os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset;

      throw UnexpectedError(os.str());
    }
      // LCOV_EXCL_STOP
    }
  }
  case ASTNodeDataType::matrix_t: {
    Assert(data_type.numberOfColumns() == data_type.numberOfRows(), "undefined matrix type");
    switch (data_type.numberOfColumns()) {
    case 1: {
      return this->_interpolate<Dimension, TinyMatrix<1>>();
    }
    case 2: {
      return this->_interpolate<Dimension, TinyMatrix<2>>();
    }
    case 3: {
      return this->_interpolate<Dimension, TinyMatrix<3>>();
    }
      // LCOV_EXCL_START
    default: {
      std::ostringstream os;
      os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset;

      throw UnexpectedError(os.str());
    }
      // LCOV_EXCL_STOP
    }
  }
    // LCOV_EXCL_START
  default: {
    std::ostringstream os;
    os << "invalid interpolation array type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset;

    throw UnexpectedError(os.str());
  }
    // LCOV_EXCL_STOP
  }
}

std::shared_ptr<ItemArrayVariant>
ItemArrayVariantFunctionInterpoler::interpolate() const
{
  switch (m_mesh->dimension()) {
  case 1: {
    return this->_interpolate<1>();
  }
  case 2: {
    return this->_interpolate<2>();
  }
  case 3: {
    return this->_interpolate<3>();
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid dimension");
  }
    // LCOV_EXCL_STOP
  }
}