diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index d39c18c19426d88620d1a3724970845b9f904279..47bf8d50e6afd1ed4b65a1577c63f3d74d50bb13 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -1,19 +1,47 @@ #include <catch.hpp> #include <Messenger.hpp> +#include <Array.hpp> + +#include <pastis_config.hpp> + +#ifdef PASTIS_HAS_MPI #include <mpi.h> +#define IF_MPI(INSTRUCTION) INSTRUCTION +#else +#define IF_MPI(INSTRUCTION) +#endif // PASTIS_HAS_MPI + +namespace mpi_check +{ +template <typename T> +void test_allToAll() +{ + Array<T> data_array(commSize()); + for (size_t i=0; i< data_array.size(); ++i) { + data_array[i] = commRank(); + } + auto exchanged_array = allToAll(data_array); + + for (size_t i=0; i< data_array.size(); ++i) { + REQUIRE(exchanged_array[i] == i); + } +} +} TEST_CASE("Messenger", "[mpi]") { - SECTION("check for parallel test") { - { - int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - REQUIRE(rank == commRank()); + SECTION("communication info") { + int rank=0; + IF_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + REQUIRE(rank == commRank()); + + int size=1; + IF_MPI(MPI_Comm_size(MPI_COMM_WORLD, &size)); + REQUIRE(size == commSize()); + } - int size; - MPI_Comm_size(MPI_COMM_WORLD, &size); - REQUIRE(size == commSize()); - } + SECTION("allToAll") { + mpi_check::test_allToAll<int8_t>(); } }