#include <scheme/DiscreteFunctionVectorInterpoler.hpp>

#include <language/utils/InterpolateItemArray.hpp>
#include <mesh/MeshCellZone.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <utils/Exceptions.hpp>

template <size_t Dimension, typename DataType>
DiscreteFunctionVariant
DiscreteFunctionVectorInterpoler::_interpolateOnZoneList() const
{
  Assert(m_zone_list.size() > 0, "no zone list provided");

  std::shared_ptr p_mesh  = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh);
  using MeshDataType      = MeshData<Dimension>;
  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);

  CellValue<bool> is_in_zone{p_mesh->connectivity()};
  is_in_zone.fill(false);

  size_t number_of_cells = 0;
  for (const auto& zone : m_zone_list) {
    auto mesh_cell_zone   = getMeshCellZone(*p_mesh, *zone);
    const auto& cell_list = mesh_cell_zone.cellList();
    for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
      const CellId cell_id = cell_list[i_cell];
      if (is_in_zone[cell_id]) {
        std::ostringstream os;
        os << "cell " << cell_id << " (number " << p_mesh->connectivity().cellNumber()[cell_id]
           << ") is present multiple times in zone list";
        throw NormalError(os.str());
      }
      ++number_of_cells;
      is_in_zone[cell_id] = true;
    }
  }

  Array<CellId> zone_cell_list{number_of_cells};
  {
    size_t i_cell = 0;
    for (CellId cell_id = 0; cell_id < p_mesh->numberOfCells(); ++cell_id) {
      if (is_in_zone[cell_id]) {
        zone_cell_list[i_cell++] = cell_id;
      }
    }
  }

  Table<const DataType> table =
    InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(m_function_id_list,
                                                                                                mesh_data.xj(),
                                                                                                zone_cell_list);

  CellArray<DataType> cell_array{p_mesh->connectivity(), m_function_id_list.size()};
  cell_array.fill(0);

  parallel_for(
    zone_cell_list.size(), PUGS_LAMBDA(const size_t i_cell) {
      for (size_t i = 0; i < table.numberOfColumns(); ++i) {
        cell_array[zone_cell_list[i_cell]][i] = table(i_cell, i);
      }
    });

  return DiscreteFunctionP0Vector<Dimension, DataType>(p_mesh, cell_array);
}

template <size_t Dimension, typename DataType>
DiscreteFunctionVariant
DiscreteFunctionVectorInterpoler::_interpolateGlobally() const
{
  Assert(m_zone_list.size() == 0, "invalid call when zones are defined");

  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(m_mesh);

  using MeshDataType      = MeshData<Dimension>;
  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);

  return DiscreteFunctionP0Vector<Dimension, DataType>(p_mesh,
                                                       InterpolateItemArray<DataType(TinyVector<Dimension>)>::
                                                         template interpolate<ItemType::cell>(m_function_id_list,
                                                                                              mesh_data.xj()));
}

template <size_t Dimension, typename DataType>
DiscreteFunctionVariant
DiscreteFunctionVectorInterpoler::_interpolate() const
{
  if (m_zone_list.size() == 0) {
    return this->_interpolateGlobally<Dimension, DataType>();
  } else {
    return this->_interpolateOnZoneList<Dimension, DataType>();
  }
}

template <size_t Dimension>
DiscreteFunctionVariant
DiscreteFunctionVectorInterpoler::_interpolate() const
{
  for (const auto& function_id : m_function_id_list) {
    const auto& function_descriptor = function_id.descriptor();
    Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t);
    const ASTNodeDataType& data_type = function_descriptor.domainMappingNode().children[1]->m_data_type.contentType();

    switch (data_type) {
    case ASTNodeDataType::bool_t:
    case ASTNodeDataType::unsigned_int_t:
    case ASTNodeDataType::int_t:
    case ASTNodeDataType::double_t: {
      break;
    }
    default: {
      std::ostringstream os;
      os << "vector functions require scalar value type.\n"
         << "Invalid interpolation value type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset;
      throw NormalError(os.str());
    }
    }
  }
  return this->_interpolate<Dimension, double>();
}

DiscreteFunctionVariant
DiscreteFunctionVectorInterpoler::interpolate() const
{
  if (m_discrete_function_descriptor->type() != DiscreteFunctionType::P0Vector) {
    throw NormalError("invalid discrete function type for vector interpolation");
  }

  switch (m_mesh->dimension()) {
  case 1: {
    return this->_interpolate<1>();
  }
  case 2: {
    return this->_interpolate<2>();
  }
  case 3: {
    return this->_interpolate<3>();
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid dimension");
  }
    // LCOV_EXCL_STOP
  }
}
