#include <output/WriterBase.hpp>

#include <mesh/IMesh.hpp>
#include <mesh/ItemValueVariant.hpp>
#include <output/NamedDiscreteFunction.hpp>
#include <output/NamedItemValueVariant.hpp>
#include <output/OutputNamedItemValueSet.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <scheme/IDiscreteFunction.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <utils/Exceptions.hpp>

template <typename DiscreteFunctionType>
void
WriterBase::_registerDiscreteFunction(const std::string& name,
                                      const DiscreteFunctionType& discrete_function,
                                      OutputNamedItemDataSet& named_item_data_set)
{
  if constexpr (DiscreteFunctionType::handled_data_type == IDiscreteFunction::HandledItemDataType::value) {
    named_item_data_set.add(NamedItemData{name, discrete_function.cellValues()});
  } else {
    named_item_data_set.add(NamedItemData{name, discrete_function.cellArrays()});
  }
}

template <size_t Dimension, template <size_t DimensionT, typename DataTypeT> typename DiscreteFunctionType>
void
WriterBase::_registerDiscreteFunction(const std::string& name,
                                      const IDiscreteFunction& i_discrete_function,
                                      OutputNamedItemDataSet& named_item_data_set)
{
  const ASTNodeDataType& data_type = i_discrete_function.dataType();
  switch (data_type) {
  case ASTNodeDataType::bool_t: {
    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, bool>&>(i_discrete_function),
                              named_item_data_set);
    break;
  }
  case ASTNodeDataType::unsigned_int_t: {
    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, uint64_t>&>(i_discrete_function),
                              named_item_data_set);
    break;
  }
  case ASTNodeDataType::int_t: {
    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, int64_t>&>(i_discrete_function),
                              named_item_data_set);
    break;
  }
  case ASTNodeDataType::double_t: {
    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, double>&>(i_discrete_function),
                              named_item_data_set);
    break;
  }
  case ASTNodeDataType::vector_t: {
    if constexpr (DiscreteFunctionType<Dimension, double>::handled_data_type ==
                  IDiscreteFunction::HandledItemDataType::vector) {
      throw UnexpectedError("invalid data type for vector data");
    } else {
      switch (data_type.dimension()) {
      case 1: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<1, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      case 2: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<2, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      case 3: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<3, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      default: {
        throw UnexpectedError("invalid vector dimension");
      }
      }
    }
    break;
  }
  case ASTNodeDataType::matrix_t: {
    if constexpr (DiscreteFunctionType<Dimension, double>::handled_data_type ==
                  IDiscreteFunction::HandledItemDataType::vector) {
      throw UnexpectedError("invalid data type for vector data");
    } else {
      Assert(data_type.numberOfRows() == data_type.numberOfColumns(), "invalid matrix dimensions");
      switch (data_type.numberOfRows()) {
      case 1: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<1, 1, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      case 2: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<2, 2, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      case 3: {
        _registerDiscreteFunction(name,
                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<3, 3, double>>&>(
                                    i_discrete_function),
                                  named_item_data_set);
        break;
      }
      default: {
        throw UnexpectedError("invalid matrix dimension");
      }
      }
    }
    break;
  }
  default: {
    throw UnexpectedError("invalid data type " + dataTypeName(data_type));
  }
  }
}

