#include <output/GnuplotWriterRaw.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshTraits.hpp>
#include <mesh/MeshVariant.hpp>
#include <utils/Filesystem.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsTraits.hpp>
#include <utils/RevisionInfo.hpp>
#include <utils/Stringify.hpp>

#include <utils/Demangle.hpp>

#include <fstream>
#include <iomanip>

template <typename DataType, ItemType item_type>
size_t
GnuplotWriterRaw::_itemDataNbRow(const ItemValue<DataType, item_type>&) const
{
  if constexpr (std::is_arithmetic_v<DataType>) {
    return 1;
  } else if constexpr (is_tiny_vector_v<std::decay_t<DataType>>) {
    return DataType::Dimension;
  } else if constexpr (is_tiny_matrix_v<std::decay_t<DataType>>) {
    return DataType{}.dimension();
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename DataType, ItemType item_type>
size_t
GnuplotWriterRaw::_itemDataNbRow(const ItemArray<DataType, item_type>& item_array) const
{
  return item_array.sizeOfArrays();
}

template <MeshConcept MeshType, ItemType item_type>
void
GnuplotWriterRaw::_writeItemDatas(const MeshType& mesh,
                                  const OutputNamedItemDataSet& output_named_item_data_set,
                                  std::ostream& fout) const
{
  using ItemId = ItemIdT<item_type>;

  const size_t number_of_columns = [&] {
    size_t nb_columns = 0;
    for (const auto& [name, item_data] : output_named_item_data_set) {
      std::visit([&](auto&& value) { nb_columns += _itemDataNbRow(value); }, item_data);
    }
    return nb_columns;
  }();

  auto is_owned = mesh.connectivity().template isOwned<item_type>();

  const size_t& number_of_owned_lines = [&]() {
    if (parallel::size() > 1) {
      size_t number_of_owned_items = 0;
      for (ItemId item_id = 0; item_id < mesh.template numberOf<item_type>(); ++item_id) {
        if (is_owned[item_id]) {
          ++number_of_owned_items;
        }
      }

      return number_of_owned_items;
    } else {
      return mesh.template numberOf<item_type>();
    }
  }();

  Array<double> values{number_of_columns * number_of_owned_lines};

  size_t column_number = 0;
  for (const auto& [name, output_item_data] : output_named_item_data_set) {
    std::visit(
      [&](auto&& item_data) {
        using ItemDataT = std::decay_t<decltype(item_data)>;
        if constexpr (ItemDataT::item_t == item_type) {
          if constexpr (is_item_value_v<ItemDataT>) {
            using DataT  = std::decay_t<typename ItemDataT::data_type>;
            size_t index = 0;
            for (ItemId item_id = 0; item_id < item_data.numberOfItems(); ++item_id) {
              if (is_owned[item_id]) {
                if constexpr (std::is_arithmetic_v<DataT>) {
                  values[number_of_columns * index + column_number] = item_data[item_id];
                } else if constexpr (is_tiny_vector_v<DataT>) {
                  const size_t k = number_of_columns * index + column_number;
                  for (size_t j = 0; j < DataT::Dimension; ++j) {
                    values[k + j] = item_data[item_id][j];
                  }
                } else if constexpr (is_tiny_matrix_v<DataT>) {
                  size_t k = number_of_columns * index + column_number;
                  for (size_t i = 0; i < DataT{}.numberOfRows(); ++i) {
                    for (size_t j = 0; j < DataT{}.numberOfColumns(); ++j) {
                      values[k++] = item_data[item_id](i, j);
                    }
                  }
                }
                ++index;
              }
            }
          } else {
            using DataT  = std::decay_t<typename ItemDataT::data_type>;
            size_t index = 0;
            for (ItemId item_id = 0; item_id < item_data.numberOfItems(); ++item_id) {
              if (is_owned[item_id]) {
                if constexpr (std::is_arithmetic_v<DataT>) {
                  const size_t k = number_of_columns * index + column_number;
                  for (size_t j = 0; j < item_data.sizeOfArrays(); ++j) {
                    values[k + j] = item_data[item_id][j];
                  }
                }
                ++index;
              }
            }
          }
        }
        column_number += _itemDataNbRow(item_data);
      },
      output_item_data);
  }

  if (parallel::size() > 1) {
    values = parallel::gatherVariable(values, 0);
  }

  if (parallel::rank() == 0) {
    Assert(values.size() % number_of_columns == 0);

    std::vector<size_t> line_numbers(values.size() / number_of_columns);
    for (size_t i = 0; i < line_numbers.size(); ++i) {
      line_numbers[i] = i;
    }

    std::sort(line_numbers.begin(), line_numbers.end(),
              [&](size_t i, size_t j) { return values[i * number_of_columns] < values[j * number_of_columns]; });

    for (auto i_line : line_numbers) {
      fout << values[i_line * number_of_columns];
      for (size_t j = 1; j < number_of_columns; ++j) {
        fout << ' ' << values[i_line * number_of_columns + j];
      }
      fout << '\n';
    }
  }
}

template <MeshConcept MeshType>
void
GnuplotWriterRaw::_write(const MeshType& mesh,
                         const OutputNamedItemDataSet& output_named_item_data_set,
                         std::optional<double> time) const
{
  bool has_cell_data = false;
  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    has_cell_data |= std::visit([&](auto&& item_data) { return this->_is_cell_data(item_data); }, item_data_variant);
  }

  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    std::visit(
      [&, var_name = name](auto&& item_data) {
        if (this->_is_face_data(item_data)) {
          std::ostringstream error_msg;
          error_msg << "gnuplot_raw_writer does not support face data, cannot save variable \"" << rang::fgB::yellow
                    << var_name << rang::fg::reset << '"';
          throw NormalError(error_msg.str());
        }
      },
      item_data_variant);
  }

  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    std::visit(
      [&, var_name = name](auto&& item_data) {
        if (this->_is_edge_data(item_data)) {
          std::ostringstream error_msg;
          error_msg << "gnuplot_1d_writer does not support edge data, cannot save variable \"" << rang::fgB::yellow
                    << var_name << rang::fg::reset << '"';
          throw NormalError(error_msg.str());
        }
      },
      item_data_variant);
  }

  bool has_node_data = false;
  for (const auto& [name, item_data_variant] : output_named_item_data_set) {
    has_node_data |= std::visit([&](auto&& item_data) { return this->_is_node_data(item_data); }, item_data_variant);
  }

  if (has_cell_data and has_node_data) {
    throw NormalError("cannot store both node and cell data in the same gnuplot file");
  }

  createDirectoryIfNeeded(_getFilename());

  std::ofstream fout;

  if (parallel::rank() == 0) {
    fout.open(_getFilename());
    if (not fout) {
      std::ostringstream error_msg;
      error_msg << "cannot create file \"" << rang::fgB::yellow << _getFilename() << rang::fg::reset << '"';
      throw NormalError(error_msg.str());
    }

    fout.precision(15);
    fout.setf(std::ios_base::scientific);
    fout << _getDateAndVersionComment();

    if (time.has_value()) {
      fout << "# time = " << *time << "\n\n";
    }

    _writePreamble(MeshType::Dimension, output_named_item_data_set, false /*do not store coordinates*/, fout);
  }

  if (has_cell_data) {
    this->_writeItemDatas<MeshType, ItemType::cell>(mesh, output_named_item_data_set, fout);
  } else {   // has_node_value
    this->_writeItemDatas<MeshType, ItemType::node>(mesh, output_named_item_data_set, fout);
  }
}

