#ifndef VTK_WRITER_HPP
#define VTK_WRITER_HPP

#include <algebra/TinyVector.hpp>
#include <mesh/CellType.hpp>
#include <mesh/IConnectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <output/OutputNamedItemValueSet.hpp>
#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>

class VTKWriter
{
 private:
  const std::string m_base_filename;
  unsigned int m_file_number;
  double m_last_time;
  const double m_time_period;

  std::string
  _getFilenamePVTU()
  {
    std::ostringstream sout;
    sout << m_base_filename;
    sout << '.' << std::setfill('0') << std::setw(4) << m_file_number << ".pvtu";
    return sout.str();
  }
  std::string
  _getFilenameVTU(int rank_number) const
  {
    std::ostringstream sout;
    sout << m_base_filename;
    if (parallel::size() > 1) {
      sout << '-' << std::setfill('0') << std::setw(4) << rank_number;
    }
    sout << '.' << std::setfill('0') << std::setw(4) << m_file_number << ".vtu";
    return sout.str();
  }

  template <typename DataType>
  void
  _write_node_pvtu(std::ofstream& os, const std::string& name, const NodeValue<const DataType>&)
  {
    os << "<PDataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\"/>\n";
  }

  template <size_t N, typename DataType>
  void
  _write_node_pvtu(std::ofstream& os, const std::string& name, const NodeValue<const TinyVector<N, DataType>>&)
  {
    os << "<PDataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\" NumberOfComponents=\"" << N
       << "\"/>\n";
  }

  template <typename DataType>
  void
  _write_node_pvtu(std::ofstream&, const std::string&, const CellValue<const DataType>&)
  {}

  template <typename DataType>
  void
  _write_cell_pvtu(std::ofstream& os, const std::string& name, const CellValue<const DataType>&)
  {
    os << "<PDataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\"/>\n";
  }

  template <size_t N, typename DataType>
  void
  _write_cell_pvtu(std::ofstream& os, const std::string& name, const CellValue<const TinyVector<N, DataType>>&)
  {
    os << "<PDataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\" NumberOfComponents=\"" << N
       << "\"/>\n";
  }

  template <typename DataType>
  void
  _write_cell_pvtu(std::ofstream&, const std::string&, const NodeValue<const DataType>&)
  {}

  template <typename DataType>
  struct VTKType
  {
    inline const static std::string name = [] {
      static_assert(std::is_arithmetic_v<DataType>, "invalid data type");

      if constexpr (std::is_integral_v<DataType>) {
        if constexpr (std::is_unsigned_v<DataType>) {
          return "UInt" + std::to_string(sizeof(DataType) * 8);
        } else {
          return "UInt" + std::to_string(sizeof(DataType) * 8);
        }
      } else if constexpr (std::is_floating_point_v<DataType>) {
        return "Float" + std::to_string(sizeof(DataType) * 8);
      }
    }();
  };

  template <typename DataType>
  void
  _write_array(std::ofstream& os, const std::string& name, const Array<DataType>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\">\n";
    for (typename Array<DataType>::index_type i = 0; i < item_value.size(); ++i) {
      // The following '+' enforces integer output for char types
      os << +item_value[i] << ' ';
    }
    os << "\n</DataArray>\n";
  }

  template <size_t N, typename DataType>
  void
  _write_array(std::ofstream& os, const std::string& name, const Array<TinyVector<N, DataType>>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\" NumberOfComponents=\"" << N
       << "\">\n";
    for (typename Array<DataType>::index_type i = 0; i < item_value.size(); ++i) {
      for (size_t j = 0; j < N; ++j) {
        // The following '+' enforces integer output for char types
        os << +item_value[i][j] << ' ';
      }
    }
    os << "\n</DataArray>\n";
  }

  template <typename DataType>
  void
  _write_node_value(std::ofstream& os, const std::string& name, const NodeValue<const DataType>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\">\n";
    for (NodeId i = 0; i < item_value.size(); ++i) {
      // The following '+' enforces integer output for char types
      os << +item_value[i] << ' ';
    }
    os << "\n</DataArray>\n";
  }

