diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index 47bf8d50e6afd1ed4b65a1577c63f3d74d50bb13..d422cf20fa41c24908d477cfca87cf598400a61c 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -14,6 +14,27 @@ namespace mpi_check { +struct integer +{ + int m_int; + operator int&() {return m_int;} + operator const int&() const {return m_int;} + integer& operator=(const int& i) {m_int = i; return *this;} +}; + +struct tri_int +{ + int m_int_0; + int m_int_1; + int m_int_2; + bool operator==(const tri_int& t) const { + return ((m_int_0 == t.m_int_0) and + (m_int_1 == t.m_int_1) and + (m_int_2 == t.m_int_2)); + } +}; + + template <typename T> void test_allToAll() { @@ -27,6 +48,37 @@ void test_allToAll() REQUIRE(exchanged_array[i] == i); } } + +template <> +void test_allToAll<bool>() +{ + Array<bool> data_array(commSize()); + for (size_t i=0; i< data_array.size(); ++i) { + data_array[i] = ((commRank()%2)==0); + } + auto exchanged_array = allToAll(data_array); + + for (size_t i=0; i< data_array.size(); ++i) { + REQUIRE(exchanged_array[i] == ((i%2)==0)); + } +} + +template <> +void test_allToAll<tri_int>() +{ + Array<tri_int> data_array(commSize()); + for (size_t i=0; i< data_array.size(); ++i) { + const int val = 1+commRank(); + data_array[i] = tri_int{val, 2*val, val+3 }; + } + auto exchanged_array = allToAll(data_array); + + for (size_t i=0; i< data_array.size(); ++i) { + const int val = 1+i; + REQUIRE(exchanged_array[i] == tri_int{val, 2*val, val+3 }); + } +} + } TEST_CASE("Messenger", "[mpi]") { @@ -42,6 +94,34 @@ TEST_CASE("Messenger", "[mpi]") { } SECTION("allToAll") { + // chars + mpi_check::test_allToAll<char>(); + mpi_check::test_allToAll<wchar_t>(); + + // integers mpi_check::test_allToAll<int8_t>(); + mpi_check::test_allToAll<int16_t>(); + mpi_check::test_allToAll<int32_t>(); + mpi_check::test_allToAll<int64_t>(); + mpi_check::test_allToAll<uint8_t>(); + mpi_check::test_allToAll<uint16_t>(); + mpi_check::test_allToAll<uint32_t>(); + mpi_check::test_allToAll<uint64_t>(); + mpi_check::test_allToAll<signed long long int>(); + mpi_check::test_allToAll<unsigned long long int>(); + + // floats + mpi_check::test_allToAll<float>(); + mpi_check::test_allToAll<double>(); + mpi_check::test_allToAll<long double>(); + + // bools + mpi_check::test_allToAll<bool>(); + + // trivial simple type + mpi_check::test_allToAll<mpi_check::integer>(); + + // compound trivial type + mpi_check::test_allToAll<mpi_check::tri_int>(); } }