#include <utils/Exceptions.hpp>
#include <utils/Socket.hpp>

#include <arpa/inet.h>
#include <cstring>
#include <iostream>
#include <netdb.h>
#include <netinet/in.h>
#include <stdexcept>
#include <sys/socket.h>
#include <unistd.h>

class Socket::Internals
{
 private:
  int m_socket_fd;
  sockaddr_in m_address;

 public:
  friend Socket createSocketServer(int port_number);
  friend Socket acceptSocketClient(const Socket& server);
  friend Socket connectSocketServer(const std::string& server_name, int port_number);

  friend std::ostream&
  operator<<(std::ostream& os, const Socket::Internals& internals)
  {
    char hbuf[NI_MAXHOST], sbuf[NI_MAXSERV];
    if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf,
                      sizeof(hbuf), sbuf, sizeof(sbuf), NI_NAMEREQD) == 0) {
      os << hbuf << ':' << sbuf;
    } else if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf,
                             sizeof(hbuf), sbuf, sizeof(sbuf), NI_NUMERICHOST) == 0) {
      if (std ::string{hbuf} == "0.0.0.0") {
        if (::gethostname(hbuf, NI_MAXHOST) == 0) {
          os << hbuf << ':' << sbuf;
        } else {
          os << "localhost:" << sbuf;
        }
      } else {
        os << hbuf << ':' << sbuf;
      }
    } else {
      os << "<unknown host>";
    }
    return os;
  }

  int
  fileDescriptor() const
  {
    return m_socket_fd;
  }

  Internals(const Internals&) = delete;
  Internals(Internals&&)      = delete;

  Internals() = default;

  ~Internals()
  {
    close(m_socket_fd);
  }
};

std::ostream&
operator<<(std::ostream& os, const Socket& socket)
{
  return os << *socket.m_internals;
}

Socket
createSocketServer(int port_number)
{
  auto p_socket_internals             = std::make_shared<Socket::Internals>();
  Socket::Internals& socket_internals = *p_socket_internals;

  socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
  if (socket_internals.m_socket_fd < 0) {
    throw NormalError(strerror(errno));
  }

  socket_internals.m_address.sin_family      = AF_INET;
  socket_internals.m_address.sin_addr.s_addr = INADDR_ANY;
  socket_internals.m_address.sin_port        = htons(port_number);

  int on = 1;
  ::setsockopt(socket_internals.m_socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));

  if (::bind(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&socket_internals.m_address),
             sizeof(socket_internals.m_address)) < 0) {
    throw NormalError(strerror(errno));
  }

  ::listen(socket_internals.m_socket_fd, 1);

  return Socket{p_socket_internals};
}

Socket
acceptSocketClient(const Socket& server)
{
  auto p_socket_internals             = std::make_shared<Socket::Internals>();
  Socket::Internals& socket_internals = *p_socket_internals;

  std::cout << "Waiting for connection on " << server << '\n';

  socklen_t address_lenght     = sizeof(socket_internals.m_address);
  socket_internals.m_socket_fd = ::accept(server.m_internals->m_socket_fd,
                                          reinterpret_cast<sockaddr*>(&socket_internals.m_address), &address_lenght);

  if (socket_internals.m_socket_fd < 0) {
    throw NormalError(strerror(errno));
  }

  Socket client(p_socket_internals);
  std::cout << "Connected from " << client << '\n';

  return client;
}

Socket
connectSocketServer(const std::string& server_name, int port_number)
{
  std::cout << "Trying to establish connection to " << server_name << ':' << port_number << '\n';

  auto p_socket_internals             = std::make_shared<Socket::Internals>();
  Socket::Internals& socket_internals = *p_socket_internals;

  socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
  if (socket_internals.m_socket_fd < 0) {
    throw NormalError(strerror(errno));
  }

  hostent* server = ::gethostbyname(server_name.c_str());
  if (server == NULL) {
    throw NormalError(strerror(errno));
  }

  sockaddr_in& serv_addr = socket_internals.m_address;
  ::memset(reinterpret_cast<char*>(&serv_addr), 0, sizeof(serv_addr));
  serv_addr.sin_family = AF_INET;

  ::memcpy(reinterpret_cast<char*>(&serv_addr.sin_addr.s_addr), reinterpret_cast<char*>(server->h_addr),
           server->h_length);

  serv_addr.sin_port = ::htons(port_number);

  if (::connect(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr))) {
    throw NormalError(strerror(errno));
  }

  Socket server_socket{p_socket_internals};
  std::cout << "Connected to " << server_socket << '\n';

  return server_socket;
}

void
Socket::_write(const char* data, const size_t lenght) const
{
  if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) {
    throw NormalError(strerror(errno));
  }
}

void
Socket::_read(char* data, const size_t length) const
{
  size_t received = 0;
  do {
    int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received);
    if (n <= 0) {
      throw NormalError(strerror(errno));
    }
    received += n;
  } while (received < length);
}
