#include <utils/ExecutionStatManager.hpp>

#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <cmath>
#include <iomanip>
#include <rang.hpp>
#include <sys/resource.h>

ExecutionStatManager* ExecutionStatManager::m_instance = nullptr;

std::string
ExecutionStatManager::_prettyPrintTime(double time_in_seconds) const
{
  std::ostringstream os;
  size_t seconds    = std::floor(time_in_seconds);
  const size_t days = seconds / (24 * 3600);
  seconds -= days * (24 * 3600);
  const size_t hours = seconds / 3600;
  seconds -= hours * 3600;
  const size_t minutes = seconds / 60;
  seconds -= minutes * 60;
  os << rang::style::bold;
  bool print = false;
  if (days > 0) {
    print = true;
    os << days << "d" << ' ';
  }
  if (print or (hours > 0)) {
    print = true;
    os << std::setw(2) << std::setfill('0') << hours << "h";
  }
  if (print or (minutes > 0)) {
    print = true;
    os << std::setw(2) << std::setfill('0') << minutes << "mn";
  }
  if (print) {
    os << rang::style::bold << std::setw(2) << std::setfill('0') << seconds << "s";
  }
  os << rang::style::reset;

  return os.str();
}

void
ExecutionStatManager::_printMaxResidentMemory(std::ostream& os) const
{
  class Memory
  {
   private:
    double m_value;

   public:
    PUGS_INLINE const double&
    value() const
    {
      return m_value;
    }

    std::string
    prettyPrint() const
    {
      const std::vector<std::string> units = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};

      double local_memory = m_value;
      size_t i_unit       = 0;
      while ((local_memory >= 1024) and (i_unit < units.size())) {
        ++i_unit;
        local_memory /= 1024;
      }
      std::ostringstream os;
      os << local_memory << units[i_unit];
      return os.str();
    }

    Memory()
    {
      rusage u;
      getrusage(RUSAGE_SELF, &u);
      m_value = u.ru_maxrss * 1024;
    }

    Memory(double value) : m_value{value} {}
  };

  Memory memory;
  os << "Memory: " << rang::style::bold << Memory{parallel::allReduceSum(memory.value())}.prettyPrint()
     << rang::style::reset;
  os << " (over " << parallel::size() << " processes)";
  os << " Avg: " << rang::style::bold << Memory{parallel::allReduceSum(memory.value()) / parallel::size()}.prettyPrint()
     << rang::style::reset;
  os << " Min: " << rang::style::bold << Memory{parallel::allReduceMin(memory.value())}.prettyPrint()
     << rang::style::reset;
  os << " Max: " << rang::style::bold << Memory{parallel::allReduceMax(memory.value())}.prettyPrint()
     << rang::style::reset;
  os << '\n';
}

void
ExecutionStatManager::_printElapseTime(std::ostream& os) const
{
  const double elapse_time = m_instance->m_elapse_time.seconds();
  os << "Execution: " << rang::style::bold << elapse_time << 's' << rang::style::reset;
  if (elapse_time > 60) {
    os << " [" << rang::style::bold << this->_prettyPrintTime(elapse_time) << rang::style::reset << ']';
  }
  if (m_run_number > 1) {
    const double cumulative_elapse_time = elapse_time + m_previous_cumulative_elapse_time;
    os << " (Run number " << m_run_number << ").\n - Cumulative execution time: " << rang::style::bold
       << cumulative_elapse_time << 's' << rang::style::reset;
    if (cumulative_elapse_time > 60) {
      os << " [" << rang::style::bold << this->_prettyPrintTime(cumulative_elapse_time) << rang::style::reset << ']';
    }
  }
  os << '\n';
}

void
ExecutionStatManager::_printTotalCPUTime(std::ostream& os) const
{
  rusage u;
  getrusage(RUSAGE_SELF, &u);

  const double total_cpu_time =
    parallel::allReduceSum(u.ru_utime.tv_sec + u.ru_stime.tv_sec + (u.ru_utime.tv_usec + u.ru_stime.tv_usec) * 1E-6);

  os << "Total CPU: " << rang::style::bold << total_cpu_time << 's' << rang::style::reset;
  os << " (" << parallel::allReduceSum(Kokkos::DefaultHostExecutionSpace::concurrency()) << " threads over "
     << parallel::size() << " processes)";
  if (total_cpu_time > 60) {
    os << " [" << _prettyPrintTime(total_cpu_time) << ']';
  }

  if (m_run_number > 1) {
    const double cumulative_total_cpu_time = total_cpu_time + m_previous_cumulative_total_cpu_time;
    os << "\n - Cumulative total CPU: " << rang::style::bold << cumulative_total_cpu_time << 's' << rang::style::reset;
    if (cumulative_total_cpu_time > 60) {
      os << " [" << rang::style::bold << this->_prettyPrintTime(cumulative_total_cpu_time) << rang::style::reset << ']';
    }
  }

  os << '\n';
}

void
ExecutionStatManager::printInfo(std::ostream& os)
{
  if (ExecutionStatManager::getInstance().doPrint()) {
    os << "----------------- " << rang::fg::green << "pugs exec stats" << rang::fg::reset << " ---------------------\n";

    ExecutionStatManager::getInstance()._printElapseTime(os);
    ExecutionStatManager::getInstance()._printTotalCPUTime(os);
    ExecutionStatManager::getInstance()._printMaxResidentMemory(os);
  }
}

double
ExecutionStatManager::getElapseTime() const
{
  return m_elapse_time.seconds();
}

double
ExecutionStatManager::getCumulativeElapseTime() const
{
  return m_previous_cumulative_elapse_time + m_elapse_time.seconds();
}

double
ExecutionStatManager::getCumulativeTotalCPUTime() const
{
  rusage u;
  getrusage(RUSAGE_SELF, &u);

  const double total_cpu_time =
    u.ru_utime.tv_sec + u.ru_stime.tv_sec + (u.ru_utime.tv_usec + u.ru_stime.tv_usec) * 1E-6;

  return m_previous_cumulative_total_cpu_time + parallel::allReduceSum(total_cpu_time);
}

void
ExecutionStatManager::create()
{
  if (ExecutionStatManager::m_instance == nullptr) {
    ExecutionStatManager::m_instance = new ExecutionStatManager;
  } else {
    throw UnexpectedError("ExecutionStatManager already created");
  }
}

void
ExecutionStatManager::destroy()
{
  // One allows multiple destruction to handle unexpected code exit
  if (ExecutionStatManager::m_instance != nullptr) {
    delete ExecutionStatManager::m_instance;
    ExecutionStatManager::m_instance = nullptr;
  }
}