void
GnuplotWriterRaw::_writeMesh(const MeshVariant&) const
{
  std::ostringstream errorMsg;
  errorMsg << "gnuplot_raw_writer does not write meshes\n"
           << rang::style::bold << "note:" << rang::style::reset << " one can use " << rang::fgB::blue
           << "gnuplot_writer" << rang::fg::reset << " instead";
  throw NormalError(errorMsg.str());
}

void
GnuplotWriterRaw::_writeAtTime(const MeshVariant& mesh_v,
                               const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
                               double time) const
{
  OutputNamedItemDataSet output_named_item_data_set = this->_getOutputNamedItemDataSet(named_discrete_data_list);

  std::visit(
    [&](auto&& p_mesh) {
      using MeshType = mesh_type_t<decltype(p_mesh)>;
      if constexpr (MeshType::Dimension == 1) {
        this->_write(*p_mesh, output_named_item_data_set, time);
      } else if constexpr (MeshType::Dimension == 2) {
        std::ostringstream errorMsg;
        errorMsg << "gnuplot_raw_writer is not available in dimension " << stringify(MeshType::Dimension) << '\n'
                 << rang::style::bold << "note:" << rang::style::reset << " one can use " << rang::fgB::blue
                 << "gnuplot_writer" << rang::fg::reset << " in dimension 2";
        throw NormalError(errorMsg.str());
      } else {
        throw NormalError("gnuplot format is not available in dimension " + stringify(MeshType::Dimension));
      }
    },
    mesh_v.variant());
}

void
GnuplotWriterRaw::_write(const MeshVariant& mesh_v,
                         const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const
{
  OutputNamedItemDataSet output_named_item_data_set = this->_getOutputNamedItemDataSet(named_discrete_data_list);

  std::visit([&](auto&& p_mesh) { this->_write(*p_mesh, output_named_item_data_set, {}); }, mesh_v.variant());
}