template <template <size_t Dimension, typename DataType> typename DiscreteFunctionType>
void
WriterBase::_registerDiscreteFunction(const NamedDiscreteFunction& named_discrete_function,
                                      OutputNamedItemDataSet& named_item_data_set)
{
  const IDiscreteFunction& i_discrete_function = *named_discrete_function.discreteFunction();
  const std::string& name                      = named_discrete_function.name();
  switch (i_discrete_function.mesh()->dimension()) {
  case 1: {
    _registerDiscreteFunction<1, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
    break;
  }
  case 2: {
    _registerDiscreteFunction<2, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
    break;
  }
  case 3: {
    _registerDiscreteFunction<3, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

void
WriterBase::_checkConnectivity(
  const std::shared_ptr<const IMesh>& mesh,
  const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  Assert(named_discrete_data_list.size() > 0);

  std::shared_ptr<const IConnectivity> connectivity = [&]() -> std::shared_ptr<const IConnectivity> {
    switch (mesh->dimension()) {
    case 1: {
      return dynamic_cast<const Mesh<Connectivity<1>>&>(*mesh).shared_connectivity();
    }
    case 2: {
      return dynamic_cast<const Mesh<Connectivity<2>>&>(*mesh).shared_connectivity();
    }
    case 3: {
      return dynamic_cast<const Mesh<Connectivity<3>>&>(*mesh).shared_connectivity();
    }
    default: {
      throw UnexpectedError("invalid dimension");
    }
    }
  }();

  for (size_t i = 0; i < named_discrete_data_list.size(); ++i) {
    const auto& named_discrete_data = named_discrete_data_list[i];

    if (named_discrete_data->type() == INamedDiscreteData::Type::item_value) {
      const NamedItemValueVariant& named_item_value_variant =
        dynamic_cast<const NamedItemValueVariant&>(*named_discrete_data);

      std::visit(
        [&](auto&& item_value) {
          if (item_value.connectivity_ptr() != connectivity) {
            std::ostringstream error_msg;
            error_msg << "The variable " << rang::fgB::yellow << named_item_value_variant.name() << rang::fg::reset
                      << " is not defined on the provided same connectivity as the mesh\n";
            throw NormalError(error_msg.str());
          }
        },
        named_item_value_variant.itemValueVariant()->itemValue());
    }
  }
}

void
WriterBase::_checkMesh(const std::shared_ptr<const IMesh>& mesh,
                       const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  Assert(named_discrete_data_list.size() > 0);

  for (size_t i = 0; i < named_discrete_data_list.size(); ++i) {
    const auto& named_discrete_data = named_discrete_data_list[i];

    if (named_discrete_data->type() == INamedDiscreteData::Type::discrete_function) {
      const NamedDiscreteFunction& named_discrete_function =
        dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);

      if (mesh != named_discrete_function.discreteFunction()->mesh()) {
        std::ostringstream error_msg;
        error_msg << "The variable " << rang::fgB::yellow << named_discrete_function.name() << rang::fg::reset
                  << " is not defined on the provided mesh\n";
        throw NormalError(error_msg.str());
      }
    }
  }
}

std::shared_ptr<const IMesh>
WriterBase::_getMesh(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  Assert(named_discrete_data_list.size() > 0);

  std::map<std::shared_ptr<const IMesh>, std::string> mesh_set;
  std::map<std::shared_ptr<const IConnectivity>, std::string> connectivity_set;

  for (size_t i = 0; i < named_discrete_data_list.size(); ++i) {
    const auto& named_discrete_data = named_discrete_data_list[i];

    switch (named_discrete_data->type()) {
    case INamedDiscreteData::Type::discrete_function: {
      const NamedDiscreteFunction& named_discrete_function =
        dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);

      std::shared_ptr mesh = named_discrete_function.discreteFunction()->mesh();
      mesh_set[mesh]       = named_discrete_function.name();

      switch (mesh->dimension()) {
      case 1: {
        connectivity_set[dynamic_cast<const Mesh<Connectivity<1>>&>(*mesh).shared_connectivity()] =
          named_discrete_function.name();
        break;
      }
      case 2: {
        connectivity_set[dynamic_cast<const Mesh<Connectivity<2>>&>(*mesh).shared_connectivity()] =
          named_discrete_function.name();
        break;
      }
      case 3: {
        connectivity_set[dynamic_cast<const Mesh<Connectivity<3>>&>(*mesh).shared_connectivity()] =
          named_discrete_function.name();
        break;
      }
      default: {
        throw UnexpectedError("invalid dimension");
      }
      }
      break;
    }
    case INamedDiscreteData::Type::item_value: {
      const NamedItemValueVariant& named_item_value_variant =
        dynamic_cast<const NamedItemValueVariant&>(*named_discrete_data);

      std::visit([&](
                   auto&&
                     item_value) { connectivity_set[item_value.connectivity_ptr()] = named_item_value_variant.name(); },
                 named_item_value_variant.itemValueVariant()->itemValue());
    }
    }
  }

  if (mesh_set.size() != 1) {
    if (mesh_set.size() == 0) {
      throw NormalError("cannot find any mesh associated to output quantities");
    } else {
      std::ostringstream error_msg;
      error_msg << "cannot save data using different " << rang::fgB::red << "meshes" << rang::fg::reset
                << " in the same file!\n";
      error_msg << rang::fgB::yellow << "note:" << rang::fg::reset
                << "the following variables are defined on different meshes:";
      for (const auto& [mesh, name] : mesh_set) {
        error_msg << "\n- " << name;
      }
      throw NormalError(error_msg.str());
    }
  }

  if (connectivity_set.size() > 1) {
    std::ostringstream error_msg;
    error_msg << "cannot save data using different " << rang::fgB::red << "connectivities" << rang::fg::reset
              << " in the same file!\n";
    error_msg << rang::fgB::yellow << "note:" << rang::fg::reset
              << "the following variables are defined on different connectivities:";
    for (const auto& [connectivity, name] : connectivity_set) {
      error_msg << "\n- " << name;
    }
    throw NormalError(error_msg.str());
  }

  return mesh_set.begin()->first;
}

OutputNamedItemDataSet
WriterBase::_getOutputNamedItemDataSet(
  const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  OutputNamedItemDataSet named_item_data_set;

  for (auto& named_discrete_data : named_discrete_data_list) {
    switch (named_discrete_data->type()) {
    case INamedDiscreteData::Type::discrete_function: {
      const NamedDiscreteFunction& named_discrete_function =
        dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);

      const IDiscreteFunction& i_discrete_function = *named_discrete_function.discreteFunction();

      switch (i_discrete_function.descriptor().type()) {
      case DiscreteFunctionType::P0: {
        WriterBase::_registerDiscreteFunction<DiscreteFunctionP0>(named_discrete_function, named_item_data_set);
        break;
      }
      case DiscreteFunctionType::P0Vector: {
        WriterBase::_registerDiscreteFunction<DiscreteFunctionP0Vector>(named_discrete_function, named_item_data_set);
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << "the type of discrete function of " << rang::fgB::blue << named_discrete_data->name()
                  << rang::style::reset << " is not supported";
        throw NormalError(error_msg.str());
      }
      }
      break;
    }
    case INamedDiscreteData::Type::item_value: {
      const NamedItemValueVariant& named_item_value_variant =
        dynamic_cast<const NamedItemValueVariant&>(*named_discrete_data);

      const std::string& name = named_item_value_variant.name();

      const ItemValueVariant& item_value_variant = *named_item_value_variant.itemValueVariant();

      std::visit([&](auto&& item_value) { named_item_data_set.add(NamedItemData{name, item_value}); },
                 item_value_variant.itemValue());
      break;
    }
    default: {
      throw UnexpectedError("invalid discrete data type");
    }
    }
  }
  return named_item_data_set;
}

