diff --git a/src/main.cpp b/src/main.cpp index e31ab068ed1f41f064cec799bb0de0f222bdd6fb..466806afb3a96bd5e6c28e5890327794d174fc52 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) std::shared_ptr<IMesh> p_mesh = gmsh_reader.mesh(); - switch (p_mesh->meshDimension()) { + switch (p_mesh->dimension()) { case 1: { std::vector<std::string> sym_boundary_name_list = {"XMIN", "XMAX"}; std::vector<std::shared_ptr<BoundaryConditionDescriptor>> bc_descriptor_list; @@ -75,9 +75,9 @@ int main(int argc, char *argv[]) const RefNodeList& ref_node_list = mesh.connectivity().refNodeList(i_ref_node_list); const RefId& ref = ref_node_list.refId(); if (ref == sym_bc_descriptor.boundaryDescriptor()) { - SymmetryBoundaryCondition<MeshType::dimension>* sym_bc - = new SymmetryBoundaryCondition<MeshType::dimension>(MeshFlatNodeBoundary<MeshType::dimension>(mesh, ref_node_list)); - std::shared_ptr<SymmetryBoundaryCondition<MeshType::dimension>> bc(sym_bc); + SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc + = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_node_list)); + std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc); bc_list.push_back(BoundaryConditionHandler(bc)); } } @@ -97,7 +97,7 @@ int main(int argc, char *argv[]) AcousticSolver<MeshDataType> acoustic_solver(mesh_data, bc_list); - using Rd = TinyVector<MeshType::dimension>; + using Rd = TinyVector<MeshType::Dimension>; const CellValue<const double>& Vj = mesh_data.Vj(); @@ -192,9 +192,9 @@ int main(int argc, char *argv[]) const RefFaceList& ref_face_list = mesh.connectivity().refFaceList(i_ref_face_list); const RefId& ref = ref_face_list.refId(); if (ref == sym_bc_descriptor.boundaryDescriptor()) { - SymmetryBoundaryCondition<MeshType::dimension>* sym_bc - = new SymmetryBoundaryCondition<MeshType::dimension>(MeshFlatNodeBoundary<MeshType::dimension>(mesh, ref_face_list)); - std::shared_ptr<SymmetryBoundaryCondition<MeshType::dimension>> bc(sym_bc); + SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc + = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_face_list)); + std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc); bc_list.push_back(BoundaryConditionHandler(bc)); } } @@ -296,9 +296,9 @@ int main(int argc, char *argv[]) const RefFaceList& ref_face_list = mesh.connectivity().refFaceList(i_ref_face_list); const RefId& ref = ref_face_list.refId(); if (ref == sym_bc_descriptor.boundaryDescriptor()) { - SymmetryBoundaryCondition<MeshType::dimension>* sym_bc - = new SymmetryBoundaryCondition<MeshType::dimension>(MeshFlatNodeBoundary<MeshType::dimension>(mesh, ref_face_list)); - std::shared_ptr<SymmetryBoundaryCondition<MeshType::dimension>> bc(sym_bc); + SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc + = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_face_list)); + std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc); bc_list.push_back(BoundaryConditionHandler(bc)); } } diff --git a/src/mesh/Connectivity.cpp b/src/mesh/Connectivity.cpp index b5ff6b825d42582379c94257452c5ac00095c9c8..0ea63e1418f8c879a43b59873d6ef16ad5bf526a 100644 --- a/src/mesh/Connectivity.cpp +++ b/src/mesh/Connectivity.cpp @@ -3,10 +3,12 @@ #include <Messenger.hpp> +template<size_t Dimension> +Connectivity<Dimension>::Connectivity() {} template<size_t Dimension> -Connectivity<Dimension>:: -Connectivity(const ConnectivityDescriptor& descriptor) +void Connectivity<Dimension>:: +_buildFrom(const ConnectivityDescriptor& descriptor) { #warning All of these should be checked by ConnectivityDescriptor Assert(descriptor.cell_by_node_vector.size() == descriptor.cell_type_vector.size()); @@ -22,7 +24,7 @@ Connectivity(const ConnectivityDescriptor& descriptor) cell_to_node_matrix = descriptor.cell_by_node_vector; { - CellValue<CellType> cell_type(*this); + WeakCellValue<CellType> cell_type(*this); parallel_for(this->numberOfCells(), PASTIS_LAMBDA(const CellId& j){ cell_type[j] = descriptor.cell_type_vector[j]; }); @@ -30,19 +32,19 @@ Connectivity(const ConnectivityDescriptor& descriptor) } { - CellValue<int> cell_number(*this); + WeakCellValue<int> cell_number(*this); cell_number = convert_to_array(descriptor.cell_number_vector); m_cell_number = cell_number; } { - NodeValue<int> node_number(*this); + WeakNodeValue<int> node_number(*this); node_number = convert_to_array(descriptor.node_number_vector); m_node_number = node_number; } { - CellValue<int> cell_global_index(*this); + WeakCellValue<int> cell_global_index(*this); #warning index must start accounting number of global indices of other procs #warning must take care of ghost cells int first_index = 0; @@ -54,7 +56,7 @@ Connectivity(const ConnectivityDescriptor& descriptor) { - CellValue<double> inv_cell_nb_nodes(*this); + WeakCellValue<double> inv_cell_nb_nodes(*this); parallel_for(this->numberOfCells(), PASTIS_LAMBDA(const CellId& j) { const auto& cell_nodes = cell_to_node_matrix.rowConst(j); inv_cell_nb_nodes[j] = 1./cell_nodes.length; @@ -63,14 +65,14 @@ Connectivity(const ConnectivityDescriptor& descriptor) } { - CellValue<int> cell_owner(*this); + WeakCellValue<int> cell_owner(*this); cell_owner = convert_to_array(descriptor.cell_owner_vector); m_cell_owner = cell_owner; } { const int rank = parallel::rank(); - CellValue<bool> cell_is_owned(*this); + WeakCellValue<bool> cell_is_owned(*this); parallel_for(this->numberOfCells(), PASTIS_LAMBDA(const CellId& j) { cell_is_owned[j] = (m_cell_owner[j] == rank); }); @@ -78,14 +80,14 @@ Connectivity(const ConnectivityDescriptor& descriptor) } { - NodeValue<int> node_owner(*this); + WeakNodeValue<int> node_owner(*this); node_owner = convert_to_array(descriptor.node_owner_vector); m_node_owner = node_owner; } { const int rank = parallel::rank(); - NodeValue<bool> node_is_owned(*this); + WeakNodeValue<bool> node_is_owned(*this); parallel_for(this->numberOfNodes(), PASTIS_LAMBDA(const NodeId& r) { node_is_owned[r] = (m_node_owner[r] == rank); }); @@ -113,19 +115,19 @@ Connectivity(const ConnectivityDescriptor& descriptor) m_cell_face_is_reversed = cell_face_is_reversed; } { - FaceValue<int> face_number(*this); + WeakFaceValue<int> face_number(*this); face_number = convert_to_array(descriptor.face_number_vector); m_face_number = face_number; } { - FaceValue<int> face_owner(*this); + WeakFaceValue<int> face_owner(*this); face_owner = convert_to_array(descriptor.face_owner_vector); m_face_owner = face_owner; } { const int rank = parallel::rank(); - FaceValue<bool> face_is_owned(*this); + WeakFaceValue<bool> face_is_owned(*this); parallel_for(this->numberOfFaces(), PASTIS_LAMBDA(const FaceId& l) { face_is_owned[l] = (m_face_owner[l] == rank); }); @@ -137,11 +139,10 @@ Connectivity(const ConnectivityDescriptor& descriptor) } } -template Connectivity1D:: -Connectivity(const ConnectivityDescriptor& descriptor); - -template Connectivity2D:: -Connectivity(const ConnectivityDescriptor& descriptor); +template Connectivity1D::Connectivity(); +template Connectivity2D::Connectivity(); +template Connectivity3D::Connectivity(); -template Connectivity3D:: -Connectivity(const ConnectivityDescriptor& descriptor); +template void Connectivity1D::_buildFrom(const ConnectivityDescriptor&); +template void Connectivity2D::_buildFrom(const ConnectivityDescriptor&); +template void Connectivity3D::_buildFrom(const ConnectivityDescriptor&); diff --git a/src/mesh/Connectivity.hpp b/src/mesh/Connectivity.hpp index 335689a34b2aa954c69f4facfc158d9b48606b51..3f1d83e371e05c7a986e8c1c10df30da22129055 100644 --- a/src/mesh/Connectivity.hpp +++ b/src/mesh/Connectivity.hpp @@ -80,37 +80,50 @@ class ConnectivityDescriptor ~ConnectivityDescriptor() = default; }; -template <size_t Dimension> +template <size_t Dim> class Connectivity final : public IConnectivity { + public: + PASTIS_INLINE + static std::shared_ptr<Connectivity<Dim>> + build(const ConnectivityDescriptor&); + private: - constexpr static auto& itemTId = ItemTypeId<Dimension>::itemTId; + constexpr static auto& itemTId = ItemTypeId<Dim>::itemTId; public: - static constexpr size_t dimension = Dimension; + static constexpr size_t Dimension = Dim; + PASTIS_INLINE + size_t dimension() const final + { + return Dimension; + } private: ConnectivityMatrix m_item_to_item_matrix[Dimension+1][Dimension+1]; - CellValue<const CellType> m_cell_type; + WeakCellValue<const CellType> m_cell_type; #warning is m_cell_global_index really necessary? should it be computed on demand instead? - CellValue<const int> m_cell_global_index; + WeakCellValue<const int> m_cell_global_index; - CellValue<const int> m_cell_number; - FaceValue<const int> m_face_number; + WeakCellValue<const int> m_cell_number; + WeakFaceValue<const int> m_face_number; #warning check that m_edge_number is filled correctly - EdgeValue<const int> m_edge_number; - NodeValue<const int> m_node_number; + WeakEdgeValue<const int> m_edge_number; + WeakNodeValue<const int> m_node_number; + + WeakCellValue<const int> m_cell_owner; + WeakCellValue<const bool> m_cell_is_owned; - CellValue<const int> m_cell_owner; - CellValue<const bool> m_cell_is_owned; + WeakFaceValue<const int> m_face_owner; + WeakFaceValue<const bool> m_face_is_owned; - FaceValue<const int> m_face_owner; - FaceValue<const bool> m_face_is_owned; -#warning Missing EdgeValue<const int> m_edge_owner and m_edge_is_owned; - NodeValue<const int> m_node_owner; - NodeValue<const bool> m_node_is_owned; + WeakEdgeValue<const int> m_edge_owner; + WeakEdgeValue<const bool> m_edge_is_owned; + + WeakNodeValue<const int> m_node_owner; + WeakNodeValue<const bool> m_node_is_owned; FaceValuePerCell<const bool> m_cell_face_is_reversed; @@ -135,7 +148,7 @@ class Connectivity final std::vector<RefFaceList> m_ref_face_list; std::vector<RefNodeList> m_ref_node_list; - CellValue<const double> m_inv_cell_nb_nodes; + WeakCellValue<const double> m_inv_cell_nb_nodes; void _computeCellFaceAndFaceNodeConnectivities(); @@ -172,67 +185,101 @@ class Connectivity final public: PASTIS_INLINE - const CellValue<const CellType>& cellType() const + CellValue<const CellType> cellType() const { return m_cell_type; } PASTIS_INLINE - const CellValue<const int>& cellNumber() const + CellValue<const int> cellNumber() const { return m_cell_number; } PASTIS_INLINE - const FaceValue<const int>& faceNumber() const + FaceValue<const int> faceNumber() const { return m_face_number; } PASTIS_INLINE - const EdgeValue<const int>& edgeNumber() const + EdgeValue<const int> edgeNumber() const { return m_edge_number; } PASTIS_INLINE - const NodeValue<const int>& nodeNumber() const + NodeValue<const int> nodeNumber() const { return m_node_number; } PASTIS_INLINE - const CellValue<const int>& cellOwner() const + CellValue<const int> cellOwner() const { return m_cell_owner; } PASTIS_INLINE - const FaceValue<const int>& faceOwner() const + FaceValue<const int> faceOwner() const { return m_face_owner; } PASTIS_INLINE - const NodeValue<const int>& nodeOwner() const + EdgeValue<const int> edgeOwner() const + { + perr() << __FILE__ << ':' << __LINE__ << ": edge owner not built\n"; + std::terminate(); + return m_edge_owner; + } + + PASTIS_INLINE + NodeValue<const int> nodeOwner() const { return m_node_owner; } + template <ItemType item_type> PASTIS_INLINE - const CellValue<const bool>& cellIsOwned() const + ItemValue<const bool, item_type> isOwned() const + { + if constexpr(item_type == ItemType::cell) { + return m_cell_is_owned; + } else if constexpr(item_type == ItemType::face) { + return m_face_is_owned; + } else if constexpr(item_type == ItemType::edge) { + return m_edge_is_owned; + } else if constexpr(item_type == ItemType::node) { + return m_node_is_owned; + } else { + static_assert(item_type == ItemType::cell, "unknown ItemType"); + return {}; + } + } + + PASTIS_INLINE + CellValue<const bool> cellIsOwned() const { return m_cell_is_owned; } PASTIS_INLINE - const FaceValue<const bool>& faceIsOwned() const + FaceValue<const bool> faceIsOwned() const { return m_face_is_owned; } PASTIS_INLINE - const NodeValue<const bool>& nodeIsOwned() const + EdgeValue<const bool> edgeIsOwned() const + { + perr() << __FILE__ << ':' << __LINE__ << ": edge is owned not built\n"; + std::terminate(); + return m_edge_is_owned; + } + + PASTIS_INLINE + NodeValue<const bool> nodeIsOwned() const { return m_node_is_owned; } @@ -331,7 +378,7 @@ class Connectivity final PASTIS_INLINE const auto& cellFaceIsReversed() const { - static_assert(dimension>1, "reversed faces makes no sense in dimension 1"); + static_assert(Dimension>1, "reversed faces makes no sense in dimension 1"); return m_cell_face_is_reversed; } @@ -344,7 +391,7 @@ class Connectivity final PASTIS_INLINE const auto& cellLocalNumbersInTheirEdges() const { - if constexpr (dimension>2) { + if constexpr (Dimension>2) { return _lazzyBuildItemNumberInTheirChild(m_cell_local_numbers_in_their_edges); } else { return cellLocalNumbersInTheirFaces(); @@ -354,7 +401,7 @@ class Connectivity final PASTIS_INLINE const auto& cellLocalNumbersInTheirFaces() const { - if constexpr (dimension>1) { + if constexpr (Dimension>1) { return _lazzyBuildItemNumberInTheirChild(m_cell_local_numbers_in_their_faces); } else { return cellLocalNumbersInTheirNodes(); @@ -364,7 +411,7 @@ class Connectivity final PASTIS_INLINE const auto& faceLocalNumbersInTheirCells() const { - if constexpr(dimension>1) { + if constexpr(Dimension>1) { return _lazzyBuildItemNumberInTheirChild(m_face_local_numbers_in_their_cells); } else { return nodeLocalNumbersInTheirCells(); @@ -374,21 +421,21 @@ class Connectivity final PASTIS_INLINE const auto& faceLocalNumbersInTheirEdges() const { - static_assert(dimension>2,"this function has no meaning in 1d or 2d"); + static_assert(Dimension>2,"this function has no meaning in 1d or 2d"); return _lazzyBuildItemNumberInTheirChild(m_face_local_numbers_in_their_edges); } PASTIS_INLINE const auto& faceLocalNumbersInTheirNodes() const { - static_assert(dimension>1,"this function has no meaning in 1d"); + static_assert(Dimension>1,"this function has no meaning in 1d"); return _lazzyBuildItemNumberInTheirChild(m_face_local_numbers_in_their_nodes); } PASTIS_INLINE const auto& edgeLocalNumbersInTheirCells() const { - if constexpr (dimension>2) { + if constexpr (Dimension>2) { return _lazzyBuildItemNumberInTheirChild(m_edge_local_numbers_in_their_cells); } else { return faceLocalNumbersInTheirCells(); @@ -398,14 +445,14 @@ class Connectivity final PASTIS_INLINE const auto& edgeLocalNumbersInTheirFaces() const { - static_assert(dimension>2, "this function has no meaning in 1d or 2d"); + static_assert(Dimension>2, "this function has no meaning in 1d or 2d"); return _lazzyBuildItemNumberInTheirChild(m_edge_local_numbers_in_their_faces); } PASTIS_INLINE const auto& edgeLocalNumbersInTheirNodes() const { - static_assert(dimension>2, "this function has no meaning in 1d or 2d"); + static_assert(Dimension>2, "this function has no meaning in 1d or 2d"); return _lazzyBuildItemNumberInTheirChild(m_edge_local_numbers_in_their_nodes); } @@ -418,14 +465,14 @@ class Connectivity final PASTIS_INLINE const auto& nodeLocalNumbersInTheirEdges() const { - static_assert(dimension>2, "this function has no meaning in 1d or 2d"); + static_assert(Dimension>2, "this function has no meaning in 1d or 2d"); return _lazzyBuildItemNumberInTheirChild(m_node_local_numbers_in_their_edges); } PASTIS_INLINE const auto& nodeLocalNumbersInTheirFaces() const { - static_assert(dimension>1,"this function has no meaning in 1d"); + static_assert(Dimension>1,"this function has no meaning in 1d"); return _lazzyBuildItemNumberInTheirChild(m_node_local_numbers_in_their_faces); } @@ -539,7 +586,7 @@ class Connectivity final return cell_to_node_matrix.numRows(); } - const CellValue<const double>& invCellNbNodes() const + CellValue<const double> invCellNbNodes() const { #warning add calculation on demand when variables will be defined return m_inv_cell_nb_nodes; @@ -547,14 +594,28 @@ class Connectivity final Connectivity(const Connectivity&) = delete; - Connectivity(const ConnectivityDescriptor& descriptor); + private: + Connectivity(); + void _buildFrom(const ConnectivityDescriptor& descriptor); + public: ~Connectivity() { ; } }; +template <size_t Dim> +PASTIS_INLINE +std::shared_ptr<Connectivity<Dim>> +Connectivity<Dim>::build(const ConnectivityDescriptor& descriptor) +{ + std::shared_ptr<Connectivity<Dim>> connectivity_ptr(new Connectivity<Dim>); + connectivity_ptr->_buildFrom(descriptor); + + return connectivity_ptr; +} + using Connectivity3D = Connectivity<3>; using Connectivity2D = Connectivity<2>; using Connectivity1D = Connectivity<1>; diff --git a/src/mesh/GmshReader.cpp b/src/mesh/GmshReader.cpp index f3aba4e18f18c308e0f16654221393a0238c3927..98ce886b3deb56494e70cb24d705a7a44e3010f0 100644 --- a/src/mesh/GmshReader.cpp +++ b/src/mesh/GmshReader.cpp @@ -529,7 +529,7 @@ void GmshReader::_dispatch() if (not m_mesh) { ConnectivityDescriptor descriptor; - std::shared_ptr connectivity = std::make_shared<ConnectivityType>(descriptor); + std::shared_ptr connectivity = ConnectivityType::build(descriptor); NodeValue<Rd> xr; m_mesh = std::make_shared<MeshType>(connectivity, xr); } @@ -1109,7 +1109,7 @@ void GmshReader::_dispatch() using ConnectivityType = Connectivity<Dimension>; std::shared_ptr p_connectivity - = std::make_shared<ConnectivityType>(new_descriptor); + = ConnectivityType::build(new_descriptor); const NodeValue<const Rd>& xr = mesh.xr(); std::vector<Array<const Rd>> send_node_coord_by_proc(parallel::size()); @@ -1321,7 +1321,7 @@ GmshReader::GmshReader(const std::string& filename) = [&]() { int mesh_dimension = -1; // unknown mesh dimension if (m_mesh) { - mesh_dimension = m_mesh->meshDimension(); + mesh_dimension = m_mesh->dimension(); } Array<int> dimensions = parallel::allGather(mesh_dimension); @@ -2202,7 +2202,7 @@ GmshReader::__proceedData() descriptor.addRefFaceList(RefFaceList(physical_ref_id.refId(), face_list)); } - std::shared_ptr p_connectivity = std::make_shared<Connectivity3D>(descriptor); + std::shared_ptr p_connectivity = Connectivity3D::build(descriptor); Connectivity3D& connectivity = *p_connectivity; using MeshType = Mesh<Connectivity3D>; @@ -2329,7 +2329,7 @@ GmshReader::__proceedData() descriptor.addRefFaceList(RefFaceList(physical_ref_id.refId(), face_list)); } - std::shared_ptr p_connectivity = std::make_shared<Connectivity2D>(descriptor); + std::shared_ptr p_connectivity = Connectivity2D::build(descriptor); Connectivity2D& connectivity = *p_connectivity; std::map<unsigned int, std::vector<unsigned int>> ref_points_map; @@ -2388,7 +2388,7 @@ GmshReader::__proceedData() parallel::rank()); - std::shared_ptr p_connectivity = std::make_shared<Connectivity1D>(descriptor); + std::shared_ptr p_connectivity = Connectivity1D::build(descriptor); Connectivity1D& connectivity = *p_connectivity; std::map<unsigned int, std::vector<unsigned int>> ref_points_map; diff --git a/src/mesh/IConnectivity.hpp b/src/mesh/IConnectivity.hpp index f60db961d80f939c64139f8746a674e4ce4de91e..3c87370669f7d00f575d7e83889fa8fcb10a4b1c 100644 --- a/src/mesh/IConnectivity.hpp +++ b/src/mesh/IConnectivity.hpp @@ -4,7 +4,7 @@ #include <ItemType.hpp> #include <ConnectivityMatrix.hpp> -class IConnectivity +class IConnectivity : public std::enable_shared_from_this<IConnectivity> { protected: template <typename DataType, @@ -19,6 +19,13 @@ class IConnectivity const ItemType& item_type_1) const = 0; public: + virtual size_t dimension() const = 0; + + std::shared_ptr<const IConnectivity> shared_ptr() const + { + return this->shared_from_this(); + } + virtual size_t numberOfNodes() const = 0; virtual size_t numberOfEdges() const = 0; virtual size_t numberOfFaces() const = 0; diff --git a/src/mesh/ItemValue.hpp b/src/mesh/ItemValue.hpp index 4b71bb9cc1a660bb82eb431cdb9e88ff9274b5b2..9ded397605b6e92814e71474dbbcef03e868a482 100644 --- a/src/mesh/ItemValue.hpp +++ b/src/mesh/ItemValue.hpp @@ -12,7 +12,8 @@ #include <IConnectivity.hpp> template <typename DataType, - ItemType item_type> + ItemType item_type, + typename ConnectivityPtr = std::shared_ptr<const IConnectivity>> class ItemValue { public: @@ -23,35 +24,57 @@ class ItemValue using index_type = ItemId; private: + using ConnectivitySharedPtr = std::shared_ptr<const IConnectivity>; + using ConnectivityWeakPtr = std::weak_ptr<const IConnectivity>; + + static_assert(std::is_same_v<ConnectivityPtr, ConnectivitySharedPtr> or + std::is_same_v<ConnectivityPtr, ConnectivityWeakPtr>); + + ConnectivityPtr m_connectivity_ptr; + bool m_is_built{false}; Array<DataType> m_values; - // Allows const version to access our data - friend ItemValue<std::add_const_t<DataType>, - item_type>; + // Allow const std:shared_ptr version to access our data + friend ItemValue<std::add_const_t<DataType>, item_type, + ConnectivitySharedPtr>; + + // Allow const std:weak_ptr version to access our data + friend ItemValue<std::add_const_t<DataType>, item_type, + ConnectivityWeakPtr>; friend PASTIS_INLINE - ItemValue<std::remove_const_t<DataType>,item_type> - copy(const ItemValue<DataType, item_type>& source) + ItemValue<std::remove_const_t<DataType>,item_type, ConnectivityPtr> + copy(const ItemValue<DataType, item_type, ConnectivityPtr>& source) { - ItemValue<std::remove_const_t<DataType>, item_type> image(source); + ItemValue<std::remove_const_t<DataType>, item_type, ConnectivityPtr> image(source); image.m_values = copy(source.m_values); return image; } public: - PASTIS_FORCEINLINE + PASTIS_INLINE bool isBuilt() const { - return m_is_built; + return m_connectivity_ptr.use_count() != 0; + } + + PASTIS_INLINE + std::shared_ptr<const IConnectivity> connectivity_ptr() const + { + if constexpr (std::is_same_v<ConnectivityPtr, ConnectivitySharedPtr>) { + return m_connectivity_ptr; + } else { + return m_connectivity_ptr.lock(); + } } PASTIS_INLINE size_t size() const { - Assert(m_is_built); + Assert(this->isBuilt()); return m_values.size(); } @@ -68,7 +91,7 @@ class ItemValue PASTIS_FORCEINLINE DataType& operator[](const ItemId& i) const noexcept(NO_ASSERT) { - Assert(m_is_built); + Assert(this->isBuilt()); return m_values[i]; } @@ -85,7 +108,7 @@ class ItemValue PASTIS_INLINE size_t numberOfItems() const { - Assert(m_is_built); + Assert(this->isBuilt()); return m_values.size(); } @@ -100,7 +123,12 @@ class ItemValue // ensures that const is not lost through copy static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()), - "Cannot assign ItemValue of const to ItemValue of non-const"); + "Cannot assign ItemValue of const to ItemValue of non-const"); + + if (not this->isBuilt()) { + perr() << "Cannot assign an array of values to a non-built ItemValue\n"; + std::terminate(); + } if (m_values.size() != values.size()) { perr() << "Cannot assign an array of values of a different size\n"; @@ -108,7 +136,7 @@ class ItemValue } if (values.size() > 0) { - if (not m_is_built) { + if (not this->isBuilt()) { perr() << "Cannot assign array of values to a non-built ItemValue\n"; std::terminate(); } @@ -119,10 +147,11 @@ class ItemValue return *this; } - template <typename DataType2> + template <typename DataType2, + typename ConnectivityPtr2> PASTIS_INLINE ItemValue& - operator=(const ItemValue<DataType2, item_type>& value_per_item) + operator=(const ItemValue<DataType2, item_type, ConnectivityPtr2>& value_per_item) { // ensures that DataType is the same as source DataType2 static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), @@ -130,17 +159,23 @@ class ItemValue // ensures that const is not lost through copy static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()), - "Cannot assign ItemValue of const to ItemValue of non-const"); + "Cannot assign ItemValue of const to ItemValue of non-const"); m_values = value_per_item.m_values; - m_is_built = value_per_item.m_is_built; + if constexpr (std::is_same_v<ConnectivityPtr, ConnectivitySharedPtr> and + std::is_same_v<ConnectivityPtr2, ConnectivityWeakPtr>) { + m_connectivity_ptr = value_per_item.m_connectivity_ptr.lock(); + } else { + m_connectivity_ptr = value_per_item.m_connectivity_ptr; + } return *this; } - template <typename DataType2> + template <typename DataType2, + typename ConnectivityPtr2> PASTIS_INLINE - ItemValue(const ItemValue<DataType2, item_type>& value_per_item) + ItemValue(const ItemValue<DataType2, item_type, ConnectivityPtr2>& value_per_item) { this->operator=(value_per_item); } @@ -150,8 +185,8 @@ class ItemValue PASTIS_INLINE ItemValue(const IConnectivity& connectivity) - : m_is_built{true}, - m_values(connectivity.numberOf<item_type>()) + : m_connectivity_ptr{connectivity.shared_ptr()}, + m_values{connectivity.numberOf<item_type>()} { static_assert(not std::is_const<DataType>(), "Cannot allocate ItemValue of const data: only view is supported"); ; @@ -173,4 +208,16 @@ using FaceValue = ItemValue<DataType, ItemType::face>; template <typename DataType> using CellValue = ItemValue<DataType, ItemType::cell>; +template <typename DataType> +using WeakNodeValue = ItemValue<DataType, ItemType::node, std::weak_ptr<const IConnectivity>>; + +template <typename DataType> +using WeakEdgeValue = ItemValue<DataType, ItemType::edge, std::weak_ptr<const IConnectivity>>; + +template <typename DataType> +using WeakFaceValue = ItemValue<DataType, ItemType::face, std::weak_ptr<const IConnectivity>>; + +template <typename DataType> +using WeakCellValue = ItemValue<DataType, ItemType::cell, std::weak_ptr<const IConnectivity>>; + #endif // ITEM_VALUE_HPP diff --git a/src/mesh/ItemValueUtils.hpp b/src/mesh/ItemValueUtils.hpp index 1381373975405a1bb669b6eb717501d0847b303b..26d05a37bf975eb39544002c89bbaafcac30d0f5 100644 --- a/src/mesh/ItemValueUtils.hpp +++ b/src/mesh/ItemValueUtils.hpp @@ -4,6 +4,8 @@ #include <Messenger.hpp> #include <ItemValue.hpp> +#include <Connectivity.hpp> + template <typename DataType, ItemType item_type> std::remove_const_t<DataType> @@ -13,10 +15,38 @@ min(const ItemValue<DataType, item_type>& item_value) using data_type = std::remove_const_t<typename ItemValueType::data_type>; using index_type = typename ItemValueType::index_type; + ItemValue<const bool, item_type> is_owned + = [&] (const IConnectivity& connectivity) { + + switch (connectivity.dimension()) { + case 1: { + const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity); + return connectivity_1d.isOwned<item_type>(); + break; + } + case 2: { + const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity); + return connectivity_2d.isOwned<item_type>(); + break; + } + case 3: { + const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity); + return connectivity_3d.isOwned<item_type>(); + break; + } + default: { + Assert((connectivity.dimension()>0) and (connectivity.dimension()<=3), + "unexpected connectivity dimension"); + return ItemValue<const bool, item_type>{}; + } + } + } (*item_value.connectivity_ptr()); + class ItemValueMin { private: const ItemValueType& m_item_value; + const ItemValue<const bool, item_type>& m_is_owned; public: PASTIS_INLINE @@ -30,7 +60,7 @@ min(const ItemValue<DataType, item_type>& item_value) PASTIS_INLINE void operator()(const index_type& i, data_type& data) const { - if (m_item_value[i] < data) { + if ((m_is_owned[i]) and (m_item_value[i] < data)) { data = m_item_value[i]; } } @@ -51,8 +81,10 @@ min(const ItemValue<DataType, item_type>& item_value) } PASTIS_INLINE - ItemValueMin(const ItemValueType& item_value) - : m_item_value(item_value) + ItemValueMin(const ItemValueType& item_value, + const ItemValue<const bool, item_type>& is_owned) + : m_item_value(item_value), + m_is_owned(is_owned) { ; } @@ -61,7 +93,7 @@ min(const ItemValue<DataType, item_type>& item_value) ~ItemValueMin() = default; }; - const DataType local_min = ItemValueMin{item_value}; + const DataType local_min = ItemValueMin{item_value, is_owned}; return parallel::allReduceMin(local_min); } @@ -74,10 +106,38 @@ max(const ItemValue<DataType, item_type>& item_value) using data_type = std::remove_const_t<typename ItemValueType::data_type>; using index_type = typename ItemValueType::index_type; + ItemValue<const bool, item_type> is_owned + = [&] (const IConnectivity& connectivity) { + + switch (connectivity.dimension()) { + case 1: { + const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity); + return connectivity_1d.isOwned<item_type>(); + break; + } + case 2: { + const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity); + return connectivity_2d.isOwned<item_type>(); + break; + } + case 3: { + const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity); + return connectivity_3d.isOwned<item_type>(); + break; + } + default: { + Assert((connectivity.dimension()>0) and (connectivity.dimension()<=3), + "unexpected connectivity dimension"); + return ItemValue<const bool, item_type>{}; + } + } + } (*item_value.connectivity_ptr()); + class ItemValueMax { private: const ItemValueType& m_item_value; + const ItemValue<const bool, item_type>& m_is_owned; public: PASTIS_INLINE @@ -91,7 +151,7 @@ max(const ItemValue<DataType, item_type>& item_value) PASTIS_INLINE void operator()(const index_type& i, data_type& data) const { - if (m_item_value[i] > data) { + if ((m_is_owned[i]) and (m_item_value[i] > data)) { data = m_item_value[i]; } } @@ -112,8 +172,10 @@ max(const ItemValue<DataType, item_type>& item_value) } PASTIS_INLINE - ItemValueMax(const ItemValueType& item_value) - : m_item_value(item_value) + ItemValueMax(const ItemValueType& item_value, + const ItemValue<const bool, item_type>& is_owned) + : m_item_value(item_value), + m_is_owned(is_owned) { ; } @@ -136,10 +198,38 @@ sum(const ItemValue<DataType, item_type>& item_value) using data_type = std::remove_const_t<typename ItemValueType::data_type>; using index_type = typename ItemValueType::index_type; + ItemValue<const bool, item_type> is_owned + = [&] (const IConnectivity& connectivity) { + + switch (connectivity.dimension()) { + case 1: { + const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity); + return connectivity_1d.isOwned<item_type>(); + break; + } + case 2: { + const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity); + return connectivity_2d.isOwned<item_type>(); + break; + } + case 3: { + const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity); + return connectivity_3d.isOwned<item_type>(); + break; + } + default: { + Assert((connectivity.dimension()>0) and (connectivity.dimension()<=3), + "unexpected connectivity dimension"); + return ItemValue<const bool, item_type>{}; + } + } + } (*item_value.connectivity_ptr()); + class ItemValueSum { private: const ItemValueType& m_item_value; + const ItemValue<const bool, item_type>& m_is_owned; public: PASTIS_INLINE @@ -153,7 +243,9 @@ sum(const ItemValue<DataType, item_type>& item_value) PASTIS_INLINE void operator()(const index_type& i, data_type& data) const { - data += m_item_value[i]; + if (m_is_owned[i]) { + data += m_item_value[i]; + } } PASTIS_INLINE @@ -174,8 +266,10 @@ sum(const ItemValue<DataType, item_type>& item_value) } PASTIS_INLINE - ItemValueSum(const ItemValueType& item_value) - : m_item_value(item_value) + ItemValueSum(const ItemValueType& item_value, + const ItemValue<const bool, item_type>& is_owned) + : m_item_value(item_value), + m_is_owned(is_owned) { ; } @@ -184,7 +278,7 @@ sum(const ItemValue<DataType, item_type>& item_value) ~ItemValueSum() = default; }; - const DataType local_sum = ItemValueSum{item_value}; + const DataType local_sum = ItemValueSum{item_value, is_owned}; return parallel::allReduceSum(local_sum); } diff --git a/src/mesh/Mesh.hpp b/src/mesh/Mesh.hpp index 31745ddf54546e243e0a1deccf956db29d5ce057..3136f90ba28c94c1ce3f8d2a8d81c28409c8f657 100644 --- a/src/mesh/Mesh.hpp +++ b/src/mesh/Mesh.hpp @@ -10,7 +10,7 @@ struct IMesh { - virtual size_t meshDimension() const = 0; + virtual size_t dimension() const = 0; virtual CSRGraph cellToCellGraph() const = 0; ~IMesh() = default; }; @@ -21,11 +21,11 @@ class Mesh final : public IMesh public: using Connectivity = ConnectivityType; - static constexpr size_t dimension = ConnectivityType::dimension; - using Rd = TinyVector<dimension>; + static constexpr size_t Dimension = ConnectivityType::Dimension; + using Rd = TinyVector<Dimension>; private: - const std::shared_ptr<Connectivity> m_connectivity; + const std::shared_ptr<const Connectivity> m_connectivity; NodeValue<const Rd> m_xr; NodeValue<Rd> m_mutable_xr; @@ -37,9 +37,9 @@ public: } PASTIS_INLINE - size_t meshDimension() const + size_t dimension() const { - return dimension; + return Dimension; } PASTIS_INLINE diff --git a/src/mesh/MeshData.hpp b/src/mesh/MeshData.hpp index f9a2aeb575ac5265073d47ce2c4025da39db7d51..dd8f33deac3c96a4fc97e7900cf4dbabd9e31ece 100644 --- a/src/mesh/MeshData.hpp +++ b/src/mesh/MeshData.hpp @@ -15,12 +15,12 @@ class MeshData public: using MeshType = M; - static constexpr size_t dimension = MeshType::dimension; - static_assert(dimension>0, "dimension must be strictly positive"); + static constexpr size_t Dimension = MeshType::Dimension; + static_assert(Dimension>0, "dimension must be strictly positive"); - using Rd = TinyVector<dimension>; + using Rd = TinyVector<Dimension>; - static constexpr double inv_dimension = 1./dimension; + static constexpr double inv_Dimension = 1./Dimension; private: const MeshType& m_mesh; @@ -33,7 +33,7 @@ class MeshData PASTIS_INLINE void _updateCenter() { // Computes vertices isobarycenter - if constexpr (dimension == 1) { + if constexpr (Dimension == 1) { const NodeValue<const Rd>& xr = m_mesh.xr(); const auto& cell_to_node_matrix @@ -81,17 +81,17 @@ class MeshData for (size_t R=0; R<cell_nodes.size(); ++R) { sum_cjr_xr += (xr[cell_nodes[R]], m_Cjr(j,R)); } - Vj[j] = inv_dimension * sum_cjr_xr; + Vj[j] = inv_Dimension * sum_cjr_xr; }); m_Vj = Vj; } PASTIS_INLINE void _updateCjr() { - if constexpr (dimension == 1) { + if constexpr (Dimension == 1) { // Cjr/njr/ljr are constant overtime } - else if constexpr (dimension == 2) { + else if constexpr (Dimension == 2) { const NodeValue<const Rd>& xr = m_mesh.xr(); const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix(); @@ -125,7 +125,7 @@ class MeshData }); m_njr = njr; } - } else if (dimension ==3) { + } else if (Dimension ==3) { const NodeValue<const Rd>& xr = m_mesh.xr(); NodeValuePerFace<Rd> Nlr(m_mesh.connectivity()); @@ -223,7 +223,7 @@ class MeshData m_njr = njr; } } - static_assert((dimension<=3), "only 1d, 2d and 3d are implemented"); + static_assert((Dimension<=3), "only 1d, 2d and 3d are implemented"); } public: @@ -267,7 +267,7 @@ class MeshData MeshData(const MeshType& mesh) : m_mesh(mesh) { - if constexpr (dimension==1) { + if constexpr (Dimension==1) { // in 1d Cjr are computed once for all { NodeValuePerCell<Rd> Cjr(m_mesh.connectivity()); diff --git a/src/mesh/MeshNodeBoundary.hpp b/src/mesh/MeshNodeBoundary.hpp index dbb3f6fec539dbb1f5e0e8b9cffa3889ef2f5183..2f83c82dc3a1c5451d51bcd0b293e7d6090a7d4a 100644 --- a/src/mesh/MeshNodeBoundary.hpp +++ b/src/mesh/MeshNodeBoundary.hpp @@ -16,7 +16,7 @@ #include <Messenger.hpp> -template <size_t dimension> +template <size_t Dimension> class MeshNodeBoundary { protected: @@ -34,7 +34,7 @@ class MeshNodeBoundary MeshNodeBoundary(const MeshType& mesh, const RefFaceList& ref_face_list) { - static_assert(dimension == MeshType::dimension); + static_assert(Dimension == MeshType::Dimension); const auto& face_to_cell_matrix = mesh.connectivity().faceToCellMatrix(); @@ -49,7 +49,7 @@ class MeshNodeBoundary Kokkos::vector<unsigned int> node_ids; // not enough but should reduce significantly the number of resizing - node_ids.reserve(dimension*face_list.size()); + node_ids.reserve(Dimension*face_list.size()); const auto& face_to_node_matrix = mesh.connectivity().faceToNodeMatrix(); @@ -76,7 +76,7 @@ class MeshNodeBoundary MeshNodeBoundary(const MeshType&, const RefNodeList& ref_node_list) : m_node_list(ref_node_list.nodeList()) { - static_assert(dimension == MeshType::dimension); + static_assert(Dimension == MeshType::Dimension); } MeshNodeBoundary() = default; @@ -86,12 +86,12 @@ class MeshNodeBoundary }; -template <size_t dimension> +template <size_t Dimension> class MeshFlatNodeBoundary - : public MeshNodeBoundary<dimension> + : public MeshNodeBoundary<Dimension> { public: - using Rd = TinyVector<dimension, double>; + using Rd = TinyVector<Dimension, double>; private: const Rd m_outgoing_normal; @@ -122,7 +122,7 @@ class MeshFlatNodeBoundary template <typename MeshType> MeshFlatNodeBoundary(const MeshType& mesh, const RefFaceList& ref_face_list) - : MeshNodeBoundary<dimension>(mesh, ref_face_list), + : MeshNodeBoundary<Dimension>(mesh, ref_face_list), m_outgoing_normal(_getOutgoingNormal(mesh)) { ; @@ -131,7 +131,7 @@ class MeshFlatNodeBoundary template <typename MeshType> MeshFlatNodeBoundary(const MeshType& mesh, const RefNodeList& ref_node_list) - : MeshNodeBoundary<dimension>(mesh, ref_node_list), + : MeshNodeBoundary<Dimension>(mesh, ref_node_list), m_outgoing_normal(_getOutgoingNormal(mesh)) { ; @@ -151,7 +151,7 @@ _checkBoundaryIsFlat(const TinyVector<2,double>& normal, const TinyVector<2,double>& xmax, const MeshType& mesh) const { - static_assert(MeshType::dimension == 2); + static_assert(MeshType::Dimension == 2); using R2 = TinyVector<2,double>; const R2 origin = 0.5*(xmin+xmax); @@ -175,7 +175,7 @@ TinyVector<1,double> MeshFlatNodeBoundary<1>:: _getNormal(const MeshType&) { - static_assert(MeshType::dimension == 1); + static_assert(MeshType::Dimension == 1); using R = TinyVector<1,double>; if (m_node_list.size() != 1) { @@ -193,7 +193,7 @@ TinyVector<2,double> MeshFlatNodeBoundary<2>:: _getNormal(const MeshType& mesh) { - static_assert(MeshType::dimension == 2); + static_assert(MeshType::Dimension == 2); using R2 = TinyVector<2,double>; const NodeValue<const R2>& xr = mesh.xr(); @@ -251,7 +251,7 @@ TinyVector<3,double> MeshFlatNodeBoundary<3>:: _getNormal(const MeshType& mesh) { - static_assert(MeshType::dimension == 3); + static_assert(MeshType::Dimension == 3); using R3 = TinyVector<3,double>; @@ -369,7 +369,7 @@ TinyVector<1,double> MeshFlatNodeBoundary<1>:: _getOutgoingNormal(const MeshType& mesh) { - static_assert(MeshType::dimension == 1); + static_assert(MeshType::Dimension == 1); using R = TinyVector<1,double>; const R normal = this->_getNormal(mesh); @@ -418,7 +418,7 @@ TinyVector<2,double> MeshFlatNodeBoundary<2>:: _getOutgoingNormal(const MeshType& mesh) { - static_assert(MeshType::dimension == 2); + static_assert(MeshType::Dimension == 2); using R2 = TinyVector<2,double>; const R2 normal = this->_getNormal(mesh); @@ -466,7 +466,7 @@ TinyVector<3,double> MeshFlatNodeBoundary<3>:: _getOutgoingNormal(const MeshType& mesh) { - static_assert(MeshType::dimension == 3); + static_assert(MeshType::Dimension == 3); using R3 = TinyVector<3,double>; const R3 normal = this->_getNormal(mesh); diff --git a/src/output/VTKWriter.hpp b/src/output/VTKWriter.hpp index dc3d1bb12018184c6b75a3997a1f889d1f3cc3d3..16f49cc414693b0a6610e7adea861322f419912e 100644 --- a/src/output/VTKWriter.hpp +++ b/src/output/VTKWriter.hpp @@ -276,14 +276,14 @@ class VTKWriter fout << "</PointData>\n"; fout << "<Points>\n"; { - using Rd = TinyVector<MeshType::dimension>; + using Rd = TinyVector<MeshType::Dimension>; const NodeValue<const Rd>& xr = mesh.xr(); Array<TinyVector<3>> positions(mesh.numberOfNodes()); parallel_for(mesh.numberOfNodes(), PASTIS_LAMBDA(NodeId r) { - for (unsigned short i=0; i<MeshType::dimension; ++i) { + for (unsigned short i=0; i<MeshType::Dimension; ++i) { positions[r][i] = xr[r][i]; } - for (unsigned short i=MeshType::dimension; i<3; ++i) { + for (unsigned short i=MeshType::Dimension; i<3; ++i) { positions[r][i] = 0; } }); diff --git a/src/scheme/AcousticSolver.hpp b/src/scheme/AcousticSolver.hpp index fe5e8bf182d0ba02c08acb9f8d406e96244c7fad..142b6f03b973b96037e50a42fb12c175213c912a 100644 --- a/src/scheme/AcousticSolver.hpp +++ b/src/scheme/AcousticSolver.hpp @@ -30,10 +30,10 @@ class AcousticSolver const typename MeshType::Connectivity& m_connectivity; const std::vector<BoundaryConditionHandler>& m_boundary_condition_list; - constexpr static size_t dimension = MeshType::dimension; + constexpr static size_t Dimension = MeshType::Dimension; - using Rd = TinyVector<dimension>; - using Rdd = TinyMatrix<dimension>; + using Rd = TinyVector<Dimension>; + using Rdd = TinyMatrix<Dimension>; private: PASTIS_INLINE @@ -136,8 +136,8 @@ class AcousticSolver break; } case BoundaryCondition::symmetry: { - const SymmetryBoundaryCondition<dimension>& symmetry_bc - = dynamic_cast<const SymmetryBoundaryCondition<dimension>&>(handler.boundaryCondition()); + const SymmetryBoundaryCondition<Dimension>& symmetry_bc + = dynamic_cast<const SymmetryBoundaryCondition<Dimension>&>(handler.boundaryCondition()); const Rd& n = symmetry_bc.outgoingNormal(); const Rdd I = identity; diff --git a/src/scheme/FiniteVolumesEulerUnknowns.hpp b/src/scheme/FiniteVolumesEulerUnknowns.hpp index 109e93d5a13acd2ce5c23b2b83576af21665c3c7..390e1c91c95beec72d96ec1389d564971a25026b 100644 --- a/src/scheme/FiniteVolumesEulerUnknowns.hpp +++ b/src/scheme/FiniteVolumesEulerUnknowns.hpp @@ -11,8 +11,8 @@ public: using MeshDataType = TMeshData; using MeshType = typename MeshDataType::MeshType; - static constexpr size_t dimension = MeshType::dimension; - using Rd = TinyVector<dimension>; + static constexpr size_t Dimension = MeshType::Dimension; + using Rd = TinyVector<Dimension>; private: const MeshDataType& m_mesh_data; diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 328be198bdf8e52ebf2f67ba41fbed6c6162cb01..1575c44410881fe76341c54c1346392321d14fd6 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -334,19 +334,25 @@ class Messenger { #ifdef PASTIS_HAS_MPI static_assert(not std::is_const_v<DataType>); - static_assert(std::is_arithmetic_v<DataType>); - - MPI_Datatype mpi_datatype - = Messenger::helper::mpiType<DataType>(); + if constexpr(std::is_arithmetic_v<DataType>) { + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); - perr() << __FILE__ << ':' << __LINE__ << ": Implementation not finished\n"; - std::terminate(); + DataType data_sum = data; + MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, MPI_COMM_WORLD); + return data_sum; + } else if (is_trivially_castable<DataType>){ + using InnerDataType = typename DataType::data_type; - DataType data_sum = data; - MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, MPI_COMM_WORLD); + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<InnerDataType>(); + const int size = sizeof(DataType)/sizeof(InnerDataType); + DataType data_sum = data; + MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, MPI_COMM_WORLD); - return data_sum; + return data_sum; + } #else // PASTIS_HAS_MPI return data; #endif // PASTIS_HAS_MPI