diff --git a/src/language/modules/SocketModule.cpp b/src/language/modules/SocketModule.cpp index 99eb7a702c399997e2d12fbaea1cab8214a204b8..7a8a503682eab3b091f7a188ded29b21b07ea156 100644 --- a/src/language/modules/SocketModule.cpp +++ b/src/language/modules/SocketModule.cpp @@ -11,7 +11,7 @@ SocketModule::SocketModule() std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const Socket>(const uint64_t&)>>( [](const uint64_t& port_number) -> std::shared_ptr<const Socket> { - return std::make_shared<const Socket>(createSocketServer(port_number)); + return std::make_shared<const Socket>(createServerSocket(port_number)); } )); @@ -21,7 +21,7 @@ SocketModule::SocketModule() BuiltinFunctionEmbedder<std::shared_ptr<const Socket>(std::shared_ptr<const Socket>)>>( [](std::shared_ptr<const Socket> server_socket) -> std::shared_ptr<const Socket> { - return std::make_shared<const Socket>(acceptSocketClient(*server_socket)); + return std::make_shared<const Socket>(acceptClientSocket(*server_socket)); } )); @@ -32,10 +32,249 @@ SocketModule::SocketModule() [](const std::string& hostname, const uint64_t& port_number) -> std::shared_ptr<const Socket> { - return std::make_shared<const Socket>(connectSocketServer(hostname, port_number)); + return std::make_shared<const Socket>(connectServerSocket(hostname, port_number)); } )); + + this->_addBuiltinFunction("write", + std::make_shared< + BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, const bool&)>>( + + [](const std::shared_ptr<const Socket>& socket, const bool& value) -> void { + write(*socket, value); + } + + )); + + this->_addBuiltinFunction("write", + std::make_shared< + BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, const uint64_t&)>>( + + [](const std::shared_ptr<const Socket>& socket, const uint64_t& value) -> void { + write(*socket, value); + } + + )); + + this->_addBuiltinFunction("write", + std::make_shared< + BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, const int64_t&)>>( + + [](const std::shared_ptr<const Socket>& socket, const int64_t& value) -> void { + write(*socket, value); + } + + )); + + this->_addBuiltinFunction("write", + std::make_shared< + BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, const double&)>>( + + [](const std::shared_ptr<const Socket>& socket, const double& value) -> void { + write(*socket, value); + } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyVector<1>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyVector<1>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyVector<2>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyVector<2>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyVector<3>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyVector<3>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyMatrix<1>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyMatrix<1>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyMatrix<2>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyMatrix<2>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", std::make_shared<BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, + const TinyMatrix<3>&)>>( + + [](const std::shared_ptr<const Socket>& socket, + const TinyMatrix<3>& value) -> void { write(*socket, value); } + + )); + + this->_addBuiltinFunction("write", + std::make_shared< + BuiltinFunctionEmbedder<void(const std::shared_ptr<const Socket>&, const std::string&)>>( + + [](const std::shared_ptr<const Socket>& socket, const std::string& value) -> void { + write(*socket, value.size()); + write(*socket, value); + } + + )); + + this->_addBuiltinFunction("read_B", + std::make_shared<BuiltinFunctionEmbedder<bool(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> bool { + bool value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_N", + std::make_shared<BuiltinFunctionEmbedder<uint64_t(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> uint64_t { + uint64_t value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_Z", + std::make_shared<BuiltinFunctionEmbedder<int64_t(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> int64_t { + int64_t value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R", + std::make_shared<BuiltinFunctionEmbedder<double(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> double { + double value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R1", std::make_shared< + BuiltinFunctionEmbedder<TinyVector<1>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyVector<1> { + TinyVector<1> value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R2", std::make_shared< + BuiltinFunctionEmbedder<TinyVector<2>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyVector<2> { + TinyVector<2> value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R3", std::make_shared< + BuiltinFunctionEmbedder<TinyVector<3>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyVector<3> { + TinyVector<3> value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R1x1", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<1>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyMatrix<1> { + TinyMatrix<1> value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R2x2", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<2>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyMatrix<2> { + TinyMatrix<2> value; + read(*socket, value); + + return value; + } + + )); + + this->_addBuiltinFunction("read_R3x3", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<3>(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> TinyMatrix<3> { + TinyMatrix<3> value; + read(*socket, value); + + return value; + } + + )); + + this + ->_addBuiltinFunction("read_string", + std::make_shared<BuiltinFunctionEmbedder<std::string(const std::shared_ptr<const Socket>&)>>( + + [](const std::shared_ptr<const Socket>& socket) -> std::string { + size_t size; + read(*socket, size); + std::string value; + if (size > 0) { + value.resize(size); + read(*socket, value); + } + return value; + } + + )); } void diff --git a/src/utils/Socket.cpp b/src/utils/Socket.cpp index d8ece713dfd5e3660bf83dbd94c8ffba0ba6067a..b74cd4cc78e381994552f86113b87646afe3b1db 100644 --- a/src/utils/Socket.cpp +++ b/src/utils/Socket.cpp @@ -16,10 +16,12 @@ class Socket::Internals int m_socket_fd; sockaddr_in m_address; + const bool m_is_server_socket; + public: - friend Socket createSocketServer(int port_number); - friend Socket acceptSocketClient(const Socket& server); - friend Socket connectSocketServer(const std::string& server_name, int port_number); + friend Socket createServerSocket(int port_number); + friend Socket acceptClientSocket(const Socket& server); + friend Socket connectServerSocket(const std::string& server_name, int port_number); friend std::ostream& operator<<(std::ostream& os, const Socket::Internals& internals) @@ -45,6 +47,12 @@ class Socket::Internals return os; } + bool + isServerSocket() const + { + return m_is_server_socket; + } + int fileDescriptor() const { @@ -54,7 +62,7 @@ class Socket::Internals Internals(const Internals&) = delete; Internals(Internals&&) = delete; - Internals() = default; + Internals(bool is_server_socket = false) : m_is_server_socket{is_server_socket} {} ~Internals() { @@ -69,9 +77,9 @@ operator<<(std::ostream& os, const Socket& socket) } Socket -createSocketServer(int port_number) +createServerSocket(int port_number) { - auto p_socket_internals = std::make_shared<Socket::Internals>(); + auto p_socket_internals = std::make_shared<Socket::Internals>(true); Socket::Internals& socket_internals = *p_socket_internals; socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); @@ -97,7 +105,7 @@ createSocketServer(int port_number) } Socket -acceptSocketClient(const Socket& server) +acceptClientSocket(const Socket& server) { auto p_socket_internals = std::make_shared<Socket::Internals>(); Socket::Internals& socket_internals = *p_socket_internals; @@ -119,7 +127,7 @@ acceptSocketClient(const Socket& server) } Socket -connectSocketServer(const std::string& server_name, int port_number) +connectServerSocket(const std::string& server_name, int port_number) { std::cout << "Trying to establish connection to " << server_name << ':' << port_number << '\n'; @@ -158,6 +166,10 @@ connectSocketServer(const std::string& server_name, int port_number) void Socket::_write(const char* data, const size_t lenght) const { + if (this->m_internals->isServerSocket()) { + throw NormalError("Server cannot write to server socket!"); + } + if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) { throw NormalError(strerror(errno)); } @@ -166,11 +178,15 @@ Socket::_write(const char* data, const size_t lenght) const void Socket::_read(char* data, const size_t length) const { + if (this->m_internals->isServerSocket()) { + throw NormalError("Server cannot read from server socket!"); + } + size_t received = 0; do { int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received); if (n <= 0) { - throw NormalError(strerror(errno)); + throw NormalError("Could not read data"); } received += n; } while (received < length); diff --git a/src/utils/Socket.hpp b/src/utils/Socket.hpp index 89bc63b9fc46c277968cf788183511424f16c90d..b8ee7201de755cea95b47a2b6db90f127cfeab4c 100644 --- a/src/utils/Socket.hpp +++ b/src/utils/Socket.hpp @@ -4,6 +4,7 @@ #include <memory> #include <string> #include <type_traits> +#include <utils/PugsTraits.hpp> class Socket { @@ -20,9 +21,9 @@ class Socket public: friend std::ostream& operator<<(std::ostream& os, const Socket& s); - friend Socket createSocketServer(int port_number); - friend Socket acceptSocketClient(const Socket& server); - friend Socket connectSocketServer(const std::string& server_name, int port_number); + friend Socket createServerSocket(int port_number); + friend Socket acceptClientSocket(const Socket& server); + friend Socket connectServerSocket(const std::string& server_name, int port_number); template <typename T> friend void write(const Socket& socket, const T& value); @@ -42,14 +43,15 @@ class Socket ~Socket() = default; }; -Socket createSocketServer(int port_number); -Socket acceptSocketClient(const Socket& server); -Socket connectSocketServer(const std::string& server_name, int port_number); +Socket createServerSocket(int port_number); +Socket acceptClientSocket(const Socket& server); +Socket connectServerSocket(const std::string& server_name, int port_number); template <typename T> inline void write(const Socket& socket, const T& value) { + static_assert(std::is_arithmetic_v<T> or is_tiny_vector_v<T> or is_tiny_matrix_v<T>, "unexpected value type"); socket._write(reinterpret_cast<const char*>(&value), sizeof(T) / sizeof(char)); } @@ -73,6 +75,7 @@ void read(const Socket& socket, ArrayT<T, R...>& array) { static_assert(not std::is_const_v<T>, "cannot read values into const data"); + static_assert(std::is_arithmetic_v<T> or is_tiny_vector_v<T> or is_tiny_matrix_v<T>, "unexpected value type"); socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); }