WriterBase::WriterBase(const std::string& base_filename, const double& time_period)
  : m_base_filename{base_filename}, m_period_manager(time_period)
{}

WriterBase::WriterBase(const std::string& base_filename) : m_base_filename{base_filename} {}

void
WriterBase::write(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  if (m_period_manager.has_value()) {
    throw NormalError("this writer requires time value");
  } else {
    std::shared_ptr<const IMesh> mesh = _getMesh(named_discrete_data_list);
    this->_write(*mesh, named_discrete_data_list);
  }
}

void
WriterBase::writeIfNeeded(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                          double time) const
{
  if (m_period_manager.has_value()) {
    const double last_time = m_period_manager->getLastTime();
    if (time == last_time)
      return;   // output already performed

    if (time >= m_period_manager->nextTime()) {
      std::shared_ptr<const IMesh> mesh = _getMesh(named_discrete_data_list);
      this->_writeAtTime(*mesh, named_discrete_data_list, time);
      m_period_manager->setSaveTime(time);
    }
  } else {
    throw NormalError("this writer does not allow time value");
  }
}

void
WriterBase::writeForced(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                        double time) const
{
  if (m_period_manager.has_value()) {
    if (time == m_period_manager->getLastTime())
      return;   // output already performed
    std::shared_ptr<const IMesh> mesh = _getMesh(named_discrete_data_list);
    this->_writeAtTime(*mesh, named_discrete_data_list, time);
    m_period_manager->setSaveTime(time);
  } else {
    throw NormalError("this writer does not allow time value");
  }
}

void
WriterBase::writeOnMesh(const std::shared_ptr<const IMesh>& mesh,
                        const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  if (m_period_manager.has_value()) {
    throw NormalError("this writer requires time value");
  } else {
    _checkMesh(mesh, named_discrete_data_list);
    _checkConnectivity(mesh, named_discrete_data_list);
    this->_write(*mesh, named_discrete_data_list);
  }
}

void
WriterBase::writeOnMeshIfNeeded(const std::shared_ptr<const IMesh>& mesh,
                                const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                                double time) const
{
  if (m_period_manager.has_value()) {
    if (time == m_period_manager->getLastTime())
      return;   // output already performed
    _checkMesh(mesh, named_discrete_data_list);
    _checkConnectivity(mesh, named_discrete_data_list);
    this->_writeAtTime(*mesh, named_discrete_data_list, time);
    m_period_manager->setSaveTime(time);
  } else {
    throw NormalError("this writer does not allow time value");
  }
}

void
WriterBase::writeOnMeshForced(const std::shared_ptr<const IMesh>& mesh,
                              const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                              double time) const
{
  if (m_period_manager.has_value()) {
    if (time == m_period_manager->getLastTime())
      return;   // output already performed
    _checkMesh(mesh, named_discrete_data_list);
    _checkConnectivity(mesh, named_discrete_data_list);
    this->_writeAtTime(*mesh, named_discrete_data_list, time);
    m_period_manager->setSaveTime(time);
  } else {
    throw NormalError("this writer does not allow time value");
  }
}

void
WriterBase::writeMesh(const std::shared_ptr<const IMesh>& mesh) const
{
  writeMesh(*mesh);
}

void
WriterBase::writeMesh(const IMesh& mesh) const
{
  if (m_period_manager.has_value()) {
    throw NormalError("write_mesh requires a writer without time period");
  } else {
    this->_writeMesh(mesh);
  }
}
