#include <output/WriterBase.hpp>

#include <mesh/IMesh.hpp>
#include <output/OutputNamedItemValueSet.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/IDiscreteFunction.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <utils/Exceptions.hpp>

template <size_t Dimension, typename DataType>
void
WriterBase::registerDiscreteFunctionP0(const std::string& name,
                                       const IDiscreteFunction& i_discrete_function,
                                       OutputNamedItemValueSet& named_item_value_set)
{
  const DiscreteFunctionP0<Dimension, DataType>& discrete_function =
    dynamic_cast<const DiscreteFunctionP0<Dimension, DataType>&>(i_discrete_function);
  named_item_value_set.add(NamedItemValue{name, discrete_function.cellValues()});
}

template <size_t Dimension>
void
WriterBase::registerDiscreteFunctionP0(const std::string& name,
                                       const IDiscreteFunction& i_discrete_function,
                                       OutputNamedItemValueSet& named_item_value_set)
{
  const ASTNodeDataType& data_type = i_discrete_function.dataType();
  switch (data_type) {
  case ASTNodeDataType::bool_t: {
    registerDiscreteFunctionP0<Dimension, bool>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case ASTNodeDataType::unsigned_int_t: {
    registerDiscreteFunctionP0<Dimension, uint64_t>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case ASTNodeDataType::int_t: {
    registerDiscreteFunctionP0<Dimension, int64_t>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case ASTNodeDataType::double_t: {
    registerDiscreteFunctionP0<Dimension, double>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case ASTNodeDataType::vector_t: {
    switch (data_type.dimension()) {
    case 1: {
      registerDiscreteFunctionP0<Dimension, TinyVector<1, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    case 2: {
      registerDiscreteFunctionP0<Dimension, TinyVector<2, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    case 3: {
      registerDiscreteFunctionP0<Dimension, TinyVector<3, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    default: {
      throw UnexpectedError("invalid vector dimension");
    }
    }
    break;
  }
  case ASTNodeDataType::matrix_t: {
    Assert(data_type.nbRows() == data_type.nbColumns(), "invalid matrix dimensions");
    switch (data_type.nbRows()) {
    case 1: {
      registerDiscreteFunctionP0<Dimension, TinyMatrix<1, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    case 2: {
      registerDiscreteFunctionP0<Dimension, TinyMatrix<2, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    case 3: {
      registerDiscreteFunctionP0<Dimension, TinyMatrix<3, double>>(name, i_discrete_function, named_item_value_set);
      break;
    }
    default: {
      throw UnexpectedError("invalid matrix dimension");
    }
    }
    break;
  }
  default: {
    throw UnexpectedError("invalid data type " + dataTypeName(data_type));
  }
  }
}

void
WriterBase::registerDiscreteFunctionP0(const NamedDiscreteFunction& named_discrete_function,
                                       OutputNamedItemValueSet& named_item_value_set)
{
  const IDiscreteFunction& i_discrete_function = *named_discrete_function.discreteFunction();
  const std::string& name                      = named_discrete_function.name();
  switch (i_discrete_function.mesh()->dimension()) {
  case 1: {
    registerDiscreteFunctionP0<1>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case 2: {
    registerDiscreteFunctionP0<2>(name, i_discrete_function, named_item_value_set);
    break;
  }
  case 3: {
    registerDiscreteFunctionP0<3>(name, i_discrete_function, named_item_value_set);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

std::shared_ptr<const IMesh>
WriterBase::_getMesh(
  const std::vector<std::shared_ptr<const NamedDiscreteFunction>>& named_discrete_function_list) const
{
  Assert(named_discrete_function_list.size() > 0);

  std::shared_ptr mesh = named_discrete_function_list[0]->discreteFunction()->mesh();
  for (size_t i = 1; i < named_discrete_function_list.size(); ++i) {
    if (mesh != named_discrete_function_list[i]->discreteFunction()->mesh()) {
      std::ostringstream error_msg;
      error_msg << "discrete functions must be defined on the same mesh!\n"
                << rang::fgB::yellow << "note:" << rang::style::reset << " cannot write " << rang::fgB::blue
                << named_discrete_function_list[0]->name() << rang::style::reset << " and " << rang::fgB::blue
                << named_discrete_function_list[i]->name() << rang::style::reset << " in the same file.";
      throw NormalError(error_msg.str());
    }
  }

  return mesh;
}

OutputNamedItemValueSet
WriterBase::_getOutputNamedItemValueSet(
  const std::vector<std::shared_ptr<const NamedDiscreteFunction>>& named_discrete_function_list) const
{
  OutputNamedItemValueSet named_item_value_set;

  for (auto& named_discrete_function : named_discrete_function_list) {
    const IDiscreteFunction& i_discrete_function = *named_discrete_function->discreteFunction();

    switch (i_discrete_function.descriptor().type()) {
    case DiscreteFunctionType::P0: {
      WriterBase::registerDiscreteFunctionP0(*named_discrete_function, named_item_value_set);
      break;
    }
    default: {
      std::ostringstream error_msg;
      error_msg << "the type of discrete function of " << rang::fgB::blue << named_discrete_function->name()
                << rang::style::reset << " is not supported";
      throw NormalError(error_msg.str());
    }
    }
  }

  return named_item_value_set;
}

double
WriterBase::getLastTime() const
{
  if (m_saved_times.size() > 0) {
    return m_saved_times[m_saved_times.size() - 1];
  } else {
    return -std::numeric_limits<double>::max();
  }
}

WriterBase::WriterBase(const std::string& base_filename, const double& time_period)
  : m_base_filename{base_filename}, m_time_period{time_period}, m_next_time{0}
{}

void
WriterBase::writeIfNeeded(const std::vector<std::shared_ptr<const NamedDiscreteFunction>>& named_discrete_function_list,
                          double time) const
{
  const double last_time = getLastTime();
  if (time == last_time)
    return;   // output already performed

  if (time >= m_next_time) {
    m_next_time += m_time_period;
    this->write(named_discrete_function_list, time);
    m_saved_times.push_back(time);
  }
}

void
WriterBase::writeForced(const std::vector<std::shared_ptr<const NamedDiscreteFunction>>& named_discrete_function_list,
                        double time) const
{
  if (time == getLastTime())
    return;   // output already performed

  this->write(named_discrete_function_list, time);
  m_saved_times.push_back(time);
}
