From 8cbb6ed25cc9dabd4007654002a579e0f6ba2d24 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 24 Oct 2018 11:20:51 +0200
Subject: [PATCH] Add all gather and reduce min/max tests

---
 tests/mpi_test_Messenger.cpp | 56 ++++++++++++++++++++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp
index 88a8122d0..66e247d99 100644
--- a/tests/mpi_test_Messenger.cpp
+++ b/tests/mpi_test_Messenger.cpp
@@ -93,6 +93,14 @@ TEST_CASE("Messenger", "[mpi]") {
     REQUIRE(size == commSize());
   }
 
+  SECTION("reduction") {
+    const int min_value = allReduceMin(commRank()+3);
+    REQUIRE(min_value ==3);
+
+    const int max_value = allReduceMax(commRank()+3);
+    REQUIRE(max_value == ((commSize()-1) + 3));
+  }
+
   SECTION("all to all") {
     // chars
     mpi_check::test_allToAll<char>();
@@ -123,6 +131,16 @@ TEST_CASE("Messenger", "[mpi]") {
 
     // compound trivial type
     mpi_check::test_allToAll<mpi_check::tri_int>();
+
+#ifndef NDEBUG
+    SECTION("checking invalid all to all") {
+      Array<int> invalid_all_to_all(commSize()+1);
+      REQUIRE_THROWS_AS(allToAll(invalid_all_to_all), AssertError);
+
+      Array<int> different_size_all_to_all(commSize()*(commRank()+1));
+      REQUIRE_THROWS_AS(allToAll(different_size_all_to_all), AssertError);
+    }
+#endif // NDEBUG
   }
 
   SECTION("broadcast value") {
@@ -182,4 +200,42 @@ TEST_CASE("Messenger", "[mpi]") {
     }
   }
 
+  SECTION("all gather value") {
+    {
+      // simple type
+      int value{(3+commRank())*2};
+      Array<int> gather_array = allGather(value);
+      REQUIRE(gather_array.size() == commSize());
+
+      for (size_t i=0; i<gather_array.size(); ++i) {
+        REQUIRE((gather_array[i] == (3+i)*2));
+      }
+    }
+
+    {
+      // trivial simple type
+      mpi_check::integer value{(3+commRank())*2};
+      Array<mpi_check::integer> gather_array = allGather(value);
+      REQUIRE(gather_array.size() == commSize());
+
+      for (size_t i=0; i<gather_array.size(); ++i) {
+        REQUIRE((gather_array[i] == (3+i)*2));
+      }
+    }
+
+    {
+      // compound trivial type
+      mpi_check::tri_int value{(3+commRank())*2, 2+commRank(), 4-commRank()};
+      Array<mpi_check::tri_int> gather_array = allGather(value);
+      REQUIRE(gather_array.size() == commSize());
+
+      for (size_t i=0; i<gather_array.size(); ++i) {
+        mpi_check::tri_int expected_value{static_cast<int>((3+i)*2),
+                                          static_cast<int>(2+i),
+                                          static_cast<int>(4-i)};
+        REQUIRE((gather_array[i] == expected_value));
+      }
+    }
+  }
+
 }
-- 
GitLab