Skip to content
Snippets Groups Projects
Commit 2fd59d93 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tests for Socket and clean-up

parent 0f8ef0c2
No related branches found
No related tags found
1 merge request!126Begin socket handling: core functionalities are available
...@@ -28,6 +28,10 @@ class Socket::Internals ...@@ -28,6 +28,10 @@ class Socket::Internals
friend std::ostream& friend std::ostream&
operator<<(std::ostream& os, const Socket::Internals& internals) operator<<(std::ostream& os, const Socket::Internals& internals)
{ {
// This function's coverage is not performed since it's quite
// complex to create its various conditions
//
// LCOV_EXCL_START
char hbuf[NI_MAXHOST], sbuf[NI_MAXSERV]; char hbuf[NI_MAXHOST], sbuf[NI_MAXSERV];
if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf, if (::getnameinfo(reinterpret_cast<const sockaddr*>(&internals.m_address), sizeof(internals.m_address), hbuf,
sizeof(hbuf), sbuf, sizeof(sbuf), NI_NAMEREQD) == 0) { sizeof(hbuf), sbuf, sizeof(sbuf), NI_NAMEREQD) == 0) {
...@@ -46,6 +50,7 @@ class Socket::Internals ...@@ -46,6 +50,7 @@ class Socket::Internals
} else { } else {
os << "<unknown host>"; os << "<unknown host>";
} }
// LCOV_EXCL_STOP
return os; return os;
} }
...@@ -73,7 +78,9 @@ class Socket::Internals ...@@ -73,7 +78,9 @@ class Socket::Internals
Internals(bool is_server_socket = false) : m_is_server_socket{is_server_socket} Internals(bool is_server_socket = false) : m_is_server_socket{is_server_socket}
{ {
if (parallel::size() > 1) { if (parallel::size() > 1) {
// LCOV_EXCL_START
throw NotImplementedError("Sockets are not managed in parallel"); throw NotImplementedError("Sockets are not managed in parallel");
// LCOV_EXCL_STOP
} }
} }
...@@ -103,7 +110,8 @@ createServerSocket(int port_number) ...@@ -103,7 +110,8 @@ createServerSocket(int port_number)
socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
if (socket_internals.m_socket_fd < 0) { if (socket_internals.m_socket_fd < 0) {
throw NormalError(strerror(errno)); // This should never happen
throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE
} }
socket_internals.m_address.sin_family = AF_INET; socket_internals.m_address.sin_family = AF_INET;
...@@ -121,7 +129,8 @@ createServerSocket(int port_number) ...@@ -121,7 +129,8 @@ createServerSocket(int port_number)
if (::getsockname(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&socket_internals.m_address), &length) == if (::getsockname(socket_internals.m_socket_fd, reinterpret_cast<sockaddr*>(&socket_internals.m_address), &length) ==
-1) { -1) {
throw NormalError(strerror(errno)); // This should never happen
throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE
} }
::listen(socket_internals.m_socket_fd, 1); ::listen(socket_internals.m_socket_fd, 1);
...@@ -140,7 +149,8 @@ acceptClientSocket(const Socket& server) ...@@ -140,7 +149,8 @@ acceptClientSocket(const Socket& server)
reinterpret_cast<sockaddr*>(&socket_internals.m_address), &address_lenght); reinterpret_cast<sockaddr*>(&socket_internals.m_address), &address_lenght);
if (socket_internals.m_socket_fd < 0) { if (socket_internals.m_socket_fd < 0) {
throw NormalError(strerror(errno)); // This should never happen
throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE
} }
return Socket{p_socket_internals}; return Socket{p_socket_internals};
...@@ -154,7 +164,8 @@ connectServerSocket(const std::string& server_name, int port_number) ...@@ -154,7 +164,8 @@ connectServerSocket(const std::string& server_name, int port_number)
socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0); socket_internals.m_socket_fd = ::socket(AF_INET, SOCK_STREAM, 0);
if (socket_internals.m_socket_fd < 0) { if (socket_internals.m_socket_fd < 0) {
throw NormalError(strerror(errno)); // This should never happen
throw UnexpectedError(strerror(errno)); // LCOV_EXCL_LINE
} }
hostent* server = ::gethostbyname(server_name.c_str()); hostent* server = ::gethostbyname(server_name.c_str());
...@@ -182,11 +193,12 @@ void ...@@ -182,11 +193,12 @@ void
Socket::_write(const char* data, const size_t lenght) const Socket::_write(const char* data, const size_t lenght) const
{ {
if (this->m_internals->isServerSocket()) { if (this->m_internals->isServerSocket()) {
throw NormalError("Server cannot write to server socket!"); throw NormalError("Server cannot write to server socket");
} }
if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) { if (::write(this->m_internals->fileDescriptor(), data, lenght) < 0) {
throw NormalError(strerror(errno)); // Quite complex to test
throw NormalError(strerror(errno)); // LCOV_EXCL_LINE
} }
} }
...@@ -194,15 +206,15 @@ void ...@@ -194,15 +206,15 @@ void
Socket::_read(char* data, const size_t length) const Socket::_read(char* data, const size_t length) const
{ {
if (this->m_internals->isServerSocket()) { if (this->m_internals->isServerSocket()) {
throw NormalError("Server cannot read from server socket!"); throw NormalError("Server cannot read from server socket");
} }
size_t received = 0; size_t received = 0;
do { while (received < length) {
int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received); int n = ::read(this->m_internals->fileDescriptor(), reinterpret_cast<char*>(data) + received, length - received);
if (n <= 0) { if (n <= 0) {
throw NormalError("Could not read data"); throw NormalError("Could not read data");
} }
received += n; received += n;
} while (received < length); }
} }
...@@ -78,7 +78,9 @@ read(const Socket& socket, ArrayT<T, R...>& array) ...@@ -78,7 +78,9 @@ read(const Socket& socket, ArrayT<T, R...>& array)
{ {
static_assert(not std::is_const_v<T>, "cannot read values into const data"); static_assert(not std::is_const_v<T>, "cannot read values into const data");
static_assert(std::is_arithmetic_v<T> or is_tiny_vector_v<T> or is_tiny_matrix_v<T>, "unexpected value type"); static_assert(std::is_arithmetic_v<T> or is_tiny_vector_v<T> or is_tiny_matrix_v<T>, "unexpected value type");
if (array.size() > 0) {
socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char)); socket._read(reinterpret_cast<char*>(&array[0]), array.size() * sizeof(T) / sizeof(char));
} }
}
#endif // SOCKET_HPP #endif // SOCKET_HPP
...@@ -120,6 +120,7 @@ add_executable (unit_tests ...@@ -120,6 +120,7 @@ add_executable (unit_tests
test_RevisionInfo.cpp test_RevisionInfo.cpp
test_SmallArray.cpp test_SmallArray.cpp
test_SmallVector.cpp test_SmallVector.cpp
test_Socket.cpp
test_SquareGaussQuadrature.cpp test_SquareGaussQuadrature.cpp
test_SquareTransformation.cpp test_SquareTransformation.cpp
test_SymbolTable.cpp test_SymbolTable.cpp
......
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/Socket.hpp>
#include <sstream>
#include <thread>
#include <netdb.h>
// clazy:excludeall=non-pod-global-static
TEST_CASE("Socket", "[utils]")
{
SECTION("create/connect and simple read/write")
{
auto self_client = [](int port) {
Socket self_server = connectServerSocket("localhost", port);
std::vector<int> values = {1, 2, 4, 7};
write(self_server, values.size());
write(self_server, values);
};
Socket server = createServerSocket(0);
std::thread t(self_client, server.portNumber());
Socket client = acceptClientSocket(server);
size_t vector_size = [&] {
size_t vector_size;
read(client, vector_size);
return vector_size;
}();
REQUIRE(vector_size == 4);
std::vector<int> v(vector_size);
read(client, v);
REQUIRE(v == std::vector<int>{1, 2, 4, 7});
t.join();
}
SECTION("move constructor")
{
Socket server = createServerSocket(0);
int port_number = server.portNumber();
Socket moved_server(std::move(server));
REQUIRE(port_number == moved_server.portNumber());
}
SECTION("host and port info")
{
Socket server = createServerSocket(0);
int port_number = server.portNumber();
std::ostringstream info;
info << server;
std::ostringstream expected;
char hbuf[NI_MAXHOST];
::gethostname(hbuf, NI_MAXHOST);
expected << hbuf << ':' << port_number;
REQUIRE(info.str() == expected.str());
}
SECTION("errors")
{
SECTION("connection")
{
REQUIRE_THROWS_WITH(createServerSocket(1), "error: Permission denied");
REQUIRE_THROWS_WITH(connectServerSocket("localhost", 1), "error: Connection refused");
// The error message is not checked since it can depend on the
// network connection itself
REQUIRE_THROWS(connectServerSocket("an invalid host name", 1));
}
SECTION("server <-> server")
{
Socket server = createServerSocket(0);
REQUIRE_THROWS_WITH(write(server, 12), "error: Server cannot write to server socket");
REQUIRE_THROWS_WITH(write(server, std::vector<int>{1, 2, 3}), "error: Server cannot write to server socket");
double x;
REQUIRE_THROWS_WITH(read(server, x), "error: Server cannot read from server socket");
std::vector<double> v(1);
REQUIRE_THROWS_WITH(read(server, v), "error: Server cannot read from server socket");
}
SECTION("invalid read")
{
auto self_client = [](int port) { Socket self_server = connectServerSocket("localhost", port); };
Socket server = createServerSocket(0);
std::thread t(self_client, server.portNumber());
Socket client = acceptClientSocket(server);
t.join();
double x;
REQUIRE_THROWS_WITH(read(client, x), "error: Could not read data");
// One cannot test easily write errors...
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment