From 67f8366526ad113eb5558d591280c1e093fcb558 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 17 Oct 2018 19:05:39 +0200
Subject: [PATCH] Write broadcast of any trivial type values

Remind that TinyVector, TinyMatrix, CellType,... are trivial types
---
 src/utils/Messenger.cpp |  9 ---------
 src/utils/Messenger.hpp | 41 +++++++++++++++++++++++++++++++++++------
 2 files changed, 35 insertions(+), 15 deletions(-)

diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp
index 936c81e20..f627f016b 100644
--- a/src/utils/Messenger.cpp
+++ b/src/utils/Messenger.cpp
@@ -66,12 +66,3 @@ _allGather(int& data) const
 
   return gather;
 }
-
-int Messenger::
-_broadcast_value(int& data, int root_rank) const
-{
-#ifdef PASTIS_HAS_MPI
-  MPI_Bcast(&data, 1,  MPI_INT, root_rank, MPI_COMM_WORLD);
-#endif // PASTIS_HAS_MPI
-  return data;
-}
diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp
index 0c56c2631..c316a4887 100644
--- a/src/utils/Messenger.hpp
+++ b/src/utils/Messenger.hpp
@@ -84,13 +84,26 @@ class Messenger
 
   Array<int> _allGather(int& data) const;
 
-  int _broadcast_value(int& data, int root_rank) const;
+  template <typename DataType>
+  void _broadcast_value(DataType& data, int root_rank) const
+  {
+#ifdef PASTIS_HAS_MPI
+    static_assert(not std::is_const_v<DataType>);
+    static_assert(std::is_arithmetic_v<DataType>);
+
+    MPI_Datatype mpi_datatype
+        = Messenger::helper::mpiType<DataType>();
+
+    MPI_Bcast(&data, 1,  mpi_datatype, root_rank, MPI_COMM_WORLD);
+#endif // PASTIS_HAS_MPI
+  }
 
   template <typename ArrayType>
   void _broadcast_array(ArrayType& array, int root_rank) const
   {
     using DataType = typename ArrayType::data_type;
     static_assert(not std::is_const_v<DataType>);
+    static_assert(std::is_arithmetic_v<DataType>);
 
 #ifdef PASTIS_HAS_MPI
     MPI_Datatype mpi_datatype
@@ -260,10 +273,26 @@ class Messenger
   PASTIS_INLINE
   void broadcast(DataType& data, int root_rank) const
   {
-    if constexpr(std::is_same<DataType, int>()) {
-      return _broadcast_value(data, root_rank);
+    static_assert(not std::is_const_v<DataType>,
+                  "cannot broadcast const data");
+    if constexpr(std::is_arithmetic_v<DataType>) {
+      _broadcast_value(data, root_rank);
+    } else if constexpr(std::is_trivial_v<DataType>) {
+      using CastType = helper::split_cast_t<DataType>;
+      if constexpr(sizeof(CastType) == sizeof(DataType)) {
+        CastType& cast_data = reinterpret_cast<CastType&>(data);
+        _broadcast_value(cast_data, root_rank);
+      } else {
+#ifdef PASTIS_HAS_MPI
+    MPI_Datatype mpi_datatype
+        = Messenger::helper::mpiType<CastType>();
+    MPI_Bcast(reinterpret_cast<CastType*>(&data), sizeof(DataType)/sizeof(CastType),
+              mpi_datatype, root_rank, MPI_COMM_WORLD);
+#endif // PASTIS_HAS_MPI
+      }
     } else {
-      static_assert(std::is_same<DataType, int>(), "unexpected type of data");
+      static_assert(std::is_trivial_v<DataType>,
+                    "unexpected non trivial type of data");
     }
   }
 
@@ -351,9 +380,9 @@ void barrier()
 
 template <typename DataType>
 PASTIS_INLINE
-DataType broadcast(const DataType& data, int root_rank)
+void broadcast(DataType& data, int root_rank)
 {
-  return messenger().broadcast(data, root_rank);
+  messenger().broadcast(data, root_rank);
 }
 
 template <typename DataType>
-- 
GitLab