#include <output/GnuplotWriter1D.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsTraits.hpp>
#include <utils/RevisionInfo.hpp>

#include <utils/Demangle.hpp>

#include <fstream>
#include <iomanip>

std::string
GnuplotWriter1D::_getDateAndVersionComment() const
{
  std::ostringstream os;

  std::time_t now = std::time(nullptr);
  os << "#  Generated by pugs: " << std::ctime(&now);
  os << "#  version: " << RevisionInfo::version() << '\n';
  os << "#  tag:  " << RevisionInfo::gitTag() << '\n';
  os << "#  HEAD: " << RevisionInfo::gitHead() << '\n';
  os << "#  hash: " << RevisionInfo::gitHash() << " (" << ((RevisionInfo::gitIsClean()) ? "clean" : "dirty") << ")\n";
  os << '\n';

  return os.str();
}

std::string
GnuplotWriter1D::_getFilename() const
{
  std::ostringstream sout;
  sout << m_base_filename;
  sout << '.' << std::setfill('0') << std::setw(4) << m_saved_times.size() << ".gnu";
  return sout.str();
}

template <typename DataType>
bool
GnuplotWriter1D::_is_cell_value(const CellValue<const DataType>&) const
{
  return true;
}

template <typename DataType>
bool
GnuplotWriter1D::_is_cell_value(const NodeValue<const DataType>&) const
{
  return false;
}

template <typename DataType>
bool
GnuplotWriter1D::_is_node_value(const CellValue<const DataType>&) const
{
  return false;
}

template <typename DataType>
bool
GnuplotWriter1D::_is_node_value(const NodeValue<const DataType>&) const
{
  return true;
}

void
GnuplotWriter1D::_writePreamble(const OutputNamedItemValueSet& output_named_item_value_set, std::ostream& fout) const
{
  fout << "# list of data\n";
  fout << "# 1:x";
  uint64_t i = 2;
  for (const auto& i_named_item_value : output_named_item_value_set) {
    const std::string name         = i_named_item_value.first;
    const auto& item_value_variant = i_named_item_value.second;
    std::visit(
      [&](auto&& item_value) {
        using ItemValueType = std::decay_t<decltype(item_value)>;
        using DataType      = std::decay_t<typename ItemValueType::data_type>;
        if constexpr (std::is_arithmetic_v<DataType>) {
          fout << ' ' << i++ << ':' << name;
        } else if constexpr (is_tiny_vector_v<DataType>) {
          for (size_t j = 0; j < DataType{}.dimension(); ++j) {
            fout << ' ' << i++ << ':' << name << '[' << j << ']';
          }
        } else if constexpr (is_tiny_matrix_v<DataType>) {
          for (size_t j = 0; j < DataType{}.nbRows(); ++j) {
            for (size_t k = 0; k < DataType{}.nbColumns(); ++k) {
              fout << ' ' << i++ << ':' << name << '(' << j << ',' << k << ')';
            }
          }
        } else {
          throw UnexpectedError("invalid data type");
        }
      },
      item_value_variant);
  }
  fout << "\n\n";
}

template <typename DataType, ItemType item_type>
size_t
GnuplotWriter1D::_itemValueNbRow(const ItemValue<DataType, item_type>&) const
{
  if constexpr (std::is_arithmetic_v<DataType>) {
    return 1;
  } else if constexpr (is_tiny_vector_v<std::decay_t<DataType>>) {
    return DataType::Dimension;
  } else if constexpr (is_tiny_matrix_v<std::decay_t<DataType>>) {
    return DataType{}.dimension();
  } else {
    throw UnexpectedError("invalid data type for cell value output: " + demangle<DataType>());
  }
}

