#include <output/GnuplotWriter.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <utils/Filesystem.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsTraits.hpp>
#include <utils/RevisionInfo.hpp>
#include <utils/Stringify.hpp>

#include <utils/Demangle.hpp>

#include <fstream>
#include <iomanip>

std::string
GnuplotWriter::_getDateAndVersionComment() const
{
  std::ostringstream os;

  std::time_t now = std::time(nullptr);
  os << "#  Generated by pugs: " << std::ctime(&now);
  os << "#  version: " << RevisionInfo::version() << '\n';
  os << "#  tag:  " << RevisionInfo::gitTag() << '\n';
  os << "#  HEAD: " << RevisionInfo::gitHead() << '\n';
  os << "#  hash: " << RevisionInfo::gitHash() << " (" << ((RevisionInfo::gitIsClean()) ? "clean" : "dirty") << ")\n";
  os << '\n';

  return os.str();
}

std::string
GnuplotWriter::_getFilename() const
{
  std::ostringstream sout;
  sout << m_base_filename;
  if (m_period_manager.has_value()) {
    sout << '.' << std::setfill('0') << std::setw(4) << m_period_manager->nbSavedTimes();
  }
  sout << ".gnu";
  return sout.str();
}

template <size_t Dimension>
void
GnuplotWriter::_writePreamble(const OutputNamedItemDataSet& output_named_item_data_set, std::ostream& fout) const
{
  fout << "# list of data\n";
  fout << "# 1:x";
  if constexpr (Dimension > 1) {
    fout << " 2:y";
  }
  uint64_t i = Dimension + 1;
  for (const auto& i_named_item_data : output_named_item_data_set) {
    const std::string name        = i_named_item_data.first;
    const auto& item_data_variant = i_named_item_data.second;
    std::visit(
      [&](auto&& item_data) {
        using ItemDataType = std::decay_t<decltype(item_data)>;
        using DataType     = std::decay_t<typename ItemDataType::data_type>;
        if constexpr (is_item_value_v<ItemDataType>) {
          if constexpr (std::is_arithmetic_v<DataType>) {
            fout << ' ' << i++ << ':' << name;
          } else if constexpr (is_tiny_vector_v<DataType>) {
            for (size_t j = 0; j < DataType{}.dimension(); ++j) {
              fout << ' ' << i++ << ':' << name << '[' << j << ']';
            }
          } else if constexpr (is_tiny_matrix_v<DataType>) {
            for (size_t j = 0; j < DataType{}.numberOfRows(); ++j) {
              for (size_t k = 0; k < DataType{}.numberOfColumns(); ++k) {
                fout << ' ' << i++ << ':' << name << '(' << j << ',' << k << ')';
              }
            }
          } else {
            throw UnexpectedError("invalid data type");
          }
        } else if constexpr (is_item_array_v<ItemDataType>) {
          if constexpr (std::is_arithmetic_v<DataType>) {
            for (size_t j = 0; j < item_data.sizeOfArrays(); ++j) {
              fout << ' ' << i++ << ':' << name << '[' << j << ']';
            }
          } else {
            throw UnexpectedError("invalid data type");
          }
        } else {
          throw UnexpectedError("invalid ItemData type");
        }
      },
      item_data_variant);
  }
  fout << "\n\n";
}

