diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index e5e9668c27776aaf0497040021b49516f7cec8a9..b4de28efb6263867adeaa03147def6c25aa5dc21 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -34,8 +34,10 @@ Messenger(int& argc, char* argv[]) MPI_Comm_size(MPI_COMM_WORLD, &m_size); if (m_rank != 0) { + // LCOV_EXCL_START pout.setOutput(null_stream); perr.setOutput(null_stream); + // LCOV_EXCL_STOP } #endif // PASTIS_HAS_MPI } diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index b27a952c62992b373e9effe306a9f8de7f3acfe3..bc669cfd4153311480230859b69afcd5ed6376ea 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -228,8 +228,10 @@ class Messenger std::vector<MPI_Status> status_list(request_list.size()); if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { + // LCOV_EXCL_START std::cerr << "Communication error!\n"; std::exit(1); + // LCOV_EXCL_STOP } #else // PASTIS_HAS_MPI @@ -452,7 +454,7 @@ class Messenger template <typename SendDataType, typename RecvDataType> PASTIS_INLINE - void exchange(const std::vector<Array<SendDataType>>& sent_array_list, + void exchange(const std::vector<Array<SendDataType>>& send_array_list, std::vector<Array<RecvDataType>>& recv_array_list) const { static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>, @@ -461,11 +463,26 @@ class Messenger "receive data type cannot be const"); using DataType = std::remove_const_t<SendDataType>; + Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE + Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE +#ifndef NDEBUG + Array<size_t> send_size(m_size); + for (size_t i=0; i<m_size; ++i) { + send_size[i] = send_array_list[i].size(); + } + Array<size_t> recv_size = allToAll(send_size); + bool correct_sizes = true; + for (size_t i=0; i<m_size; ++i) { + correct_sizes &= (recv_size[i] == recv_array_list[i].size()); + } + Assert(correct_sizes); // LCOV_EXCL_LINE +#endif // NDEBUG + if constexpr(std::is_arithmetic_v<DataType>) { - _exchange(sent_array_list, recv_array_list); + _exchange(send_array_list, recv_array_list); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; - _exchange_through_cast<SendDataType, CastType>(sent_array_list, recv_array_list); + _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list); } else { static_assert(std::is_trivial_v<RecvDataType>, "unexpected non trivial type of data"); diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index 2837e0830a0ccc1e097651e82b1a7d50aea0e539..081ab8ef031c65cd31dac820ba60527713a4a569 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -134,11 +134,13 @@ TEST_CASE("Messenger", "[mpi]") { #ifndef NDEBUG SECTION("checking invalid all to all") { - Array<int> invalid_all_to_all(parallel::commSize()+1); - REQUIRE_THROWS_AS(parallel::allToAll(invalid_all_to_all), AssertError); + if (parallel::commSize() > 1) { + Array<int> invalid_all_to_all(parallel::commSize()+1); + REQUIRE_THROWS_AS(parallel::allToAll(invalid_all_to_all), AssertError); - Array<int> different_size_all_to_all(parallel::commSize()*(parallel::commRank()+1)); - REQUIRE_THROWS_AS(parallel::allToAll(different_size_all_to_all), AssertError); + Array<int> different_size_all_to_all(parallel::commSize()*(parallel::commRank()+1)); + REQUIRE_THROWS_AS(parallel::allToAll(different_size_all_to_all), AssertError); + } } #endif // NDEBUG } @@ -243,4 +245,182 @@ TEST_CASE("Messenger", "[mpi]") { } } + SECTION("all gather array") { + { + // simple type + Array<int> array(3); + for (size_t i=0; i<array.size(); ++i) { + array[i] = (3+parallel::commRank())*2+i; + } + Array<int> gather_array = parallel::allGather(array); + REQUIRE(gather_array.size() == array.size()*parallel::commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i/array.size())*2+(i%array.size()))); + } + } + + { + // trivial simple type + Array<mpi_check::integer> array(3); + for (size_t i=0; i<array.size(); ++i) { + array[i] = (3+parallel::commRank())*2+i; + } + Array<mpi_check::integer> gather_array = parallel::allGather(array); + REQUIRE(gather_array.size() == array.size()*parallel::commSize()); + + for (size_t i=0; i<gather_array.size(); ++i) { + REQUIRE((gather_array[i] == (3+i/array.size())*2+(i%array.size()))); + } + } + + { + // compound trivial type + Array<mpi_check::tri_int> array(3); + for (size_t i=0; i<array.size(); ++i) { + array[i] = mpi_check::tri_int{static_cast<int>((3+parallel::commRank())*2), + static_cast<int>(2+parallel::commRank()+i), + static_cast<int>(4-parallel::commRank()-i)}; + } + Array<mpi_check::tri_int> gather_array + = parallel::allGather(array); + + REQUIRE(gather_array.size() == array.size()*parallel::commSize()); + for (size_t i=0; i<gather_array.size(); ++i) { + mpi_check::tri_int expected_value{static_cast<int>((3+i/array.size())*2), + static_cast<int>(2+i/array.size()+(i%array.size())), + static_cast<int>(4-i/array.size()-(i%array.size()))}; + REQUIRE((gather_array[i] == expected_value)); + } + } + } + + SECTION("all array exchanges") { + { // simple type + std::vector<Array<const int>> send_array_list(parallel::commSize()); + for (size_t i=0; i<send_array_list.size(); ++i) { + Array<int> send_array(i+1); + for (size_t j=0; j<send_array.size(); ++j) { + send_array[j] = (parallel::commRank()+1)*j; + } + send_array_list[i] = send_array; + } + + std::vector<Array<int>> recv_array_list(parallel::commSize()); + for (size_t i=0; i<recv_array_list.size(); ++i) { + recv_array_list[i] = Array<int>(parallel::commRank()+1); + } + parallel::exchange(send_array_list, recv_array_list); + + for (size_t i=0; i<parallel::commSize(); ++i) { + const Array<const int> recv_array = recv_array_list[i]; + for (size_t j=0; j<recv_array.size(); ++j) { + REQUIRE(recv_array[j] == (i+1)*j); + } + } + } + + { // trivial simple type + std::vector<Array<mpi_check::integer>> send_array_list(parallel::commSize()); + for (size_t i=0; i<send_array_list.size(); ++i) { + Array<mpi_check::integer> send_array(i+1); + for (size_t j=0; j<send_array.size(); ++j) { + send_array[j] = static_cast<int>((parallel::commRank()+1)*j); + } + send_array_list[i] = send_array; + } + + std::vector<Array<mpi_check::integer>> recv_array_list(parallel::commSize()); + for (size_t i=0; i<recv_array_list.size(); ++i) { + recv_array_list[i] = Array<mpi_check::integer>(parallel::commRank()+1); + } + parallel::exchange(send_array_list, recv_array_list); + + for (size_t i=0; i<parallel::commSize(); ++i) { + const Array<const mpi_check::integer> recv_array = recv_array_list[i]; + for (size_t j=0; j<recv_array.size(); ++j) { + REQUIRE(recv_array[j] == (i+1)*j); + } + } + } + + { + // compound trivial type + std::vector<Array<mpi_check::tri_int>> send_array_list(parallel::commSize()); + for (size_t i=0; i<send_array_list.size(); ++i) { + Array<mpi_check::tri_int> send_array(i+1); + for (size_t j=0; j<send_array.size(); ++j) { + send_array[j] = mpi_check::tri_int{static_cast<int>((parallel::commRank()+1)*j), + static_cast<int>(parallel::commRank()), + static_cast<int>(j)}; + } + send_array_list[i] = send_array; + } + + std::vector<Array<mpi_check::tri_int>> recv_array_list(parallel::commSize()); + for (size_t i=0; i<recv_array_list.size(); ++i) { + recv_array_list[i] = Array<mpi_check::tri_int>(parallel::commRank()+1); + } + parallel::exchange(send_array_list, recv_array_list); + + for (size_t i=0; i<parallel::commSize(); ++i) { + const Array<const mpi_check::tri_int> recv_array = recv_array_list[i]; + for (size_t j=0; j<recv_array.size(); ++j) { + mpi_check::tri_int expected_value{static_cast<int>((i+1)*j), + static_cast<int>(i), + static_cast<int>(j)}; + REQUIRE((recv_array[j] == expected_value)); + } + } + } + } + +#ifndef NDEBUG + SECTION("checking all array exchange invalid sizes") { + std::vector<Array<const int>> send_array_list(parallel::commSize()); + for (size_t i=0; i<send_array_list.size(); ++i) { + Array<int> send_array(i+1); + send_array.fill(parallel::commRank()); + send_array_list[i] = send_array; + } + + std::vector<Array<int>> recv_array_list(parallel::commSize()); + REQUIRE_THROWS_AS(parallel::exchange(send_array_list, recv_array_list), AssertError); + } +#endif // NDEBUG + + + SECTION("checking barrier") { + for (size_t i=0; i<parallel::commSize(); ++i) { + if (i==parallel::commRank()) { + std::ofstream file; + if (i==0) { + file.open("barrier_test", std::ios_base::out); + } else { + file.open("barrier_test", std::ios_base::app); + } + file << i << "\n" << std::flush; + } + parallel::barrier(); + } + + { // reading produced file + std::ifstream file("barrier_test"); + std::vector<size_t> number_list; + while (file) { + size_t value; + file >> value; + if (file) { + number_list.push_back(value); + } + } + REQUIRE(number_list.size() == parallel::commSize()); + for (size_t i=0; i<number_list.size(); ++i) { + REQUIRE(number_list[i] == i); + } + } + parallel::barrier(); + + std::remove("barrier_test"); + } }