diff --git a/src/utils/Socket.cpp b/src/utils/Socket.cpp index 503dac1831832fe4887277f24337c88fb3e49e49..3582e359be22a257b2b52109f7a3ed57b2944a1b 100644 --- a/src/utils/Socket.cpp +++ b/src/utils/Socket.cpp @@ -28,6 +28,10 @@ class Socket::Internals friend std::ostream& operator<<(std::ostream& os, const Socket::Internals& internals) { + // This function's coverage is not performed since it's quite + // complex to create its various conditions + // + // LCOV_EXCL_START char hbuf[NI_MAXHOST], sbuf[NI_MAXSERV]; if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), NI_NAMEREQD) == 0) { @@ -46,6 +50,7 @@ class Socket::Internals } else { os << "<unknown host>"; } + // LCOV_EXCL_STOP return os; } @@ -73,7 +78,9 @@ class Socket::Internals Internals(bool is_server_socket = false) : m_is_server_socket{is_server_socket} { if (parallel::size() > 1) { + // LCOV_EXCL_START throw NotImplementedError("Sockets are not managed in parallel"); + // LCOV_EXCL_STOP } } @@ -103,7 +110,8 @@ createServerSocket(int port_number) socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); if (socket_internals.m_socket_fd < 0) { - throw NormalError(strerror(errno)); + // This should never happen + throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE } socket_internals.m_address.sin_family = AF_INET; @@ -121,7 +129,8 @@ createServerSocket(int port_number) if (::getsockname(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&socket_internals.m_address), &length) == -1) { - throw NormalError(strerror(errno)); + // This should never happen + throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE } ::listen(socket_internals.m_socket_fd, 1); @@ -140,7 +149,8 @@ acceptClientSocket(const Socket& server) reinterpret_cast<sockaddr*>(&socket_internals.m_address), &address_lenght); if (socket_internals.m_socket_fd < 0) { - throw NormalError(strerror(errno)); + // This should never happen + throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE } return Socket{p_socket_internals}; @@ -154,7 +164,8 @@ connectServerSocket(const std::string& server_name, int port_number) socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); if (socket_internals.m_socket_fd < 0) { - throw NormalError(strerror(errno)); + // This should never happen + throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE } hostent* server = ::gethostbyname(server_name.c_str()); @@ -182,11 +193,12 @@ void Socket::_write(const char* data, const size_t lenght) const { if (this->m_internals->isServerSocket()) { - throw NormalError("Server cannot write to server socket!"); + throw NormalError("Server cannot write to server socket"); } if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) { - throw NormalError(strerror(errno)); + // Quite complex to test + throw NormalError(strerror(errno)); // LCOV_EXCL_LINE } } @@ -194,15 +206,15 @@ void Socket::_read(char* data, const size_t length) const { if (this->m_internals->isServerSocket()) { - throw NormalError("Server cannot read from server socket!"); + throw NormalError("Server cannot read from server socket"); } size_t received = 0; - do { + while (received < length) { int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received); if (n <= 0) { throw NormalError("Could not read data"); } received += n; - } while (received < length); + } } diff --git a/src/utils/Socket.hpp b/src/utils/Socket.hpp index 6b47fa4f01d84a7782fd310bb38e78f1c4f85c6e..e6d25ddd1328d559e4786e677aff8cde9af06549 100644 --- a/src/utils/Socket.hpp +++ b/src/utils/Socket.hpp @@ -78,7 +78,9 @@ read(const Socket& socket, ArrayT<T, R...>& array) { static_assert(not std::is_const_v<T>, "cannot read values into const data"); static_assert(std::is_arithmetic_v<T> or is_tiny_vector_v<T> or is_tiny_matrix_v<T>, "unexpected value type"); - socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); + if (array.size() > 0) { + socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); + } } #endif // SOCKET_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96729d8acb0fc5dc87df0ddb732d63fc0d2a8ee7..380f070b9e8886849280c9d11f16078786b1c62b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -120,6 +120,7 @@ add_executable (unit_tests test_RevisionInfo.cpp test_SmallArray.cpp test_SmallVector.cpp + test_Socket.cpp test_SquareGaussQuadrature.cpp test_SquareTransformation.cpp test_SymbolTable.cpp diff --git a/tests/test_Socket.cpp b/tests/test_Socket.cpp new file mode 100644 index 0000000000000000000000000000000000000000..774725b9ff48662e112922e43b1e693e736dc1d3 --- /dev/null +++ b/tests/test_Socket.cpp @@ -0,0 +1,113 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <utils/PugsAssert.hpp> + +#include <utils/Socket.hpp> + +#include <sstream> +#include <thread> + +#include <netdb.h> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("Socket", "[utils]") +{ + SECTION("create/connect and simple read/write") + { + auto self_client = [](int port) { + Socket self_server = connectServerSocket("localhost", port); + std::vector<int> values = {1, 2, 4, 7}; + write(self_server, values.size()); + write(self_server, values); + }; + + Socket server = createServerSocket(0); + + std::thread t(self_client, server.portNumber()); + + Socket client = acceptClientSocket(server); + + size_t vector_size = [&] { + size_t vector_size; + read(client, vector_size); + return vector_size; + }(); + + REQUIRE(vector_size == 4); + + std::vector<int> v(vector_size); + read(client, v); + + REQUIRE(v == std::vector<int>{1, 2, 4, 7}); + + t.join(); + } + + SECTION("move constructor") + { + Socket server = createServerSocket(0); + int port_number = server.portNumber(); + + Socket moved_server(std::move(server)); + + REQUIRE(port_number == moved_server.portNumber()); + } + + SECTION("host and port info") + { + Socket server = createServerSocket(0); + int port_number = server.portNumber(); + + std::ostringstream info; + info << server; + + std::ostringstream expected; + char hbuf[NI_MAXHOST]; + ::gethostname(hbuf, NI_MAXHOST); + expected << hbuf << ':' << port_number; + + REQUIRE(info.str() == expected.str()); + } + + SECTION("errors") + { + SECTION("connection") + { + REQUIRE_THROWS_WITH(createServerSocket(1), "error: Permission denied"); + REQUIRE_THROWS_WITH(connectServerSocket("localhost", 1), "error: Connection refused"); + + // The error message is not checked since it can depend on the + // network connection itself + REQUIRE_THROWS(connectServerSocket("an invalid host name", 1)); + } + + SECTION("server <-> server") + { + Socket server = createServerSocket(0); + REQUIRE_THROWS_WITH(write(server, 12), "error: Server cannot write to server socket"); + REQUIRE_THROWS_WITH(write(server, std::vector<int>{1, 2, 3}), "error: Server cannot write to server socket"); + + double x; + REQUIRE_THROWS_WITH(read(server, x), "error: Server cannot read from server socket"); + std::vector<double> v(1); + REQUIRE_THROWS_WITH(read(server, v), "error: Server cannot read from server socket"); + } + + SECTION("invalid read") + { + auto self_client = [](int port) { Socket self_server = connectServerSocket("localhost", port); }; + + Socket server = createServerSocket(0); + + std::thread t(self_client, server.portNumber()); + Socket client = acceptClientSocket(server); + t.join(); + + double x; + REQUIRE_THROWS_WITH(read(client, x), "error: Could not read data"); + // One cannot test easily write errors... + } + } +}