From 4e25cdc43b071b7775a7cbbf16d1e853ef33d996 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 5 Oct 2018 11:59:43 +0200
Subject: [PATCH] Use Kokkos::deep_copy instead of parallel-for-loop

---
 src/utils/Array.hpp  | 25 +++++++++---------
 tests/test_Array.cpp | 63 +++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 74 insertions(+), 14 deletions(-)

diff --git a/src/utils/Array.hpp b/src/utils/Array.hpp
index 49993d366..9c65e52ec 100644
--- a/src/utils/Array.hpp
+++ b/src/utils/Array.hpp
@@ -6,6 +6,8 @@
 
 #include <PastisAssert.hpp>
 
+#include <Kokkos_CopyViews.hpp>
+
 template <typename DataType>
 class Array
 {
@@ -26,6 +28,15 @@ class Array
     return m_values.extent(0);
   }
 
+  friend PASTIS_INLINE
+  Array<std::remove_const_t<DataType>> copy(const Array<DataType>& source)
+  {
+    Array<std::remove_const_t<DataType>> image(source.size());
+    Kokkos::deep_copy(image.m_values, source.m_values);
+
+    return image;
+  }
+
   PASTIS_INLINE
   DataType& operator[](const index_type& i) const
   {
@@ -39,6 +50,7 @@ class Array
     static_assert(not std::is_const<DataType>(),
                   "Cannot modify Array of const");
 
+    // could consider to use std::fill
     parallel_for(this->size(), PASTIS_LAMBDA(const index_type& i){
         m_values[i] = data;
       });
@@ -93,17 +105,4 @@ class Array
   ~Array() = default;
 };
 
-template <typename DataType>
-PASTIS_INLINE
-Array<std::remove_const_t<DataType>> copy(const Array<DataType>& array)
-{
-  Array<std::remove_const_t<DataType>> image(array.size());
-  using index_type = typename Array<DataType>::index_type;
-
-  parallel_for(array.size(), PASTIS_LAMBDA(const index_type& i){
-        image[i] = array[i];
-    });
-  return image;
-}
-
 #endif // ARRAY_HPP
diff --git a/tests/test_Array.cpp b/tests/test_Array.cpp
index 6cd02146c..29257385c 100644
--- a/tests/test_Array.cpp
+++ b/tests/test_Array.cpp
@@ -50,7 +50,19 @@ TEST_CASE("Array", "[utils]") {
 
   }
 
-  SECTION("checking for affectations") {
+  SECTION("checking for fill") {
+    Array<int> b(10);
+    b.fill(3);
+
+    REQUIRE(((b[0] == 3) and (b[1] == 3) and
+             (b[2] == 3) and (b[3] == 3) and
+             (b[4] == 3) and (b[5] == 3) and
+             (b[6] == 3) and (b[7] == 3) and
+             (b[8] == 3) and (b[9] == 3)));
+
+  }
+
+  SECTION("checking for affectations (shallow copy)") {
     Array<const int> b;
     b = a;
 
@@ -78,7 +90,56 @@ TEST_CASE("Array", "[utils]") {
              (d[4] == 8) and (d[5] ==10) and
              (d[6] ==12) and (d[7] ==14) and
              (d[8] ==16) and (d[9] ==18)));
+
   }
+
+  SECTION("checking for affectations (deep copy)") {
+    Array<int> b(copy(a));
+
+    REQUIRE(((b[0] == 0) and (b[1] == 2) and
+             (b[2] == 4) and (b[3] == 6) and
+             (b[4] == 8) and (b[5] ==10) and
+             (b[6] ==12) and (b[7] ==14) and
+             (b[8] ==16) and (b[9] ==18)));
+
+    b.fill(2);
+
+    REQUIRE(((a[0] == 0) and (a[1] == 2) and
+             (a[2] == 4) and (a[3] == 6) and
+             (a[4] == 8) and (a[5] ==10) and
+             (a[6] ==12) and (a[7] ==14) and
+             (a[8] ==16) and (a[9] ==18)));
+
+    REQUIRE(((b[0] == 2) and (b[1] == 2) and
+             (b[2] == 2) and (b[3] == 2) and
+             (b[4] == 2) and (b[5] == 2) and
+             (b[6] == 2) and (b[7] == 2) and
+             (b[8] == 2) and (b[9] == 2)));
+
+    Array<int> c;
+    c = a;
+
+    REQUIRE(((c[0] == 0) and (c[1] == 2) and
+             (c[2] == 4) and (c[3] == 6) and
+             (c[4] == 8) and (c[5] ==10) and
+             (c[6] ==12) and (c[7] ==14) and
+             (c[8] ==16) and (c[9] ==18)));
+
+    c = copy(b);
+
+    REQUIRE(((a[0] == 0) and (a[1] == 2) and
+             (a[2] == 4) and (a[3] == 6) and
+             (a[4] == 8) and (a[5] ==10) and
+             (a[6] ==12) and (a[7] ==14) and
+             (a[8] ==16) and (a[9] ==18)));
+
+    REQUIRE(((c[0] == 2) and (c[1] == 2) and
+             (c[2] == 2) and (c[3] == 2) and
+             (c[4] == 2) and (c[5] == 2) and
+             (c[6] == 2) and (c[7] == 2) and
+             (c[8] == 2) and (c[9] == 2)));
+  }
+
 #ifndef NDEBUG
   SECTION("checking for bounds violation") {
     REQUIRE_THROWS_AS(a[10], AssertError);
-- 
GitLab