diff --git a/src/language/modules/CMakeLists.txt b/src/language/modules/CMakeLists.txt index f7f1baeafe34261a38033dd3d2db828df334ce25..c4edd2505b39320d39b25b7653667f7992567ccc 100644 --- a/src/language/modules/CMakeLists.txt +++ b/src/language/modules/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(PugsLanguageModules MeshModule.cpp ModuleRepository.cpp SchemeModule.cpp + SocketModule.cpp UnaryOperatorRegisterForVh.cpp UtilsModule.cpp WriterModule.cpp diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index 71a1368b9dc374b19562bae27cc14a5d2b54c233..85810cbfc3f12bba920ae2ee99002be73f0db39d 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -6,6 +6,7 @@ #include <language/modules/MathModule.hpp> #include <language/modules/MeshModule.hpp> #include <language/modules/SchemeModule.hpp> +#include <language/modules/SocketModule.hpp> #include <language/modules/UtilsModule.hpp> #include <language/modules/WriterModule.hpp> #include <language/utils/BasicAffectationRegistrerFor.hpp> @@ -56,6 +57,7 @@ ModuleRepository::ModuleRepository() this->_subscribe(std::make_unique<MathModule>()); this->_subscribe(std::make_unique<MeshModule>()); this->_subscribe(std::make_unique<SchemeModule>()); + this->_subscribe(std::make_unique<SocketModule>()); this->_subscribe(std::make_unique<UtilsModule>()); this->_subscribe(std::make_unique<WriterModule>()); } diff --git a/src/language/modules/SocketModule.cpp b/src/language/modules/SocketModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99eb7a702c399997e2d12fbaea1cab8214a204b8 --- /dev/null +++ b/src/language/modules/SocketModule.cpp @@ -0,0 +1,43 @@ +#include <language/modules/SocketModule.hpp> + +#include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <utils/Socket.hpp> + +SocketModule::SocketModule() +{ + this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const Socket>>); + + this->_addBuiltinFunction("createSocketServer", + std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const Socket>(const uint64_t&)>>( + + [](const uint64_t& port_number) -> std::shared_ptr<const Socket> { + return std::make_shared<const Socket>(createSocketServer(port_number)); + } + + )); + + this->_addBuiltinFunction("acceptSocketClient", + std::make_shared< + BuiltinFunctionEmbedder<std::shared_ptr<const Socket>(std::shared_ptr<const Socket>)>>( + + [](std::shared_ptr<const Socket> server_socket) -> std::shared_ptr<const Socket> { + return std::make_shared<const Socket>(acceptSocketClient(*server_socket)); + } + + )); + + this->_addBuiltinFunction("connectSocketServer", + std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const Socket>(const std::string&, + const uint64_t&)>>( + + [](const std::string& hostname, + const uint64_t& port_number) -> std::shared_ptr<const Socket> { + return std::make_shared<const Socket>(connectSocketServer(hostname, port_number)); + } + + )); +} + +void +SocketModule::registerOperators() const +{} diff --git a/src/language/modules/SocketModule.hpp b/src/language/modules/SocketModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b32b64fafc72829e6598cfa78a95ce14f61ed9b --- /dev/null +++ b/src/language/modules/SocketModule.hpp @@ -0,0 +1,28 @@ +#ifndef SOCKET_MODULE_HPP +#define SOCKET_MODULE_HPP + +#include <language/modules/BuiltinModule.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> + +class Socket; + +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const Socket>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("socket"); + +class SocketModule : public BuiltinModule +{ + public: + std::string_view + name() const final + { + return "socket"; + } + + void registerOperators() const final; + + SocketModule(); + ~SocketModule() = default; +}; + +#endif // SOCKET_MODULE_HPP diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index 19feb273c4a6e2ad09818f467caa9f0f79636b35..741b0d5104cef791cf045fd8cc76cd8b16db30a2 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -15,7 +15,8 @@ add_library( RandomEngine.cpp RevisionInfo.cpp SignalManager.cpp - SLEPcWrapper.cpp) + SLEPcWrapper.cpp + Socket.cpp) target_link_libraries( PugsUtils diff --git a/src/utils/Socket.cpp b/src/utils/Socket.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8ece713dfd5e3660bf83dbd94c8ffba0ba6067a --- /dev/null +++ b/src/utils/Socket.cpp @@ -0,0 +1,177 @@ +#include <utils/Exceptions.hpp> +#include <utils/Socket.hpp> + +#include <arpa/inet.h> +#include <cstring> +#include <iostream> +#include <netdb.h> +#include <netinet/in.h> +#include <stdexcept> +#include <sys/socket.h> +#include <unistd.h> + +class Socket::Internals +{ + private: + int m_socket_fd; + sockaddr_in m_address; + + public: + friend Socket createSocketServer(int port_number); + friend Socket acceptSocketClient(const Socket& server); + friend Socket connectSocketServer(const std::string& server_name, int port_number); + + friend std::ostream& + operator<<(std::ostream& os, const Socket::Internals& internals) + { + char hbuf[NI_MAXHOST], sbuf[NI_MAXSERV]; + if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf, + sizeof(hbuf), sbuf, sizeof(sbuf), NI_NAMEREQD) == 0) { + os << hbuf << ':' << sbuf; + } else if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf, + sizeof(hbuf), sbuf, sizeof(sbuf), NI_NUMERICHOST) == 0) { + if (std ::string{hbuf} == "0.0.0.0") { + if (::gethostname(hbuf, NI_MAXHOST) == 0) { + os << hbuf << ':' << sbuf; + } else { + os << "localhost:" << sbuf; + } + } else { + os << hbuf << ':' << sbuf; + } + } else { + os << "<unknown host>"; + } + return os; + } + + int + fileDescriptor() const + { + return m_socket_fd; + } + + Internals(const Internals&) = delete; + Internals(Internals&&) = delete; + + Internals() = default; + + ~Internals() + { + close(m_socket_fd); + } +}; + +std::ostream& +operator<<(std::ostream& os, const Socket& socket) +{ + return os << *socket.m_internals; +} + +Socket +createSocketServer(int port_number) +{ + auto p_socket_internals = std::make_shared<Socket::Internals>(); + Socket::Internals& socket_internals = *p_socket_internals; + + socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (socket_internals.m_socket_fd < 0) { + throw NormalError(strerror(errno)); + } + + socket_internals.m_address.sin_family = AF_INET; + socket_internals.m_address.sin_addr.s_addr = INADDR_ANY; + socket_internals.m_address.sin_port = htons(port_number); + + int on = 1; + ::setsockopt(socket_internals.m_socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + + if (::bind(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&socket_internals.m_address), + sizeof(socket_internals.m_address)) < 0) { + throw NormalError(strerror(errno)); + } + + ::listen(socket_internals.m_socket_fd, 1); + + return Socket{p_socket_internals}; +} + +Socket +acceptSocketClient(const Socket& server) +{ + auto p_socket_internals = std::make_shared<Socket::Internals>(); + Socket::Internals& socket_internals = *p_socket_internals; + + std::cout << "Waiting for connection on " << server << '\n'; + + socklen_t address_lenght = sizeof(socket_internals.m_address); + socket_internals.m_socket_fd = ::accept(server.m_internals->m_socket_fd, + reinterpret_cast<sockaddr*>(&socket_internals.m_address), &address_lenght); + + if (socket_internals.m_socket_fd < 0) { + throw NormalError(strerror(errno)); + } + + Socket client(p_socket_internals); + std::cout << "Connected from " << client << '\n'; + + return client; +} + +Socket +connectSocketServer(const std::string& server_name, int port_number) +{ + std::cout << "Trying to establish connection to " << server_name << ':' << port_number << '\n'; + + auto p_socket_internals = std::make_shared<Socket::Internals>(); + Socket::Internals& socket_internals = *p_socket_internals; + + socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (socket_internals.m_socket_fd < 0) { + throw NormalError(strerror(errno)); + } + + hostent* server = ::gethostbyname(server_name.c_str()); + if (server == NULL) { + throw NormalError(strerror(errno)); + } + + sockaddr_in& serv_addr = socket_internals.m_address; + ::memset(reinterpret_cast<char*>(&serv_addr), 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + + ::memcpy(reinterpret_cast<char*>(&serv_addr.sin_addr.s_addr), reinterpret_cast<char*>(server->h_addr), + server->h_length); + + serv_addr.sin_port = ::htons(port_number); + + if (::connect(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr))) { + throw NormalError(strerror(errno)); + } + + Socket server_socket{p_socket_internals}; + std::cout << "Connected to " << server_socket << '\n'; + + return server_socket; +} + +void +Socket::_write(const char* data, const size_t lenght) const +{ + if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) { + throw NormalError(strerror(errno)); + } +} + +void +Socket::_read(char* data, const size_t length) const +{ + size_t received = 0; + do { + int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received); + if (n <= 0) { + throw NormalError(strerror(errno)); + } + received += n; + } while (received < length); +} diff --git a/src/utils/Socket.hpp b/src/utils/Socket.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89bc63b9fc46c277968cf788183511424f16c90d --- /dev/null +++ b/src/utils/Socket.hpp @@ -0,0 +1,79 @@ +#ifndef SOCKET_HPP +#define SOCKET_HPP + +#include <memory> +#include <string> +#include <type_traits> + +class Socket +{ + private: + class Internals; + + std::shared_ptr<Internals> m_internals; + + Socket(std::shared_ptr<Internals> internals) : m_internals{internals} {} + + void _write(const char* data, const size_t lenght) const; + void _read(char* data, const size_t lenght) const; + + public: + friend std::ostream& operator<<(std::ostream& os, const Socket& s); + + friend Socket createSocketServer(int port_number); + friend Socket acceptSocketClient(const Socket& server); + friend Socket connectSocketServer(const std::string& server_name, int port_number); + + template <typename T> + friend void write(const Socket& socket, const T& value); + + template <typename T> + friend void read(const Socket& socket, T& value); + + template <template <typename T, typename... R> typename ArrayT, typename T, typename... R> + friend void write(const Socket& socket, const ArrayT<T, R...>& array); + + template <template <typename T, typename... R> typename ArrayT, typename T, typename... R> + friend void read(const Socket& socket, ArrayT<T, R...>& array); + + Socket(Socket&&) = default; + Socket(const Socket&) = default; + + ~Socket() = default; +}; + +Socket createSocketServer(int port_number); +Socket acceptSocketClient(const Socket& server); +Socket connectSocketServer(const std::string& server_name, int port_number); + +template <typename T> +inline void +write(const Socket& socket, const T& value) +{ + socket._write(reinterpret_cast<const char*>(&value), sizeof(T) / sizeof(char)); +} + +template <template <typename T, typename... R> typename ArrayT, typename T, typename... R> +void +write(const Socket& socket, const ArrayT<T, R...>& array) +{ + socket._write(reinterpret_cast<const char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); +} + +template <typename T> +inline void +read(const Socket& socket, T& value) +{ + static_assert(not std::is_const_v<T>, "cannot read values into const data"); + socket._read(reinterpret_cast<char*>(&value), sizeof(T) / sizeof(char)); +} + +template <template <typename T, typename... R> typename ArrayT, typename T, typename... R> +void +read(const Socket& socket, ArrayT<T, R...>& array) +{ + static_assert(not std::is_const_v<T>, "cannot read values into const data"); + socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); +} + +#endif // SOCKET_HPP