  template <size_t N, typename DataType>
  void
  _write_node_value(std::ofstream& os,
                    const std::string& name,
                    const NodeValue<const TinyVector<N, DataType>>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\" NumberOfComponents=\"" << N
       << "\">\n";
    for (NodeId i = 0; i < item_value.size(); ++i) {
      for (size_t j = 0; j < N; ++j) {
        // The following '+' enforces integer output for char types
        os << +item_value[i][j] << ' ';
      }
    }
    os << "\n</DataArray>\n";
  }

  template <typename DataType>
  void
  _write_node_value(std::ofstream&, const std::string&, const CellValue<const DataType>&)
  {}

  template <typename DataType>
  void
  _write_cell_value(std::ofstream& os, const std::string& name, const CellValue<const DataType>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\">\n";
    for (CellId i = 0; i < item_value.size(); ++i) {
      // The following '+' enforces integer output for char types
      os << +item_value[i] << ' ';
    }
    os << "\n</DataArray>\n";
  }

  template <size_t N, typename DataType>
  void
  _write_cell_value(std::ofstream& os,
                    const std::string& name,
                    const CellValue<const TinyVector<N, DataType>>& item_value)
  {
    os << "<DataArray type=\"" << VTKType<DataType>::name << "\" Name=\"" << name << "\" NumberOfComponents=\"" << N
       << "\">\n";
    for (CellId i = 0; i < item_value.size(); ++i) {
      for (size_t j = 0; j < N; ++j) {
        // The following '+' enforces integer output for char types
        os << +item_value[i][j] << ' ';
      }
    }
    os << "\n</DataArray>\n";
  }

  template <typename DataType>
  void
  _write_cell_value(std::ofstream&, const std::string&, const NodeValue<const DataType>&)
  {}

 public:
  template <typename MeshType>
  void
  write(const MeshType& mesh,
        const OutputNamedItemValueSet& output_named_item_value_set,
        double time,
        bool forced_output = false)
  {
    if (time == m_last_time)
      return;   // output already performed
    if ((time - m_last_time >= m_time_period) or forced_output) {
      m_last_time = time;
    } else {
      return;
    }

    if (parallel::rank() == 0) {   // write PVTK file
      std::ofstream fout(_getFilenamePVTU());
      fout << "<?xml version=\"1.0\"?>\n";
      fout << "<VTKFile type=\"PUnstructuredGrid\">\n";
      fout << "<PUnstructuredGrid GhostLevel=\"0\">\n";

      fout << "<PPoints>\n";
      fout << "<PDataArray Name=\"Positions\" NumberOfComponents=\"3\" "
              "type=\"Float64\"/>\n";
      fout << "</PPoints>\n";

      fout << "<PCells>\n";
      fout << "<PDataArray type=\"Int32\" Name=\"connectivity\" "
              "NumberOfComponents=\"1\"/>\n";
      fout << "<PDataArray type=\"UInt32\" Name=\"offsets\" "
              "NumberOfComponents=\"1\"/>\n";
      fout << "<PDataArray type=\"Int8\" Name=\"types\" "
              "NumberOfComponents=\"1\"/>\n";
      for (const auto& [name, item_value_variant] : output_named_item_value_set) {
        std::visit([&, name = name](auto&& item_value) { return this->_write_cell_pvtu(fout, name, item_value); },
                   item_value_variant);
      }
      fout << "</PCells>\n";

      fout << "<PPointData>\n";
      for (const auto& [name, item_value_variant] : output_named_item_value_set) {
        std::visit([&, name = name](auto&& item_value) { return this->_write_node_pvtu(fout, name, item_value); },
                   item_value_variant);
      }
      fout << "</PPointData>\n";

      fout << "<PCellData>\n";
      for (const auto& [name, item_value_variant] : output_named_item_value_set) {
        std::visit([&, name = name](auto&& item_value) { return this->_write_cell_pvtu(fout, name, item_value); },
                   item_value_variant);
      }
      fout << "</PCellData>\n";

      for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
        fout << "<Piece Source=\"" << _getFilenameVTU(i_rank) << "\"/>\n";
      }
      fout << "</PUnstructuredGrid>\n";
      fout << "</VTKFile>\n";
    }

