#include <scheme/DiscreteFunctionIntegrator.hpp>

#include <language/utils/IntegrateCellValue.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <utils/Exceptions.hpp>

template <size_t Dimension, typename DataType, typename ValueType>
std::shared_ptr<IDiscreteFunction>
DiscreteFunctionIntegrator::_integrate() const
{
  using MeshType       = Mesh<Connectivity<Dimension>>;
  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(m_mesh);

  if constexpr (std::is_same_v<DataType, ValueType>) {
    return std::make_shared<
      DiscreteFunctionP0<Dimension, ValueType>>(mesh,
                                                IntegrateCellValue<DataType(TinyVector<Dimension>)>::template integrate<
                                                  MeshType>(m_function_id, *m_quadrature_descriptor, *mesh));
  } else {
    static_assert(std::is_convertible_v<DataType, ValueType>);

    CellValue<DataType> cell_data =
      IntegrateCellValue<DataType(TinyVector<Dimension>)>::template integrate<MeshType>(m_function_id,
                                                                                        *m_quadrature_descriptor,
                                                                                        *mesh);

    CellValue<ValueType> cell_value{mesh->connectivity()};

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(const CellId cell_id) { cell_value[cell_id] = cell_data[cell_id]; });

    return std::make_shared<DiscreteFunctionP0<Dimension, ValueType>>(mesh, cell_value);
  }
}

template <size_t Dimension>
std::shared_ptr<IDiscreteFunction>
DiscreteFunctionIntegrator::_integrate() const
{
  const auto& function_descriptor = m_function_id.descriptor();
  Assert(function_descriptor.domainMappingNode().children[1]->m_data_type == ASTNodeDataType::typename_t);

  const ASTNodeDataType& data_type = function_descriptor.domainMappingNode().children[1]->m_data_type.contentType();

  switch (data_type) {
  case ASTNodeDataType::bool_t: {
    return this->_integrate<Dimension, bool, double>();
  }
  case ASTNodeDataType::unsigned_int_t: {
    return this->_integrate<Dimension, uint64_t, double>();
  }
  case ASTNodeDataType::int_t: {
    return this->_integrate<Dimension, int64_t, double>();
  }
  case ASTNodeDataType::double_t: {
    return this->_integrate<Dimension, double>();
  }
  case ASTNodeDataType::vector_t: {
    switch (data_type.dimension()) {
    case 1: {
      return this->_integrate<Dimension, TinyVector<1>>();
    }
    case 2: {
      return this->_integrate<Dimension, TinyVector<2>>();
    }
    case 3: {
      return this->_integrate<Dimension, TinyVector<3>>();
    }
      // LCOV_EXCL_START
    default: {
      std::ostringstream os;
      os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset;

      throw UnexpectedError(os.str());
    }
      // LCOV_EXCL_STOP
    }
  }
  case ASTNodeDataType::matrix_t: {
    Assert(data_type.numberOfColumns() == data_type.numberOfRows(), "undefined matrix type");
    switch (data_type.numberOfColumns()) {
    case 1: {
      return this->_integrate<Dimension, TinyMatrix<1>>();
    }
    case 2: {
      return this->_integrate<Dimension, TinyMatrix<2>>();
    }
    case 3: {
      return this->_integrate<Dimension, TinyMatrix<3>>();
    }
      // LCOV_EXCL_START
    default: {
      std::ostringstream os;
      os << "invalid vector dimension " << rang::fgB::red << data_type.dimension() << rang::style::reset;

      throw UnexpectedError(os.str());
    }
      // LCOV_EXCL_STOP
    }
  }
    // LCOV_EXCL_START
  default: {
    std::ostringstream os;
    os << "invalid integrate value type: " << rang::fgB::red << dataTypeName(data_type) << rang::style::reset;

    throw UnexpectedError(os.str());
  }
    // LCOV_EXCL_STOP
  }
}

std::shared_ptr<IDiscreteFunction>
DiscreteFunctionIntegrator::integrate() const
{
  std::shared_ptr<IDiscreteFunction> discrete_function;
  switch (m_mesh->dimension()) {
  case 1: {
    return this->_integrate<1>();
  }
  case 2: {
    return this->_integrate<2>();
  }
  case 3: {
    return this->_integrate<3>();
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid dimension");
  }
    // LCOV_EXCL_STOP
  }
}
