From 743a0d8f7152e61422a94179a946cfe06ef3ca1c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Mon, 25 Jul 2022 00:05:31 +0200
Subject: [PATCH] Change thread allocation policy for tests to avoid
 oversubscription

Oversubscription can degrade parallel performances substantially
---
 tests/mpi_test_main.cpp | 10 +++++++++-
 tests/test_main.cpp     |  8 +++++++-
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/tests/mpi_test_main.cpp b/tests/mpi_test_main.cpp
index 51b222c08..10314a4dc 100644
--- a/tests/mpi_test_main.cpp
+++ b/tests/mpi_test_main.cpp
@@ -18,12 +18,17 @@
 
 #include <cstdlib>
 #include <filesystem>
+#include <thread>
 
 int
 main(int argc, char* argv[])
 {
   parallel::Messenger::create(argc, argv);
-  Kokkos::initialize({4, -1, -1, true});
+
+  const int nb_max_threads = std::max(std::thread::hardware_concurrency(), 1u);
+  const int nb_threads     = std::max(nb_max_threads / parallel::Messenger::getInstance().size(), 1ul);
+
+  Kokkos::initialize({nb_threads, -1, -1, true});
 
   PETScWrapper::initialize(argc, argv);
 
@@ -56,6 +61,9 @@ main(int argc, char* argv[])
         session.useConfigData(data);
       }
 
+      std::cout << "Using " << nb_threads << " threads per process [" << nb_threads << "x"
+                << parallel::Messenger::getInstance().size() << "]\n";
+
       // Disable outputs from tested classes to the standard output
       std::cout.setstate(std::ios::badbit);
 
diff --git a/tests/test_main.cpp b/tests/test_main.cpp
index 91d9a080a..3eff2f311 100644
--- a/tests/test_main.cpp
+++ b/tests/test_main.cpp
@@ -15,11 +15,15 @@
 
 #include <MeshDataBaseForTests.hpp>
 
+#include <thread>
+
 int
 main(int argc, char* argv[])
 {
   parallel::Messenger::create(argc, argv);
-  Kokkos::initialize({4, -1, -1, true});
+  const int nb_threads = std::max(std::thread::hardware_concurrency(), 1u);
+
+  Kokkos::initialize({nb_threads, -1, -1, true});
 
   PETScWrapper::initialize(argc, argv);
   SLEPcWrapper::initialize(argc, argv);
@@ -32,6 +36,8 @@ main(int argc, char* argv[])
     if (config.listReporters() or config.listTags() or config.listTests()) {
       result = session.run();
     } else {
+      std::cout << "Using " << nb_threads << " threads\n";
+
       // Disable outputs from tested classes to the standard output
       std::cout.setstate(std::ios::badbit);
 
-- 
GitLab