    {   // write VTK files
      std::ofstream fout(_getFilenameVTU(parallel::rank()));
      fout << "<?xml version=\"1.0\"?>\n";
      fout << "<VTKFile type=\"UnstructuredGrid\">\n";
      fout << "<UnstructuredGrid>\n";
      fout << "<Piece NumberOfPoints=\"" << mesh.numberOfNodes() << "\" NumberOfCells=\"" << mesh.numberOfCells()
           << "\">\n";
      fout << "<CellData>\n";
      for (const auto& [name, item_value_variant] : output_named_item_value_set) {
        std::visit([&, name = name](auto&& item_value) { return this->_write_cell_value(fout, name, item_value); },
                   item_value_variant);
      }
      fout << "</CellData>\n";
      fout << "<PointData>\n";
      for (const auto& [name, item_value_variant] : output_named_item_value_set) {
        std::visit([&, name = name](auto&& item_value) { return this->_write_node_value(fout, name, item_value); },
                   item_value_variant);
      }
      fout << "</PointData>\n";
      fout << "<Points>\n";
      {
        using Rd                      = TinyVector<MeshType::Dimension>;
        const NodeValue<const Rd>& xr = mesh.xr();
        Array<TinyVector<3>> positions(mesh.numberOfNodes());
        parallel_for(
          mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
            for (unsigned short i = 0; i < MeshType::Dimension; ++i) {
              positions[r][i] = xr[r][i];
            }
            for (unsigned short i = MeshType::Dimension; i < 3; ++i) {
              positions[r][i] = 0;
            }
          });
        _write_array(fout, "Positions", positions);
      }
      fout << "</Points>\n";

      fout << "<Cells>\n";
      {
        const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

        _write_array(fout, "connectivity", cell_to_node_matrix.entries());
      }

      {
        const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
        Array<unsigned int> offsets(mesh.numberOfCells());
        unsigned int offset = 0;
        for (CellId j = 0; j < mesh.numberOfCells(); ++j) {
          const auto& cell_nodes = cell_to_node_matrix[j];
          offset += cell_nodes.size();
          offsets[j] = offset;
        }
        _write_array(fout, "offsets", offsets);
      }

      {
        Array<int8_t> types(mesh.numberOfCells());
        const auto& cell_type = mesh.connectivity().cellType();
        parallel_for(
          mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            switch (cell_type[j]) {
            case CellType::Line: {
              types[j] = 3;
              break;
            }
            case CellType::Triangle: {
              types[j] = 5;
              break;
            }
            case CellType::Quadrangle: {
              types[j] = 9;
              break;
            }
            case CellType::Tetrahedron: {
              types[j] = 10;
              break;
            }
            case CellType::Pyramid: {
              types[j] = 14;
              break;
            }
            case CellType::Prism: {
              types[j] = 13;
              break;
            }
            case CellType::Hexahedron: {
              types[j] = 12;
              break;
            }
            default: {
              std::ostringstream os;
              os << __FILE__ << ':' << __LINE__ << ": unknown cell type";
              throw UnexpectedError(os.str());
            }
            }
          });
        _write_array(fout, "types", types);
      }

      fout << "</Cells>\n";
      fout << "</Piece>\n";
      fout << "</UnstructuredGrid>\n";
      fout << "</VTKFile>\n";
    }
    m_file_number++;
  }

  VTKWriter(const std::string& base_filename, const double time_period)
    : m_base_filename(base_filename),
      m_file_number(0),
      m_last_time(-std::numeric_limits<double>::max()),
      m_time_period(time_period)
  {}

  ~VTKWriter() = default;
};

#endif   // VTK_WRITER_HPP
