From 1690e21c3e30a2000a6e4f6acc72a946b9c82d19 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Mon, 19 Apr 2021 17:36:37 +0200
Subject: [PATCH] Add copy_to utilities

While copy returns a new Array/Table from copied values, copy_to deeps
copies values from an Array/Table to an existing one.

Source and destination must have the same size
---
 src/utils/Array.hpp  |  7 +++++++
 src/utils/Table.hpp  |  8 ++++++++
 tests/test_Array.cpp | 23 +++++++++++++++++++++++
 tests/test_Table.cpp | 32 ++++++++++++++++++++++++++++++++
 4 files changed, 70 insertions(+)

diff --git a/src/utils/Array.hpp b/src/utils/Array.hpp
index 1fe0aacaa..5eb834b85 100644
--- a/src/utils/Array.hpp
+++ b/src/utils/Array.hpp
@@ -36,6 +36,13 @@ class [[nodiscard]] Array
     return image;
   }
 
+  friend PUGS_INLINE void copy_to(const Array<DataType>& source,
+                                  const Array<std::remove_const_t<DataType>>& destination)
+  {
+    Assert(source.size() == destination.size());
+    Kokkos::deep_copy(destination.m_values, source.m_values);
+  }
+
   template <typename DataType2, typename... RT>
   friend PUGS_INLINE Array<DataType2> encapsulate(const Kokkos::View<DataType2*, RT...>& values);
 
diff --git a/src/utils/Table.hpp b/src/utils/Table.hpp
index 67f544b99..eedc5e7d7 100644
--- a/src/utils/Table.hpp
+++ b/src/utils/Table.hpp
@@ -46,6 +46,14 @@ class [[nodiscard]] Table
     return image;
   }
 
+  friend PUGS_INLINE void copy_to(const Table<DataType>& source,
+                                  const Table<std::remove_const_t<DataType>>& destination)
+  {
+    Assert(source.nbRows() == destination.nbRows());
+    Assert(source.nbColumns() == destination.nbColumns());
+    Kokkos::deep_copy(destination.m_values, source.m_values);
+  }
+
   template <typename DataType2, typename... RT>
   friend PUGS_INLINE Table<DataType2> encapsulate(const Kokkos::View<DataType2**, RT...>& values);
 
diff --git a/tests/test_Array.cpp b/tests/test_Array.cpp
index ac61be590..23d63a4a2 100644
--- a/tests/test_Array.cpp
+++ b/tests/test_Array.cpp
@@ -105,6 +105,23 @@ TEST_CASE("Array", "[utils]")
 
     REQUIRE(((c[0] == 2) and (c[1] == 2) and (c[2] == 2) and (c[3] == 2) and (c[4] == 2) and (c[5] == 2) and
              (c[6] == 2) and (c[7] == 2) and (c[8] == 2) and (c[9] == 2)));
+
+    Array<int> d{a.size()};
+    copy_to(a, d);
+
+    REQUIRE(((d[0] == 0) and (d[1] == 2) and (d[2] == 4) and (d[3] == 6) and (d[4] == 8) and (d[5] == 10) and
+             (d[6] == 12) and (d[7] == 14) and (d[8] == 16) and (d[9] == 18)));
+
+    REQUIRE(((c[0] == 2) and (c[1] == 2) and (c[2] == 2) and (c[3] == 2) and (c[4] == 2) and (c[5] == 2) and
+             (c[6] == 2) and (c[7] == 2) and (c[8] == 2) and (c[9] == 2)));
+
+    copy_to(c, d);
+
+    REQUIRE(((a[0] == 0) and (a[1] == 2) and (a[2] == 4) and (a[3] == 6) and (a[4] == 8) and (a[5] == 10) and
+             (a[6] == 12) and (a[7] == 14) and (a[8] == 16) and (a[9] == 18)));
+
+    REQUIRE(((d[0] == 2) and (d[1] == 2) and (d[2] == 2) and (d[3] == 2) and (d[4] == 2) and (d[5] == 2) and
+             (d[6] == 2) and (d[7] == 2) and (d[8] == 2) and (d[9] == 2)));
   }
 
   SECTION("checking for std container conversion")
@@ -220,5 +237,11 @@ TEST_CASE("Array", "[utils]")
   {
     REQUIRE_THROWS_AS(a[10], AssertError);
   }
+
+  SECTION("invalid copy_to")
+  {
+    Array<int> b{2 * a.size()};
+    REQUIRE_THROWS_AS(copy_to(a, b), AssertError);
+  }
 #endif   // NDEBUG
 }
diff --git a/tests/test_Table.cpp b/tests/test_Table.cpp
index 7b736f3d7..3521476c3 100644
--- a/tests/test_Table.cpp
+++ b/tests/test_Table.cpp
@@ -137,6 +137,23 @@ TEST_CASE("Table", "[utils]")
     REQUIRE(((c(0, 0) == 2) and (c(1, 0) == 2) and (c(2, 0) == 2) and (c(3, 0) == 2) and   //
              (c(0, 1) == 2) and (c(1, 1) == 2) and (c(2, 1) == 2) and (c(3, 1) == 2) and   //
              (c(0, 2) == 2) and (c(1, 2) == 2) and (c(2, 2) == 2) and (c(3, 2) == 2)));
+
+    Table<int> d{a.nbRows(), a.nbColumns()};
+    copy_to(a, d);
+
+    REQUIRE(((a(0, 0) == 0) and (a(1, 0) == 2) and (a(2, 0) == 4) and (a(3, 0) == 6) and   //
+             (a(0, 1) == 1) and (a(1, 1) == 3) and (a(2, 1) == 5) and (a(3, 1) == 7) and   //
+             (a(0, 2) == 2) and (a(1, 2) == 4) and (a(2, 2) == 6) and (a(3, 2) == 8)));
+
+    REQUIRE(((d(0, 0) == 0) and (d(1, 0) == 2) and (d(2, 0) == 4) and (d(3, 0) == 6) and   //
+             (d(0, 1) == 1) and (d(1, 1) == 3) and (d(2, 1) == 5) and (d(3, 1) == 7) and   //
+             (d(0, 2) == 2) and (d(1, 2) == 4) and (d(2, 2) == 6) and (d(3, 2) == 8)));
+
+    copy_to(c, d);
+
+    REQUIRE(((d(0, 0) == 2) and (d(1, 0) == 2) and (d(2, 0) == 2) and (d(3, 0) == 2) and   //
+             (d(0, 1) == 2) and (d(1, 1) == 2) and (d(2, 1) == 2) and (d(3, 1) == 2) and   //
+             (d(0, 2) == 2) and (d(1, 2) == 2) and (d(2, 2) == 2) and (d(3, 2) == 2)));
   }
 
   SECTION("checking for Kokkos::View encaspulation")
@@ -167,5 +184,20 @@ TEST_CASE("Table", "[utils]")
     REQUIRE_THROWS_AS(a(4, 0), AssertError);
     REQUIRE_THROWS_AS(a(0, 3), AssertError);
   }
+
+  SECTION("invalid copy_to")
+  {
+    SECTION("wrong row number")
+    {
+      Table<int> b{2 * a.nbRows(), a.nbColumns()};
+      REQUIRE_THROWS_AS(copy_to(a, b), AssertError);
+    }
+
+    SECTION("wrong column number")
+    {
+      Table<int> c{a.nbRows(), 2 * a.nbColumns()};
+      REQUIRE_THROWS_AS(copy_to(a, c), AssertError);
+    }
+  }
 #endif   // NDEBUG
 }
-- 
GitLab