template <typename DataType>
void
GnuplotWriter::_writeCellData(const CellValue<DataType>& cell_value, CellId cell_id, std::ostream& fout) const
{
  const auto& value = cell_value[cell_id];
  if constexpr (std::is_arithmetic_v<DataType>) {
    fout << ' ' << value;
  } else if constexpr (is_tiny_vector_v<std::decay_t<DataType>>) {
    for (size_t i = 0; i < value.dimension(); ++i) {
      fout << ' ' << value[i];
    }
  } else if constexpr (is_tiny_matrix_v<std::decay_t<DataType>>) {
    for (size_t i = 0; i < value.numberOfRows(); ++i) {
      for (size_t j = 0; j < value.numberOfColumns(); ++j) {
        fout << ' ' << value(i, j);
      }
    }
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename DataType>
void
GnuplotWriter::_writeCellData(const CellArray<DataType>& cell_array, CellId cell_id, std::ostream& fout) const
{
  const auto& array = cell_array[cell_id];
  if constexpr (std::is_arithmetic_v<DataType>) {
    for (size_t i = 0; i < array.size(); ++i) {
      fout << ' ' << array[i];
    }
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename ItemDataT>
void
GnuplotWriter::_writeData(const ItemDataT& item_data,
                          [[maybe_unused]] CellId cell_id,
                          [[maybe_unused]] NodeId node_id,
                          std::ostream& fout) const
{
  if constexpr (ItemDataT::item_t == ItemType::cell) {
    this->_writeCellData(item_data, cell_id, fout);
  } else if constexpr (ItemDataT::item_t == ItemType::node) {
    this->_writeNodeData(item_data, node_id, fout);
  } else {
    throw UnexpectedError{"invalid item type"};
  }
}

template <typename MeshType>
void
GnuplotWriter::_writeDataAtNodes(const MeshType& mesh,
                                 const OutputNamedItemDataSet& output_named_item_data_set,
                                 std::ostream& fout) const
{
  if constexpr (MeshType::Dimension == 1) {
    auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
    auto cell_is_owned       = mesh.connectivity().cellIsOwned();

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      if (cell_is_owned[cell_id]) {
        const auto& cell_nodes = cell_to_node_matrix[cell_id];
        for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
          const NodeId& node_id                     = cell_nodes[i_node];
          const TinyVector<MeshType::Dimension>& xr = mesh.xr()[node_id];
          fout << xr[0];
          for (const auto& [name, item_data_variant] : output_named_item_data_set) {
            std::visit([&](auto&& item_data) { _writeData(item_data, cell_id, node_id, fout); }, item_data_variant);
          }
          fout << '\n';
        }

        fout << "\n\n";
      }
    }
  } else if constexpr (MeshType::Dimension == 2) {
    auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
    auto cell_is_owned       = mesh.connectivity().cellIsOwned();

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      if (cell_is_owned[cell_id]) {
        const auto& cell_nodes = cell_to_node_matrix[cell_id];
        for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
          const NodeId& node_id                     = cell_nodes[i_node];
          const TinyVector<MeshType::Dimension>& xr = mesh.xr()[node_id];
          fout << xr[0] << ' ' << xr[1];
          for (const auto& [name, item_data_variant] : output_named_item_data_set) {
            std::visit([&](auto&& item_data) { _writeData(item_data, cell_id, node_id, fout); }, item_data_variant);
          }
          fout << '\n';
        }
        const NodeId& node_id                     = cell_nodes[0];
        const TinyVector<MeshType::Dimension>& xr = mesh.xr()[node_id];
        fout << xr[0] << ' ' << xr[1];
        for (const auto& [name, item_data_variant] : output_named_item_data_set) {
          std::visit([&](auto&& item_data) { _writeData(item_data, cell_id, node_id, fout); }, item_data_variant);
        }
        fout << "\n\n\n";
      }
    }
  } else {
    throw UnexpectedError("invalid mesh dimension");
  }
}

