Skip to content
Snippets Groups Projects
Select Git revision
  • db11c008a584aa4fe9b122b63a7a91abb77939bb
  • develop default protected
  • feature/variational-hydro
  • origin/stage/bouguettaia
  • feature/gmsh-reader
  • feature/reconstruction
  • save_clemence
  • feature/kinetic-schemes
  • feature/local-dt-fsi
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master protected
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

StencilBuilder.cpp

Blame
  • Messenger.hpp 19.52 KiB
    #ifndef MESSENGER_HPP
    #define MESSENGER_HPP
    
    #include <PastisMacros.hpp>
    #include <PastisAssert.hpp>
    
    #include <Array.hpp>
    #include <CastArray.hpp>
    #include <ArrayUtils.hpp>
    
    #include <type_traits>
    #include <vector>
    
    #include <pastis_config.hpp>
    #ifdef PASTIS_HAS_MPI
    #include <mpi.h>
    #endif // PASTIS_HAS_MPI
    
    namespace parallel
    {
    
    class Messenger
    {
     private:
      struct helper
      {
    #ifdef PASTIS_HAS_MPI
        template<typename DataType>
        static PASTIS_INLINE
        MPI_Datatype mpiType()
        {
          if constexpr (std::is_const_v<DataType>) {
            return mpiType<std::remove_const_t<DataType>>();
          } else {
            static_assert(std::is_arithmetic_v<DataType>,
                          "Unexpected arithmetic type! Should not occur!");
            static_assert(not std::is_arithmetic_v<DataType>,
                          "MPI_Datatype are only defined for arithmetic types!");
            return MPI_Datatype();
          }
        }
    #endif // PASTIS_HAS_MPI
    
       private:
        template <typename T,
                  typename Allowed = void>
        struct split_cast {};
    
        template <typename T>
        struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>> {
          using type = int64_t;
          static_assert(not(sizeof(T) % sizeof(int64_t)));
        };
    
        template <typename T>
        struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int32_t))
                                             and(sizeof(T) % sizeof(int64_t))>> {
          using type = int32_t;
          static_assert(not(sizeof(T) % sizeof(int32_t)));
        };
    
        template <typename T>
        struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int16_t))
                                             and(sizeof(T) % sizeof(int32_t))
                                             and(sizeof(T) % sizeof(int64_t))>> {
          using type = int16_t;
          static_assert(not(sizeof(T) % sizeof(int16_t)));
        };
    
        template <typename T>
        struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int8_t))
                                             and(sizeof(T) % sizeof(int16_t))
                                             and(sizeof(T) % sizeof(int32_t))
                                             and(sizeof(T) % sizeof(int64_t))>> {
          using type = int8_t;
          static_assert(not(sizeof(T) % sizeof(int8_t)));
        };
    
       public:
        template <typename T>
        using split_cast_t = typename split_cast<T>::type;
      };
    
      static Messenger* m_instance;
      Messenger(int& argc, char* argv[]);
    
      size_t m_rank{0};
      size_t m_size{1};
    
      template <typename DataType>
      void _allGather(const DataType& data,
                      Array<DataType> gather) const
      {
        static_assert(std::is_arithmetic_v<DataType>);
        Assert(gather.size() == m_size); // LCOV_EXCL_LINE
    
    #ifdef PASTIS_HAS_MPI
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<DataType>();
    
        MPI_Allgather(&data, 1,  mpi_datatype,
                      &(gather[0]), 1, mpi_datatype,
                      MPI_COMM_WORLD);
    #else // PASTIS_HAS_MPI
        gather[0] = data;
    #endif // PASTIS_HAS_MPI
      }
    
    
    
      template <template <typename ...SendT> typename SendArrayType,
                template <typename ...RecvT> typename RecvArrayType,
                typename ...SendT, typename ...RecvT>
      void _allGather(const SendArrayType<SendT...>& data_array,
                      RecvArrayType<RecvT...> gather_array) const
      {
        Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE
    
        using SendDataType = typename SendArrayType<SendT...>::data_type;
        using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
    
        static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
        static_assert(std::is_arithmetic_v<SendDataType>);
    
    #ifdef PASTIS_HAS_MPI
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<RecvDataType>();
    
        MPI_Allgather(&(data_array[0]), data_array.size(),  mpi_datatype,
                      &(gather_array[0]), data_array.size(),  mpi_datatype,
                      MPI_COMM_WORLD);
    #else // PASTIS_HAS_MPI
        value_copy(data_array, gather_array);
    #endif // PASTIS_HAS_MPI
      }
    
      template <typename DataType>
      void _broadcast_value(DataType& data, const size_t& root_rank) const
      {
        static_assert(not std::is_const_v<DataType>);
        static_assert(std::is_arithmetic_v<DataType>);
    
    #ifdef PASTIS_HAS_MPI
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<DataType>();
    
        MPI_Bcast(&data, 1,  mpi_datatype, root_rank, MPI_COMM_WORLD);
    #endif // PASTIS_HAS_MPI
      }
    
      template <typename ArrayType>
      void _broadcast_array(ArrayType& array, const size_t& root_rank) const
      {
        using DataType = typename ArrayType::data_type;
        static_assert(not std::is_const_v<DataType>);
        static_assert(std::is_arithmetic_v<DataType>);
    
    #ifdef PASTIS_HAS_MPI
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<DataType>();
        MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
    #endif // PASTIS_HAS_MPI
      }
    
      template <template <typename ...SendT> typename SendArrayType,
                template <typename ...RecvT> typename RecvArrayType,
                typename ...SendT, typename ...RecvT>
      RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array,
                                        RecvArrayType<RecvT...>& recv_array) const
      {
    #ifdef PASTIS_HAS_MPI
        using SendDataType = typename SendArrayType<SendT...>::data_type;
        using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
    
        static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
        static_assert(std::is_arithmetic_v<SendDataType>);
    
        Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
        Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE
    
        const size_t count = sent_array.size()/m_size;
    
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<SendDataType>();
    
        MPI_Alltoall(&(sent_array[0]), count, mpi_datatype,
                     &(recv_array[0]), count, mpi_datatype,
                     MPI_COMM_WORLD);
    #else  // PASTIS_HAS_MPI
        value_copy(sent_array, recv_array);
    #endif // PASTIS_HAS_MPI
        return recv_array;
      }
    
      template <template <typename ...SendT> typename SendArrayType,
                template <typename ...RecvT> typename RecvArrayType,
                typename ...SendT, typename ...RecvT>
      void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list,
                     std::vector<RecvArrayType<RecvT...>>& recv_array_list) const
      {
        using SendDataType = typename SendArrayType<SendT...>::data_type;
        using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
    
        static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
        static_assert(std::is_arithmetic_v<SendDataType>);
    
    #ifdef PASTIS_HAS_MPI
        std::vector<MPI_Request> request_list;
    
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<SendDataType>();
    
        for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
          const auto sent_array = sent_array_list[i_send];
          if (sent_array.size()>0) {
            MPI_Request request;
            MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
            request_list.push_back(request);
          }
        }
    
        for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
          auto recv_array = recv_array_list[i_recv];
          if (recv_array.size()>0) {
            MPI_Request request;
            MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
            request_list.push_back(request);
          }
        }
    
        std::vector<MPI_Status> status_list(request_list.size());
        if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
          // LCOV_EXCL_START
          std::cerr << "Communication error!\n";
          std::exit(1);
          // LCOV_EXCL_STOP
        }
    
    #else // PASTIS_HAS_MPI
        Assert(sent_array_list.size() == 1);
        Assert(recv_array_list.size() == 1);
    
        value_copy(sent_array_list[0], recv_array_list[0]);
    #endif // PASTIS_HAS_MPI
      }
    
      template <typename DataType,
                typename CastDataType>
      void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list,
                                  std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
      {
        std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list;
        for (size_t i=0; i<sent_array_list.size(); ++i) {
          sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i]));
        }
    
        using MutableDataType = std::remove_const_t<DataType>;
        std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list;
        for (size_t i=0; i<sent_array_list.size(); ++i) {
          recv_cast_array_list.emplace_back(recv_array_list[i]);
        }
    
        _exchange(sent_cast_array_list, recv_cast_array_list);
      }
    
     public:
      static void create(int& argc, char* argv[]);
      static void destroy();
    
      PASTIS_INLINE
      static Messenger& getInstance()
      {
        Assert(m_instance != nullptr); // LCOV_EXCL_LINE
        return *m_instance;
      }
    
      PASTIS_INLINE
      const size_t& rank() const
      {
        return m_rank;
      }
    
      PASTIS_INLINE
      const size_t& size() const
      {
        return m_size;
      }
    
      void barrier() const;
    
      template <typename DataType>
      DataType allReduceMin(const DataType& data) const
      {
    #ifdef PASTIS_HAS_MPI
        static_assert(not std::is_const_v<DataType>);
        static_assert(std::is_arithmetic_v<DataType>);
    
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<DataType>();
    
        DataType min_data = data;
        MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);
    
        return min_data;
    #else // PASTIS_HAS_MPI
        return data;
    #endif // PASTIS_HAS_MPI
      }
    
      template <typename DataType>
      DataType allReduceMax(const DataType& data) const
      {
    #ifdef PASTIS_HAS_MPI
        static_assert(not std::is_const_v<DataType>);
        static_assert(std::is_arithmetic_v<DataType>);
    
        MPI_Datatype mpi_datatype
            = Messenger::helper::mpiType<DataType>();
    
        DataType max_data = data;
        MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);
    
        return max_data;
    #else // PASTIS_HAS_MPI
        return data;
    #endif // PASTIS_HAS_MPI
      }
    
      template <typename DataType>
      PASTIS_INLINE
      Array<DataType>
      allGather(const DataType& data) const
      {
        static_assert(not std::is_const_v<DataType>);
    
        Array<DataType> gather_array(m_size);
    
        if constexpr(std::is_arithmetic_v<DataType>) {
          _allGather(data, gather_array);
        } else  if constexpr(std::is_trivial_v<DataType>) {
          using CastType = helper::split_cast_t<DataType>;
    
          CastArray cast_value_array = cast_value_to<const CastType>::from(data);
          CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);
    
          _allGather(cast_value_array, cast_gather_array);
        } else {
          static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
        }
        return gather_array;
      }
    
      template <typename DataType>
      PASTIS_INLINE
      Array<std::remove_const_t<DataType>>
      allGather(const Array<DataType>& array) const
      {
        using MutableDataType = std::remove_const_t<DataType>;
        Array<MutableDataType> gather_array(m_size*array.size());
    
        if constexpr(std::is_arithmetic_v<DataType>) {
          _allGather(array, gather_array);
        } else  if constexpr(std::is_trivial_v<DataType>) {
          using CastType = helper::split_cast_t<DataType>;
          using MutableCastType = helper::split_cast_t<MutableDataType>;
    
          CastArray cast_array = cast_array_to<CastType>::from(array);
          CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
    
          _allGather(cast_array, cast_gather_array);
        } else {
          static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
        }
        return gather_array;
      }
    
      template <typename SendDataType>
      PASTIS_INLINE
      Array<std::remove_const_t<SendDataType>>
      allToAll(const Array<SendDataType>& sent_array) const
      {
    #ifndef NDEBUG
        const size_t min_size = allReduceMin(sent_array.size());
        const size_t max_size = allReduceMax(sent_array.size());
        Assert(max_size == min_size); // LCOV_EXCL_LINE
    #endif // NDEBUG
        Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
    
        using DataType = std::remove_const_t<SendDataType>;
        Array<DataType> recv_array(sent_array.size());
    
        if constexpr(std::is_arithmetic_v<DataType>) {
          _allToAll(sent_array, recv_array);
        } else if constexpr(std::is_trivial_v<DataType>) {
          using CastType = helper::split_cast_t<DataType>;
    
          auto send_cast_array = cast_array_to<const CastType>::from(sent_array);
          auto recv_cast_array = cast_array_to<CastType>::from(recv_array);
          _allToAll(send_cast_array, recv_cast_array);
        } else {
          static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
        }
        return recv_array;
      }
    
      template <typename DataType>
      PASTIS_INLINE
      void broadcast(DataType& data, const size_t& root_rank) const
      {
        static_assert(not std::is_const_v<DataType>,
                      "cannot broadcast const data");
        if constexpr(std::is_arithmetic_v<DataType>) {
          _broadcast_value(data, root_rank);
        } else if constexpr(std::is_trivial_v<DataType>) {
          using CastType = helper::split_cast_t<DataType>;
          if constexpr(sizeof(CastType) == sizeof(DataType)) {
            CastType& cast_data = reinterpret_cast<CastType&>(data);
            _broadcast_value(cast_data, root_rank);
          } else {
            CastArray cast_array = cast_value_to<CastType>::from(data);
            _broadcast_array(cast_array, root_rank);
          }
        } else {
          static_assert(std::is_trivial_v<DataType>,
                        "unexpected non trivial type of data");
        }
      }
    
      template <typename DataType>
      PASTIS_INLINE
      void broadcast(Array<DataType>& array,
                     const size_t& root_rank) const
      {
        static_assert(not std::is_const_v<DataType>,
                      "cannot broadcast array of const");
        if constexpr(std::is_arithmetic_v<DataType>) {
          size_t size = array.size();
          _broadcast_value(size, root_rank);
          if (m_rank != root_rank) {
            array = Array<DataType>(size); // LCOV_EXCL_LINE
          }
          _broadcast_array(array, root_rank);
        } else if constexpr(std::is_trivial_v<DataType>) {
          size_t size = array.size();
          _broadcast_value(size, root_rank);
          if (m_rank != root_rank) {
            array = Array<DataType>(size); // LCOV_EXCL_LINE
          }
    
          using CastType = helper::split_cast_t<DataType>;
          auto cast_array = cast_array_to<CastType>::from(array);
          _broadcast_array(cast_array, root_rank);
        } else{
          static_assert(std::is_trivial_v<DataType>,
                        "unexpected non trivial type of data");
        }
      }
    
      template <typename SendDataType,
                typename RecvDataType>
      PASTIS_INLINE
      void exchange(const std::vector<Array<SendDataType>>& send_array_list,
                    std::vector<Array<RecvDataType>>& recv_array_list) const
      {
        static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
                      "send and receive data type do not match");
        static_assert(not std::is_const_v<RecvDataType>,
                      "receive data type cannot be const");
        using DataType = std::remove_const_t<SendDataType>;
    
        Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE
        Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE
    #ifndef NDEBUG
        Array<size_t> send_size(m_size);
        for (size_t i=0; i<m_size; ++i) {
          send_size[i] = send_array_list[i].size();
        }
        Array<size_t> recv_size = allToAll(send_size);
        bool correct_sizes = true;
        for (size_t i=0; i<m_size; ++i) {
          correct_sizes &= (recv_size[i] == recv_array_list[i].size());
        }
        Assert(correct_sizes); // LCOV_EXCL_LINE
    #endif // NDEBUG
    
        if constexpr(std::is_arithmetic_v<DataType>) {
          _exchange(send_array_list, recv_array_list);
        } else if constexpr(std::is_trivial_v<DataType>) {
          using CastType = helper::split_cast_t<DataType>;
          _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list);
        } else {
          static_assert(std::is_trivial_v<RecvDataType>,
                        "unexpected non trivial type of data");
        }
      }
    
      Messenger(const Messenger&) = delete;
      ~Messenger();
    };
    
    PASTIS_INLINE
    const Messenger& messenger()
    {
      return Messenger::getInstance();
    }
    
    PASTIS_INLINE
    const size_t& rank()
    {
      return messenger().rank();
    }
    
    PASTIS_INLINE
    const size_t& size()
    {
      return messenger().size();
    }
    
    PASTIS_INLINE
    void barrier()
    {
      return messenger().barrier();
    }
    
    template <typename DataType>
    PASTIS_INLINE
    DataType allReduceMax(const DataType& data)
    {
      return messenger().allReduceMax(data);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    DataType allReduceMin(const DataType& data)
    {
      return messenger().allReduceMin(data);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    Array<DataType>
    allGather(const DataType& data)
    {
      return messenger().allGather(data);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    Array<std::remove_const_t<DataType>>
    allGather(const Array<DataType>& array)
    {
      return messenger().allGather(array);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    Array<std::remove_const_t<DataType>>
    allToAll(const Array<DataType>& array)
    {
      return messenger().allToAll(array);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    void broadcast(DataType& data, const size_t& root_rank)
    {
      messenger().broadcast(data, root_rank);
    }
    
    template <typename DataType>
    PASTIS_INLINE
    void broadcast(Array<DataType>& array, const size_t& root_rank)
    {
      messenger().broadcast(array, root_rank);
    }
    
    template <typename SendDataType,
              typename RecvDataType>
    PASTIS_INLINE
    void exchange(const std::vector<Array<SendDataType>>& sent_array_list,
                  std::vector<Array<RecvDataType>>& recv_array_list)
    {
      static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
                    "send and receive data type do not match");
      static_assert(not std::is_const_v<RecvDataType>,
                    "receive data type cannot be const");
    
       messenger().exchange(sent_array_list, recv_array_list);
    }
    
    #ifdef PASTIS_HAS_MPI
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<char>() {return MPI_CHAR; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<int8_t>() {return MPI_INT8_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<int16_t>() {return MPI_INT16_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<int32_t>() {return MPI_INT32_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<int64_t>() {return MPI_INT64_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<uint8_t>() {return MPI_UINT8_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<uint16_t>() {return MPI_UINT16_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<uint32_t>() {return MPI_UINT32_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<uint64_t>() {return MPI_UINT64_T; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<float>() {return MPI_FLOAT; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<double>() {return MPI_DOUBLE; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<long double>() {return MPI_LONG_DOUBLE; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<wchar_t>() {return MPI_WCHAR; }
    
    template<> PASTIS_INLINE MPI_Datatype
    Messenger::helper::mpiType<bool>() {return MPI_CXX_BOOL; }
    
    #endif // PASTIS_HAS_MPI
    
    } // namespace parallel
    
    #endif // MESSENGER_HPP