From 8cbb6ed25cc9dabd4007654002a579e0f6ba2d24 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 24 Oct 2018 11:20:51 +0200 Subject: [PATCH] Add all gather and reduce min/max tests --- tests/mpi_test_Messenger.cpp | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index 88a8122d0..66e247d99 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -93,6 +93,14 @@ TEST_CASE("Messenger", "[mpi]") { REQUIRE(size == commSize()); } + SECTION("reduction") { + const int min_value = allReduceMin(commRank()+3); + REQUIRE(min_value ==3); + + const int max_value = allReduceMax(commRank()+3); + REQUIRE(max_value == ((commSize()-1) + 3)); + } + SECTION("all to all") { // chars mpi_check::test_allToAll<char>(); @@ -123,6 +131,16 @@ TEST_CASE("Messenger", "[mpi]") { // compound trivial type mpi_check::test_allToAll<mpi_check::tri_int>(); + +#ifndef NDEBUG + SECTION("checking invalid all to all") { + Array<int> invalid_all_to_all(commSize()+1); + REQUIRE_THROWS_AS(allToAll(invalid_all_to_all), AssertError); + + Array<int> different_size_all_to_all(commSize()*(commRank()+1)); + REQUIRE_THROWS_AS(allToAll(different_size_all_to_all), AssertError); + } +#endif // NDEBUG } SECTION("broadcast value") { @@ -182,4 +200,42 @@ TEST_CASE("Messenger", "[mpi]") { } } + SECTION("all gather value") { + { + // simple type + int value{(3+commRank())*2}; + Array<int> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i)*2)); + } + } + + { + // trivial simple type + mpi_check::integer value{(3+commRank())*2}; + Array<mpi_check::integer> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i)*2)); + } + } + + { + // compound trivial type + mpi_check::tri_int value{(3+commRank())*2, 2+commRank(), 4-commRank()}; + Array<mpi_check::tri_int> gather_array = allGather(value); + REQUIRE(gather_array.size() == commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + mpi_check::tri_int expected_value{static_cast<int>((3+i)*2), + static_cast<int>(2+i), + static_cast<int>(4-i)}; + REQUIRE((gather_array[i] == expected_value)); + } + } + } + } -- GitLab