Skip to content
Snippets Groups Projects
Commit 20e5e2a7 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add various tests

- all gather for arrays
- barrier
- add assertions in array exchanges
parent ee9030ec
No related branches found
No related tags found
1 merge request!11Feature/mpi
......@@ -34,8 +34,10 @@ Messenger(int& argc, char* argv[])
MPI_Comm_size(MPI_COMM_WORLD, &m_size);
if (m_rank != 0) {
// LCOV_EXCL_START
pout.setOutput(null_stream);
perr.setOutput(null_stream);
// LCOV_EXCL_STOP
}
#endif // PASTIS_HAS_MPI
}
......
......@@ -228,8 +228,10 @@ class Messenger
std::vector<MPI_Status> status_list(request_list.size());
if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
// LCOV_EXCL_START
std::cerr << "Communication error!\n";
std::exit(1);
// LCOV_EXCL_STOP
}
#else // PASTIS_HAS_MPI
......@@ -452,7 +454,7 @@ class Messenger
template <typename SendDataType,
typename RecvDataType>
PASTIS_INLINE
void exchange(const std::vector<Array<SendDataType>>& sent_array_list,
void exchange(const std::vector<Array<SendDataType>>& send_array_list,
std::vector<Array<RecvDataType>>& recv_array_list) const
{
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
......@@ -461,11 +463,26 @@ class Messenger
"receive data type cannot be const");
using DataType = std::remove_const_t<SendDataType>;
Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE
Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE
#ifndef NDEBUG
Array<size_t> send_size(m_size);
for (size_t i=0; i<m_size; ++i) {
send_size[i] = send_array_list[i].size();
}
Array<size_t> recv_size = allToAll(send_size);
bool correct_sizes = true;
for (size_t i=0; i<m_size; ++i) {
correct_sizes &= (recv_size[i] == recv_array_list[i].size());
}
Assert(correct_sizes); // LCOV_EXCL_LINE
#endif // NDEBUG
if constexpr(std::is_arithmetic_v<DataType>) {
_exchange(sent_array_list, recv_array_list);
_exchange(send_array_list, recv_array_list);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
_exchange_through_cast<SendDataType, CastType>(sent_array_list, recv_array_list);
_exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list);
} else {
static_assert(std::is_trivial_v<RecvDataType>,
"unexpected non trivial type of data");
......
......@@ -134,12 +134,14 @@ TEST_CASE("Messenger", "[mpi]") {
#ifndef NDEBUG
SECTION("checking invalid all to all") {
if (parallel::commSize() > 1) {
Array<int> invalid_all_to_all(parallel::commSize()+1);
REQUIRE_THROWS_AS(parallel::allToAll(invalid_all_to_all), AssertError);
Array<int> different_size_all_to_all(parallel::commSize()*(parallel::commRank()+1));
REQUIRE_THROWS_AS(parallel::allToAll(different_size_all_to_all), AssertError);
}
}
#endif // NDEBUG
}
......@@ -243,4 +245,182 @@ TEST_CASE("Messenger", "[mpi]") {
}
}
SECTION("all gather array") {
{
// simple type
Array<int> array(3);
for (size_t i=0; i<array.size(); ++i) {
array[i] = (3+parallel::commRank())*2+i;
}
Array<int> gather_array = parallel::allGather(array);
REQUIRE(gather_array.size() == array.size()*parallel::commSize());
for (size_t i=0; i<gather_array.size(); ++i) {
REQUIRE((gather_array[i] == (3+i/array.size())*2+(i%array.size())));
}
}
{
// trivial simple type
Array<mpi_check::integer> array(3);
for (size_t i=0; i<array.size(); ++i) {
array[i] = (3+parallel::commRank())*2+i;
}
Array<mpi_check::integer> gather_array = parallel::allGather(array);
REQUIRE(gather_array.size() == array.size()*parallel::commSize());
for (size_t i=0; i<gather_array.size(); ++i) {
REQUIRE((gather_array[i] == (3+i/array.size())*2+(i%array.size())));
}
}
{
// compound trivial type
Array<mpi_check::tri_int> array(3);
for (size_t i=0; i<array.size(); ++i) {
array[i] = mpi_check::tri_int{static_cast<int>((3+parallel::commRank())*2),
static_cast<int>(2+parallel::commRank()+i),
static_cast<int>(4-parallel::commRank()-i)};
}
Array<mpi_check::tri_int> gather_array
= parallel::allGather(array);
REQUIRE(gather_array.size() == array.size()*parallel::commSize());
for (size_t i=0; i<gather_array.size(); ++i) {
mpi_check::tri_int expected_value{static_cast<int>((3+i/array.size())*2),
static_cast<int>(2+i/array.size()+(i%array.size())),
static_cast<int>(4-i/array.size()-(i%array.size()))};
REQUIRE((gather_array[i] == expected_value));
}
}
}
SECTION("all array exchanges") {
{ // simple type
std::vector<Array<const int>> send_array_list(parallel::commSize());
for (size_t i=0; i<send_array_list.size(); ++i) {
Array<int> send_array(i+1);
for (size_t j=0; j<send_array.size(); ++j) {
send_array[j] = (parallel::commRank()+1)*j;
}
send_array_list[i] = send_array;
}
std::vector<Array<int>> recv_array_list(parallel::commSize());
for (size_t i=0; i<recv_array_list.size(); ++i) {
recv_array_list[i] = Array<int>(parallel::commRank()+1);
}
parallel::exchange(send_array_list, recv_array_list);
for (size_t i=0; i<parallel::commSize(); ++i) {
const Array<const int> recv_array = recv_array_list[i];
for (size_t j=0; j<recv_array.size(); ++j) {
REQUIRE(recv_array[j] == (i+1)*j);
}
}
}
{ // trivial simple type
std::vector<Array<mpi_check::integer>> send_array_list(parallel::commSize());
for (size_t i=0; i<send_array_list.size(); ++i) {
Array<mpi_check::integer> send_array(i+1);
for (size_t j=0; j<send_array.size(); ++j) {
send_array[j] = static_cast<int>((parallel::commRank()+1)*j);
}
send_array_list[i] = send_array;
}
std::vector<Array<mpi_check::integer>> recv_array_list(parallel::commSize());
for (size_t i=0; i<recv_array_list.size(); ++i) {
recv_array_list[i] = Array<mpi_check::integer>(parallel::commRank()+1);
}
parallel::exchange(send_array_list, recv_array_list);
for (size_t i=0; i<parallel::commSize(); ++i) {
const Array<const mpi_check::integer> recv_array = recv_array_list[i];
for (size_t j=0; j<recv_array.size(); ++j) {
REQUIRE(recv_array[j] == (i+1)*j);
}
}
}
{
// compound trivial type
std::vector<Array<mpi_check::tri_int>> send_array_list(parallel::commSize());
for (size_t i=0; i<send_array_list.size(); ++i) {
Array<mpi_check::tri_int> send_array(i+1);
for (size_t j=0; j<send_array.size(); ++j) {
send_array[j] = mpi_check::tri_int{static_cast<int>((parallel::commRank()+1)*j),
static_cast<int>(parallel::commRank()),
static_cast<int>(j)};
}
send_array_list[i] = send_array;
}
std::vector<Array<mpi_check::tri_int>> recv_array_list(parallel::commSize());
for (size_t i=0; i<recv_array_list.size(); ++i) {
recv_array_list[i] = Array<mpi_check::tri_int>(parallel::commRank()+1);
}
parallel::exchange(send_array_list, recv_array_list);
for (size_t i=0; i<parallel::commSize(); ++i) {
const Array<const mpi_check::tri_int> recv_array = recv_array_list[i];
for (size_t j=0; j<recv_array.size(); ++j) {
mpi_check::tri_int expected_value{static_cast<int>((i+1)*j),
static_cast<int>(i),
static_cast<int>(j)};
REQUIRE((recv_array[j] == expected_value));
}
}
}
}
#ifndef NDEBUG
SECTION("checking all array exchange invalid sizes") {
std::vector<Array<const int>> send_array_list(parallel::commSize());
for (size_t i=0; i<send_array_list.size(); ++i) {
Array<int> send_array(i+1);
send_array.fill(parallel::commRank());
send_array_list[i] = send_array;
}
std::vector<Array<int>> recv_array_list(parallel::commSize());
REQUIRE_THROWS_AS(parallel::exchange(send_array_list, recv_array_list), AssertError);
}
#endif // NDEBUG
SECTION("checking barrier") {
for (size_t i=0; i<parallel::commSize(); ++i) {
if (i==parallel::commRank()) {
std::ofstream file;
if (i==0) {
file.open("barrier_test", std::ios_base::out);
} else {
file.open("barrier_test", std::ios_base::app);
}
file << i << "\n" << std::flush;
}
parallel::barrier();
}
{ // reading produced file
std::ifstream file("barrier_test");
std::vector<size_t> number_list;
while (file) {
size_t value;
file >> value;
if (file) {
number_list.push_back(value);
}
}
REQUIRE(number_list.size() == parallel::commSize());
for (size_t i=0; i<number_list.size(); ++i) {
REQUIRE(number_list[i] == i);
}
}
parallel::barrier();
std::remove("barrier_test");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment