#ifndef PARALLEL_CHECKER_HPP
#define PARALLEL_CHECKER_HPP

#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <utils/HDF5.hpp>
#include <utils/Messenger.hpp>

#include <utils/SourceLocation.hpp>

#include <fstream>
#include <utils/Demangle.hpp>

namespace parallel
{
#ifdef PUGS_HAS_HDF5

template <typename DataType, ItemType item_type, typename ConnectivityPtr>
void check(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value,
           const std::string& name,
           const SourceLocation& source_location = SourceLocation{});

class ParallelChecker
{
 private:
  static ParallelChecker* m_instance;

  size_t m_tag = 0;

  std::string m_filename = "parallel_checker.h5";

  ParallelChecker() = default;

  void
  _printHeader(const std::string& name, const SourceLocation& source_location) const
  {
    std::cout << rang::fg::cyan << " | " << rang::fgB::cyan << "parallel checker" << rang::fg::cyan << " for \""
              << rang::fgB::magenta << name << rang::fg::cyan << "\" tag " << rang::fgB::blue << m_tag
              << rang::fg::reset << '\n';
    std::cout << rang::fg::cyan << " | from " << rang::fgB::blue << source_location.filename() << rang::fg::reset << ':'
              << rang::style::bold << source_location.line() << rang::style::reset << '\n';
  }

 public:
  static void create();
  static void destroy();

  static ParallelChecker&
  instance()
  {
    return *m_instance;
  }

  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  friend void check(const ItemValue<DataType, item_type, ConnectivityPtr>&, const std::string&, const SourceLocation&);