template <typename MeshType, ItemType item_type>
void
GnuplotWriter1D::_writeItemValues(const std::shared_ptr<const MeshType>& mesh,
                                  const OutputNamedItemValueSet& output_named_item_value_set,
                                  std::ostream& fout) const
{
  using ItemId = ItemIdT<item_type>;

  const size_t& number_of_columns = [&] {
    size_t number_of_columns = 1;
    for (auto [name, item_value] : output_named_item_value_set) {
      std::visit([&](auto&& value) { number_of_columns += _itemValueNbRow(value); }, item_value);
    }
    return number_of_columns;
  }();

  auto is_owned = mesh->connectivity().template isOwned<item_type>();

  const size_t& number_of_owned_lines = [&]() {
    if (parallel::size() > 1) {
      size_t number_of_owned_items = 0;
      for (ItemId item_id = 0; item_id < mesh->template numberOf<item_type>(); ++item_id) {
        if (is_owned[item_id]) {
          ++number_of_owned_items;
        }
      }

      return number_of_owned_items;
    } else {
      return mesh->template numberOf<item_type>();
    }
  }();

  Array<double> values{number_of_columns * number_of_owned_lines};

  if constexpr (item_type == ItemType::cell) {
    auto& mesh_data         = MeshDataManager::instance().getMeshData(*mesh);
    const auto& cell_center = mesh_data.xj();

    size_t index = 0;
    for (ItemId item_id = 0; item_id < mesh->template numberOf<item_type>(); ++item_id) {
      if (is_owned[item_id]) {
        values[number_of_columns * index++] = cell_center[item_id][0];
      }
    }
  } else if constexpr (item_type == ItemType::node) {
    const auto& node_position = mesh->xr();

    size_t index = 0;
    for (ItemId item_id = 0; item_id < mesh->template numberOf<item_type>(); ++item_id) {
      if (is_owned[item_id]) {
        values[number_of_columns * index++] = node_position[item_id][0];
      }
    }
  } else {
    throw UnexpectedError("invalid item type");
  }

  size_t column_number = 1;
  for (auto [name, output_item_value] : output_named_item_value_set) {
    std::visit(
      [&](auto&& item_value) {
        using ItemValueT = std::decay_t<decltype(item_value)>;
        if constexpr (ItemValueT::item_t == item_type) {
          using DataT  = std::decay_t<typename ItemValueT::data_type>;
          size_t index = 0;
          for (ItemId item_id = 0; item_id < item_value.size(); ++item_id) {
            if (is_owned[item_id]) {
              if constexpr (std::is_arithmetic_v<DataT>) {
                values[number_of_columns * index + column_number] = item_value[item_id];
              } else if constexpr (is_tiny_vector_v<DataT>) {
                const size_t k = number_of_columns * index + column_number;
                for (size_t j = 0; j < DataT::Dimension; ++j) {
                  values[k + j] = item_value[item_id][j];
                }
              } else if constexpr (is_tiny_matrix_v<DataT>) {
                size_t k = number_of_columns * index + column_number;
                for (size_t i = 0; i < DataT{}.nbRows(); ++i) {
                  for (size_t j = 0; j < DataT{}.nbColumns(); ++j) {
                    values[k++] = item_value[item_id](i, j);
                  }
                }
              }
              ++index;
            }
          }
        }
        column_number += _itemValueNbRow(item_value);
      },
      output_item_value);
  }

  if (parallel::size() > 1) {
    values = parallel::gatherVariable(values, 0);
  }

  if (parallel::rank() == 0) {
    Assert(values.size() % number_of_columns == 0);

    std::vector<size_t> line_numbers(values.size() / number_of_columns);
    for (size_t i = 0; i < line_numbers.size(); ++i) {
      line_numbers[i] = i;
    }

    std::sort(line_numbers.begin(), line_numbers.end(),
              [&](size_t i, size_t j) { return values[i * number_of_columns] < values[j * number_of_columns]; });

    for (auto i_line : line_numbers) {
      fout << values[i_line * number_of_columns];
      for (size_t j = 1; j < number_of_columns; ++j) {
        fout << ' ' << values[i_line * number_of_columns + j];
      }
      fout << '\n';
    }
  }
}

template <typename MeshType>
void
GnuplotWriter1D::_write(const std::shared_ptr<const MeshType>& mesh,
                        const OutputNamedItemValueSet& output_named_item_value_set,
                        double time) const
{
  bool has_cell_value = false;
  for (const auto& [name, item_value_variant] : output_named_item_value_set) {
    has_cell_value |=
      std::visit([&](auto&& item_value) { return this->_is_cell_value(item_value); }, item_value_variant);
  }

  bool has_node_value = false;
  for (const auto& [name, item_value_variant] : output_named_item_value_set) {
    has_node_value |=
      std::visit([&](auto&& item_value) { return this->_is_node_value(item_value); }, item_value_variant);
  }

  if (has_cell_value and has_node_value) {
    throw NormalError("cannot store both node and cell values in a gnuplot file");
  }

  std::ofstream fout;

  if (parallel::rank() == 0) {
    fout.open(_getFilename());
    fout.precision(15);
    fout.setf(std::ios_base::scientific);
    fout << _getDateAndVersionComment();

    fout << "# time = " << time << "\n\n";

    _writePreamble(output_named_item_value_set, fout);
  }

  if (has_cell_value) {
    this->_writeItemValues<MeshType, ItemType::cell>(mesh, output_named_item_value_set, fout);
  } else {   // has_node_value
    this->_writeItemValues<MeshType, ItemType::node>(mesh, output_named_item_value_set, fout);
  }
}

void
GnuplotWriter1D::writeMesh(const std::shared_ptr<const IMesh>&) const
{
  throw UnexpectedError("This function should not be called");
}

void
GnuplotWriter1D::write(const std::vector<std::shared_ptr<const NamedDiscreteFunction>>& named_discrete_function_list,
                       double time) const
{
  std::shared_ptr mesh = this->_getMesh(named_discrete_function_list);

  OutputNamedItemValueSet output_named_item_value_set = this->_getOutputNamedItemValueSet(named_discrete_function_list);

  switch (mesh->dimension()) {
  case 1: {
    this->_write(std::dynamic_pointer_cast<const Mesh<Connectivity<1>>>(mesh), output_named_item_value_set, time);
    break;
  }
  case 2: {
    std::ostringstream errorMsg;
    errorMsg << "gnuplot_1d_writer is not available in dimension " << std::to_string(mesh->dimension()) << '\n'
             << rang::style::bold << "note:" << rang::style::reset << " one can use " << rang::fgB::blue
             << "gnuplot_writer" << rang::style::reset << " in dimension 2";
    throw NormalError(errorMsg.str());
  }
  default: {
    throw NormalError("gnuplot format is not available in dimension " + std::to_string(mesh->dimension()));
  }
  }
}