diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index 88a8122d077518e250642d4a65cf6850ebf8f119..66e247d998c6a23f36bfc70eba03979be7e915e8 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -93,6 +93,14 @@ TEST_CASE("Messenger", "[mpi]") { REQUIRE(size == commSize()); } + SECTION("reduction") { + const int min_value = allReduceMin(commRank()+3); + REQUIRE(min_value ==3); + + const int max_value = allReduceMax(commRank()+3); + REQUIRE(max_value == ((commSize()-1) + 3)); + } + SECTION("all to all") { // chars mpi_check::test_allToAll<char>(); @@ -123,6 +131,16 @@ TEST_CASE("Messenger", "[mpi]") { // compound trivial type mpi_check::test_allToAll<mpi_check::tri_int>(); + +#ifndef NDEBUG + SECTION("checking invalid all to all") { + Array<int> invalid_all_to_all(commSize()+1); + REQUIRE_THROWS_AS(allToAll(invalid_all_to_all), AssertError); + + Array<int> different_size_all_to_all(commSize()*(commRank()+1)); + REQUIRE_THROWS_AS(allToAll(different_size_all_to_all), AssertError); + } +#endif // NDEBUG } SECTION("broadcast value") { @@ -182,4 +200,42 @@ TEST_CASE("Messenger", "[mpi]") { } } + SECTION("all gather value") { + { + // simple type + int value{(3+commRank())*2}; + Array<int> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i)*2)); + } + } + + { + // trivial simple type + mpi_check::integer value{(3+commRank())*2}; + Array<mpi_check::integer> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i)*2)); + } + } + + { + // compound trivial type + mpi_check::tri_int value{(3+commRank())*2, 2+commRank(), 4-commRank()}; + Array<mpi_check::tri_int> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + mpi_check::tri_int expected_value{static_cast<int>((3+i)*2), + static_cast<int>(2+i), + static_cast<int>(4-i)}; + REQUIRE((gather_array[i] == expected_value)); + } + } + } + }