 private:
  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  void
  write(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value,
        const std::string& name,
        const SourceLocation& source_location)
  {
    this->_printHeader(name, source_location);

    auto file_id = [&] {
      if (m_tag == 0) {
        return HDF5::create(m_filename);
      } else {
        return HDF5::openFileRW(m_filename);
      }
    }();

    auto values_group_id = HDF5::createOrOpenGroup(file_id, "/values");
    auto group_id        = HDF5::createOrOpenGroup(values_group_id, std::to_string(m_tag));

    HDF5::writeAttribute(group_id, "filename", std::string{source_location.filename()});
    HDF5::writeAttribute(group_id, "function", source_location.function());
    HDF5::writeAttribute(group_id, "line", static_cast<size_t>(source_location.line()));
    HDF5::writeAttribute(group_id, "name", name);

    std::shared_ptr<const IConnectivity> i_connectivity = item_value.connectivity_ptr();
    HDF5::writeAttribute(group_id, "dimension", static_cast<size_t>(i_connectivity->dimension()));
    HDF5::writeAttribute(group_id, "item_type", itemName(item_type));
    HDF5::writeAttribute(group_id, "data_type", demangle<DataType>());

    HDF5::writeArray(group_id, name, item_value.arrayView());

    switch (i_connectivity->dimension()) {
    case 1: {
      const Connectivity<1>& connectivity = dynamic_cast<const Connectivity<1>&>(*i_connectivity);
      HDF5::writeArray(group_id, "numbers", connectivity.number<item_type>().arrayView());
      break;
    }
    case 2: {
      const Connectivity<2>& connectivity = dynamic_cast<const Connectivity<2>&>(*i_connectivity);
      HDF5::writeArray(group_id, "numbers", connectivity.number<item_type>().arrayView());
      break;
    }
    case 3: {
      const Connectivity<3>& connectivity = dynamic_cast<const Connectivity<3>&>(*i_connectivity);
      HDF5::writeArray(group_id, "numbers", connectivity.number<item_type>().arrayView());
      break;
    }
    default: {
      throw UnexpectedError("unexpected connectivity dimension");
    }
    }

    ++m_tag;

    HDF5::close(values_group_id);
    HDF5::close(group_id);
    HDF5::close(file_id);

    std::cout << rang::fg::cyan << " | writing " << rang::fgB::green << "success" << rang::fg::reset << '\n';
  }

  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  void
  compare(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value,
          const std::string& name,
          const SourceLocation& source_location)
  {
    this->_printHeader(name, source_location);

    auto file_id = HDF5::openFileRO(m_filename);

    auto values_group_id = HDF5::openGroup(file_id, "/values");
    auto group_id        = HDF5::openGroup(values_group_id, std::to_string(m_tag));

    const std::string reference_name          = HDF5::readAttribute<std::string>(group_id, "name");
    const std::string reference_file_name     = HDF5::readAttribute<std::string>(group_id, "filename");
    const std::string reference_function_name = HDF5::readAttribute<std::string>(group_id, "function");
    const size_t reference_line_number        = HDF5::readAttribute<size_t>(group_id, "line");
    const size_t reference_dimension          = HDF5::readAttribute<size_t>(group_id, "dimension");
    const std::string reference_item_type     = HDF5::readAttribute<std::string>(group_id, "item_type");
    const std::string reference_data_type     = HDF5::readAttribute<std::string>(group_id, "data_type");

    std::shared_ptr<const IConnectivity> i_connectivity = item_value.connectivity_ptr();

    bool is_comparable = true;
    if (i_connectivity->dimension() != reference_dimension) {
      std::cout << rang::fg::cyan << " | " << rang::fgB::red << "different support dimensions: reference ("
                << rang::fgB::yellow << reference_dimension << rang::fgB::red << ") / target (" << rang::fgB::yellow
                << i_connectivity->dimension() << rang::fg::reset << ")\n";
      is_comparable = false;
    }
    if (itemName(item_type) != reference_item_type) {
      std::cout << rang::fg::cyan << " | " << rang::fgB::red << "different item types: reference (" << rang::fgB::yellow
                << reference_item_type << rang::fgB::red << ") / target (" << rang::fgB::yellow << itemName(item_type)
                << rang::fg::reset << ")\n";
      is_comparable = false;
    }
    if (demangle<DataType>() != reference_data_type) {
      std::cout << rang::fg::cyan << " | " << rang::fgB::red << "different data types: reference (" << rang::fgB::yellow
                << reference_data_type << rang::fgB::red << ") / target (" << rang::fgB::yellow << demangle<DataType>()
                << rang::fg::reset << ")\n";
      is_comparable = false;
    }
    if (name != reference_name) {
      // Just warn for different labels (maybe useful for some kind of
      // debugging...)
      std::cout << rang::fg::cyan << " | " << rang::fgB::magenta << "different names: reference (" << rang::fgB::yellow
                << reference_name << rang::fgB::magenta << ") / target (" << rang::fgB::yellow << name
                << rang::fg::reset << ")\n";
      std::cout << rang::fg::cyan << " | " << rang::fgB::magenta << "reference from " << rang::fgB::blue
                << reference_file_name << rang::fg::reset << ':' << rang::style::bold << reference_line_number
                << rang::style::reset << '\n';
      std::cout << rang::fg::cyan << " | " << rang::fgB::magenta << "reference function " << rang::fgB::blue
                << reference_function_name << rang::fg::reset << '\n';
      std::cout << rang::fg::cyan << " | " << rang::fgB::magenta << "target function " << rang::fgB::blue
                << source_location.function() << rang::fg::reset << '\n';
    }

    if (not parallel::allReduceAnd(is_comparable)) {
      throw NormalError("cannot compare data");
    }

    Array<const int> reference_item_numbers = HDF5::readArray<int>(group_id, "numbers");
    Array<const DataType> reference_item_value =
      HDF5::readArray<std::remove_const_t<DataType> >(group_id, reference_name);

    Array<const int> item_numbers = [&] {
      switch (i_connectivity->dimension()) {
      case 1: {
        const Connectivity<1>& connectivity = dynamic_cast<const Connectivity<1>&>(*i_connectivity);
        return connectivity.number<item_type>().arrayView();
      }
      case 2: {
        const Connectivity<2>& connectivity = dynamic_cast<const Connectivity<2>&>(*i_connectivity);
        return connectivity.number<item_type>().arrayView();
      }
      case 3: {
        const Connectivity<3>& connectivity = dynamic_cast<const Connectivity<3>&>(*i_connectivity);
        return connectivity.number<item_type>().arrayView();
      }
      default: {
        throw UnexpectedError("unexpected connectivity dimension");
      }
      }
    }();

    using ItemId = ItemIdT<item_type>;

    std::unordered_map<int, ItemId> item_number_to_item_id_map;

    for (ItemId item_id = 0; item_id < item_numbers.size(); ++item_id) {
      const auto& [iterator, success] =
        item_number_to_item_id_map.insert(std::make_pair(item_numbers[item_id], item_id));

      if (not success) {
        throw UnexpectedError("item numbers have duplicate values");
      }
    }

    Assert(item_number_to_item_id_map.size() == item_numbers.size());

    Array<int> index_in_reference(item_numbers.size());
    index_in_reference.fill(-1);
    for (size_t i = 0; i < reference_item_numbers.size(); ++i) {
      const auto& i_number_to_item_id = item_number_to_item_id_map.find(reference_item_numbers[i]);
      if (i_number_to_item_id != item_number_to_item_id_map.end()) {
        index_in_reference[i_number_to_item_id->second] = i;
      }
    }

    if (parallel::allReduceMin(min(index_in_reference)) < 0) {
      throw NormalError("some item numbers are not defined in reference");
    }

    Array<const int> owner = [&] {
      switch (i_connectivity->dimension()) {
      case 1: {
        const Connectivity<1>& connectivity = dynamic_cast<const Connectivity<1>&>(*i_connectivity);
        return connectivity.owner<item_type>().arrayView();
      }
      case 2: {
        const Connectivity<2>& connectivity = dynamic_cast<const Connectivity<2>&>(*i_connectivity);
        return connectivity.owner<item_type>().arrayView();
      }
      case 3: {
        const Connectivity<3>& connectivity = dynamic_cast<const Connectivity<3>&>(*i_connectivity);
        return connectivity.owner<item_type>().arrayView();
      }
      default: {
        throw UnexpectedError("unexpected connectivity dimension");
      }
      }
    }();

    bool has_own_differences = false;
    bool is_same             = true;

    for (ItemId item_id = 0; item_id < item_value.numberOfItems(); ++item_id) {
      if (reference_item_value[index_in_reference[item_id]] != item_value[item_id]) {
        is_same = false;
        if (static_cast<size_t>(owner[item_id]) == parallel::rank()) {
          has_own_differences = true;
        }
      }
    }

    is_same             = parallel::allReduceAnd(is_same);
    has_own_differences = parallel::allReduceOr(has_own_differences);

    if (is_same) {
      std::cout << rang::fg::cyan << " | compare: " << rang::fgB::green << "success" << rang::fg::reset << '\n';
    } else {
      if (has_own_differences) {
        std::cout << rang::fg::cyan << " | compare: " << rang::fgB::red << "failed!" << rang::fg::reset;
      } else {
        std::cout << rang::fg::cyan << " | compare: " << rang::fgB::yellow << "not synchronized" << rang::fg::reset;
      }
      std::cout << rang::fg::cyan << " [see \"" << rang::fgB::blue << "parallel_differences_" << m_tag << "_*"
                << rang::fg::cyan << "\" files for details]" << rang::fg::reset << '\n';

      {
        std::ofstream fout(std::string{"parallel_differences_"} + stringify(m_tag) + std::string{"_"} +
                           stringify(parallel::rank()));

        fout.precision(15);
        for (ItemId item_id = 0; item_id < item_value.numberOfItems(); ++item_id) {
          if (reference_item_value[index_in_reference[item_id]] != item_value[item_id]) {
            const bool is_own_difference = (parallel::rank() == static_cast<size_t>(owner[item_id]));
            if (is_own_difference) {
              fout << rang::fgB::red << "[ own ]" << rang::fg::reset;
            } else {
              fout << rang::fgB::yellow << "[ghost]" << rang::fg::reset;
            }
            fout << " rank=" << parallel::rank() << " owner=" << owner[item_id] << " item_id=" << item_id
                 << " number=" << item_numbers[item_id]
                 << " reference=" << reference_item_value[index_in_reference[item_id]]
                 << " target=" << item_value[item_id]
                 << " difference=" << reference_item_value[index_in_reference[item_id]] - item_value[item_id] << '\n';
            if (static_cast<size_t>(owner[item_id]) == parallel::rank()) {
              has_own_differences = true;
            }
          }
        }
      }

      if (parallel::allReduceAnd(has_own_differences)) {
        throw NormalError("calculations differ!");
      }
    }

    HDF5::close(values_group_id);
    HDF5::close(group_id);
    HDF5::close(file_id);
    ++m_tag;
  }
};

template <typename DataType, ItemType item_type, typename ConnectivityPtr>
void
check(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value,
      const std::string& name,
      const SourceLocation& source_location)
{
  const bool write_mode = (parallel::size() == 1);

  std::cout << '\n';
  if (write_mode) {
    ParallelChecker::instance().write(item_value, name, source_location);
  } else {
    ParallelChecker::instance().compare(item_value, name, source_location);
  }
  std::cout << '\n';
}

#else   // PUGS_HAS_HDF5

template <typename DataType, ItemType item_type, typename ConnectivityPtr>
void
check(const ItemValue<DataType, item_type, ConnectivityPtr>&,
      const std::string&,
      const SourceLocation& = SourceLocation{})
{
  throw UnexpectedError("parallel checker cannot be used without HDF5 support");
}

class ParallelChecker
{
 private:
  static ParallelChecker* m_instance;

 public:
  static void create();
  static void destroy();

  static ParallelChecker&
  instance()
  {
    return *m_instance;
  }
};

#endif   // PUGS_HAS_HDF5

}   // namespace parallel

#endif   // PARALLEL_CHECKER_HPP