diff --git a/tests/test_Messenger.cpp b/tests/test_Messenger.cpp index eebd5dc939c23c7875c3b92ff697cb1f00afca63..4f76e250ddab60d81e1d49eacb1bcfc23425e317 100644 --- a/tests/test_Messenger.cpp +++ b/tests/test_Messenger.cpp @@ -1,6 +1,7 @@ #include <catch2/catch_test_macros.hpp> #include <catch2/matchers/catch_matchers_all.hpp> +#include <algebra/TinyVector.hpp> #include <utils/Array.hpp> #include <utils/Messenger.hpp> @@ -143,6 +144,14 @@ TEST_CASE("Messenger", "[mpi]") const bool or_value_2 = parallel::allReduceOr(parallel::rank() > 0); REQUIRE(or_value_2 == (parallel::size() > 1)); + + const size_t sum_value = parallel::allReduceSum(parallel::rank() + 1); + REQUIRE(sum_value == parallel::size() * (parallel::size() + 1) / 2); + + const TinyVector<2, size_t> sum_tiny_vector = + parallel::allReduceSum(TinyVector<2, size_t>(parallel::rank() + 1, 1)); + REQUIRE( + (sum_tiny_vector == TinyVector<2, size_t>{parallel::size() * (parallel::size() + 1) / 2, parallel::size()})); } SECTION("all to all")