template <typename DataType>
void
GnuplotWriter::_writeNodeData(const NodeValue<DataType>& node_value, NodeId node_id, std::ostream& fout) const
{
  const auto& value = node_value[node_id];
  if constexpr (std::is_arithmetic_v<DataType>) {
    fout << ' ' << value;
  } else if constexpr (is_tiny_vector_v<std::decay_t<DataType>>) {
    for (size_t i = 0; i < value.dimension(); ++i) {
      fout << ' ' << value[i];
    }
  } else if constexpr (is_tiny_matrix_v<std::decay_t<DataType>>) {
    for (size_t i = 0; i < value.numberOfRows(); ++i) {
      for (size_t j = 0; j < value.numberOfColumns(); ++j) {
        fout << ' ' << value(i, j);
      }
    }
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename DataType>
void
GnuplotWriter::_writeNodeData(const NodeArray<DataType>& node_array, NodeId node_id, std::ostream& fout) const
{
  const auto& array = node_array[node_id];
  if constexpr (std::is_arithmetic_v<DataType>) {
    for (size_t i = 0; i < array.size(); ++i) {
      fout << ' ' << array[i];
    }
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename MeshType>
void
GnuplotWriter::_write(const MeshType& mesh,
                      const OutputNamedItemDataSet& output_named_item_data_set,
                      std::optional<double> time) const
{
  createDirectoryIfNeeded(_getFilename());

  if (parallel::rank() == 0) {
    std::ofstream fout{_getFilename()};
    if (not fout) {
      std::ostringstream error_msg;
      error_msg << "cannot create file \"" << rang::fgB::yellow << _getFilename() << rang::fg::reset << '"';
      throw NormalError(error_msg.str());
    }

    fout.precision(15);
    fout.setf(std::ios_base::scientific);

    fout << this->_getDateAndVersionComment();
    if (time.has_value()) {
      fout << "# time = " << *time << "\n\n";
    }
    this->_writePreamble<MeshType::Dimension>(output_named_item_data_set, fout);
  }

  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    std::visit(
      [&, name = name](auto&& item_data) {
        using ItemDataType = std::decay_t<decltype(item_data)>;
        if (ItemDataType::item_t == ItemType::face) {
          std::ostringstream error_msg;
          error_msg << "gnuplot_writer does not support face data, cannot save variable \"" << rang::fgB::yellow << name
                    << rang::fg::reset << '"';
          throw NormalError(error_msg.str());
        }
      },
      item_data_variant);
  }

  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    std::visit(
      [&, name = name](auto&& item_data) {
        using ItemDataType = std::decay_t<decltype(item_data)>;
        if (ItemDataType::item_t == ItemType::edge) {
          std::ostringstream error_msg;
          error_msg << "gnuplot_writer does not support edge data, cannot save variable \"" << rang::fgB::yellow << name
                    << rang::fg::reset << '"';
          throw NormalError(error_msg.str());
        }
      },
      item_data_variant);
  }

  for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
    if (i_rank == parallel::rank()) {
      std::ofstream fout(_getFilename(), std::ios_base::app);
      if (not fout) {
        std::ostringstream error_msg;
        error_msg << "cannot open file \"" << rang::fgB::yellow << _getFilename() << rang::fg::reset << '"';
        throw NormalError(error_msg.str());
      }

      fout.precision(15);
      fout.setf(std::ios_base::scientific);

      this->_writeDataAtNodes(mesh, output_named_item_data_set, fout);
    }
    parallel::barrier();
  }
}

void
GnuplotWriter::_writeAtTime(const IMesh& mesh,
                            const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                            double time) const
{
  OutputNamedItemDataSet output_named_item_data_set = this->_getOutputNamedItemDataSet(named_discrete_data_list);

  switch (mesh.dimension()) {
  case 1: {
    this->_write(dynamic_cast<const Mesh<Connectivity<1>>&>(mesh), output_named_item_data_set, time);
    break;
  }
  case 2: {
    this->_write(dynamic_cast<const Mesh<Connectivity<2>>&>(mesh), output_named_item_data_set, time);
    break;
  }
  default: {
    throw NormalError("gnuplot format is not available in dimension " + stringify(mesh.dimension()));
  }
  }
}

void
GnuplotWriter::_writeMesh(const IMesh& mesh) const
{
  OutputNamedItemDataSet output_named_item_data_set{};

  switch (mesh.dimension()) {
  case 1: {
    this->_write(dynamic_cast<const Mesh<Connectivity<1>>&>(mesh), output_named_item_data_set, {});
    break;
  }
  case 2: {
    this->_write(dynamic_cast<const Mesh<Connectivity<2>>&>(mesh), output_named_item_data_set, {});
    break;
  }
  default: {
    throw NormalError("gnuplot format is not available in dimension " + stringify(mesh.dimension()));
  }
  }
}

void
GnuplotWriter::_write(const IMesh& mesh,
                      const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  OutputNamedItemDataSet output_named_item_data_set = this->_getOutputNamedItemDataSet(named_discrete_data_list);

  switch (mesh.dimension()) {
  case 1: {
    this->_write(dynamic_cast<const Mesh<Connectivity<1>>&>(mesh), output_named_item_data_set, {});
    break;
  }
  case 2: {
    this->_write(dynamic_cast<const Mesh<Connectivity<2>>&>(mesh), output_named_item_data_set, {});
    break;
  }
  default: {
    throw NormalError("gnuplot format is not available in dimension " + stringify(mesh.dimension()));
  }
  }
}
