diff --git a/src/utils/checkpointing/ReadTable.hpp b/src/utils/checkpointing/ReadTable.hpp index f57a689a8ee976b44e82ff06f635fc8816363de7..074dc5e2578287e75c881f9b0262f2796ac4ab8f 100644 --- a/src/utils/checkpointing/ReadTable.hpp +++ b/src/utils/checkpointing/ReadTable.hpp @@ -24,11 +24,11 @@ readTable(const HighFive::Group& group, const std::string& name) const std::vector<size_t> number_of_rows_per_rank = dataset.getAttribute("number_of_rows_per_rank").read<std::vector<size_t>>(); - std::vector<size_t> offset{0, 0ul}; + std::vector<size_t> offset = {0, 0ul}; for (size_t i = 0; i < parallel::rank(); ++i) { - offset[0] += number_of_rows_per_rank[i] * number_of_columns; + offset[0] += number_of_rows_per_rank[i] * number_of_columns; // LCOV_EXCL_LINE } - std::vector<size_t> count{number_of_rows_per_rank[parallel::rank()]}; + std::vector<size_t> count = {number_of_rows_per_rank[parallel::rank()] * number_of_columns}; Table<DataType> table(number_of_rows_per_rank[parallel::rank()], number_of_columns); if constexpr (std::is_same_v<CellType, data_type>) { diff --git a/src/utils/checkpointing/WriteArray.hpp b/src/utils/checkpointing/WriteArray.hpp index 1902378d43a5cf8fce7d5dffcf4c0a0bd9be0dd4..495fc5fcd916e3cfc8d5dad0862f701968cb6fda 100644 --- a/src/utils/checkpointing/WriteArray.hpp +++ b/src/utils/checkpointing/WriteArray.hpp @@ -22,7 +22,7 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar size_t current_offset = 0; for (size_t i = 0; i < parallel::rank(); ++i) { - current_offset += size_per_rank[i]; + current_offset += size_per_rank[i]; // LCOV_EXCL_LINE } std::vector<size_t> offset{current_offset, 0ul}; std::vector<size_t> count{array.size()}; diff --git a/src/utils/checkpointing/WriteTable.hpp b/src/utils/checkpointing/WriteTable.hpp index fb11b5f962b1a7199850bd0bdd3f1206da428683..ad1336a6f7d90479fbbcf28c7622470778dfbe8f 100644 --- a/src/utils/checkpointing/WriteTable.hpp +++ b/src/utils/checkpointing/WriteTable.hpp @@ -3,6 +3,8 @@ #include <utils/HighFivePugsUtils.hpp> +#include <mesh/CellType.hpp> +#include <mesh/ItemId.hpp> #include <utils/Messenger.hpp> #include <utils/Table.hpp> @@ -14,9 +16,11 @@ PUGS_INLINE void write(HighFive::Group& group, const std::string& name, const Table<DataType>& table) { const size_t number_of_columns = parallel::allReduceMax(table.numberOfColumns()); + // LCOV_EXCL_START if ((table.numberOfColumns() != number_of_columns) and (table.numberOfRows() > 0)) { throw UnexpectedError("table must have same number of columns in parallel"); } + // LCOV_EXCL_STOP auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; }; @@ -25,7 +29,7 @@ write(HighFive::Group& group, const std::string& name, const Table<DataType>& ta size_t current_offset = 0; for (size_t i = 0; i < parallel::rank(); ++i) { - current_offset += number_of_rows_per_rank[i] * table.numberOfColumns(); + current_offset += number_of_rows_per_rank[i] * table.numberOfColumns(); // LCOV_EXCL_LINE } std::vector<size_t> offset{current_offset, 0ul}; std::vector<size_t> count{table.numberOfRows() * table.numberOfColumns()}; @@ -39,9 +43,24 @@ write(HighFive::Group& group, const std::string& name, const Table<DataType>& ta auto xfer_props = HighFive::DataTransferProps{}; xfer_props.add(HighFive::UseCollectiveIO{}); - HighFive::DataSet dataset = - group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); - dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props); + HighFive::DataSet dataset; + + if constexpr (std::is_same_v<CellType, data_type>) { + using base_type = std::underlying_type_t<CellType>; + dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count) + .template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(table)), xfer_props); + } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or + (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { + using base_type = typename data_type::base_type; + + dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count) + .template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(table)), xfer_props); + } else { + dataset = group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props); + } std::vector<size_t> number_of_rows_per_rank_vector; for (size_t i = 0; i < number_of_rows_per_rank.size(); ++i) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4de129eca304537540016ff3ec89133b9c2740d3..4462049eef7a05ce96e8b7eea67e87939a5402df 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -171,6 +171,7 @@ if(PUGS_HAS_HDF5) test_checkpointing_ItemType.cpp test_checkpointing_IWriter.cpp test_checkpointing_IZoneDescriptor.cpp + test_checkpointing_Table.cpp ) endif(PUGS_HAS_HDF5) diff --git a/tests/test_checkpointing_Table.cpp b/tests/test_checkpointing_Table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd4223555dec37cbc2efc610af6f87154e94ac7c --- /dev/null +++ b/tests/test_checkpointing_Table.cpp @@ -0,0 +1,146 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <utils/Messenger.hpp> + +#include <language/utils/DataHandler.hpp> +#include <language/utils/EmbeddedData.hpp> +#include <utils/checkpointing/ReadTable.hpp> +#include <utils/checkpointing/WriteTable.hpp> + +#include <filesystem> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("checkpointing_Table", "[utils/checkpointing]") +{ + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + HighFive::FileAccessProps fapl; + fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL}); + fapl.add(HighFive::MPIOCollectiveMetadata{}); + HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl); + + SECTION("Table") + { + HighFive::Group checkpoint_group = file.createGroup("checkpoint_group"); + HighFive::Group useless_group; + + Table<CellType> cell_type_table{19 + 3 * parallel::rank(), 3}; + cell_type_table.fill(CellType::Line); + for (size_t i = 0; i < 10; ++i) { + for (size_t j = 0; j < cell_type_table.numberOfColumns(); ++j) { + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Line; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Triangle; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Quadrangle; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Polygon; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Tetrahedron; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Diamond; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Hexahedron; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Prism; + cell_type_table[std::rand() / (RAND_MAX / cell_type_table.numberOfRows())][j] = CellType::Pyramid; + } + } + checkpointing::write(checkpoint_group, "cell_type_table", cell_type_table); + + Table<CellId> cell_id_table{27 + 2 * parallel::rank(), 3}; + cell_id_table.fill(0); + for (size_t i = 0; i < 20; ++i) { + for (size_t j = 0; j < cell_id_table.numberOfColumns(); ++j) { + cell_id_table[std::rand() / (RAND_MAX / cell_id_table.numberOfRows())][j] = + std::rand() / (RAND_MAX / cell_id_table.numberOfRows()); + } + } + checkpointing::write(checkpoint_group, "cell_id_table", cell_id_table); + + Table<FaceId> face_id_table{29 + 2 * parallel::rank(), 2}; + face_id_table.fill(0); + for (size_t i = 0; i < 20; ++i) { + for (size_t j = 0; j < face_id_table.numberOfColumns(); ++j) { + face_id_table[std::rand() / (RAND_MAX / face_id_table.numberOfRows())][j] = + std::rand() / (RAND_MAX / face_id_table.numberOfRows()); + } + } + checkpointing::write(checkpoint_group, "face_id_table", face_id_table); + + Table<EdgeId> edge_id_table{13 + 2 * parallel::rank(), 4}; + edge_id_table.fill(0); + for (size_t i = 0; i < 20; ++i) { + for (size_t j = 0; j < edge_id_table.numberOfColumns(); ++j) { + edge_id_table[std::rand() / (RAND_MAX / edge_id_table.numberOfRows())][j] = + std::rand() / (RAND_MAX / edge_id_table.numberOfRows()); + } + } + checkpointing::write(checkpoint_group, "edge_id_table", edge_id_table); + + Table<NodeId> node_id_table{22 + 2 * parallel::rank(), 3}; + node_id_table.fill(0); + for (size_t i = 0; i < 20; ++i) { + for (size_t j = 0; j < node_id_table.numberOfColumns(); ++j) { + node_id_table[std::rand() / (RAND_MAX / node_id_table.numberOfRows())][j] = + std::rand() / (RAND_MAX / node_id_table.numberOfRows()); + } + } + checkpointing::write(checkpoint_group, "node_id_table", node_id_table); + + Table<double> double_table{16 + 3 * parallel::rank(), 5}; + double_table.fill(0); + for (size_t i = 0; i < 20; ++i) { + for (size_t j = 0; j < double_table.numberOfColumns(); ++j) { + double_table[std::rand() / (RAND_MAX / double_table.numberOfRows())][j] = + (1. * std::rand()) / (1. * RAND_MAX / double_table.numberOfRows()); + } + } + checkpointing::write(checkpoint_group, "double_table", double_table); + + file.flush(); + + auto is_same = [](const auto& a, const auto& b) { + bool same = true; + for (size_t i = 0; i < a.numberOfRows(); ++i) { + for (size_t j = 0; j < a.numberOfColumns(); ++j) { + if (a(i, j) != b(i, j)) { + same = false; + } + } + } + return parallel::allReduceAnd(same); + }; + + Table read_cell_type_table = checkpointing::readTable<CellType>(checkpoint_group, "cell_type_table"); + REQUIRE(is_same(cell_type_table, read_cell_type_table)); + + Table read_cell_id_table = checkpointing::readTable<CellId>(checkpoint_group, "cell_id_table"); + REQUIRE(is_same(cell_id_table, read_cell_id_table)); + + Table read_face_id_table = checkpointing::readTable<FaceId>(checkpoint_group, "face_id_table"); + REQUIRE(is_same(face_id_table, read_face_id_table)); + + Table read_edge_id_table = checkpointing::readTable<EdgeId>(checkpoint_group, "edge_id_table"); + REQUIRE(is_same(edge_id_table, read_edge_id_table)); + + Table read_node_id_table = checkpointing::readTable<NodeId>(checkpoint_group, "node_id_table"); + REQUIRE(is_same(node_id_table, read_node_id_table)); + + Table read_double_table = checkpointing::readTable<double>(checkpoint_group, "double_table"); + REQUIRE(is_same(double_table, read_double_table)); + } + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } +}