diff --git a/src/utils/Array.hpp b/src/utils/Array.hpp index 49993d36651a745f1596d1b9c33dc872fbbead18..9c65e52ecb8c40c33a33211484e2c67ad8f3882d 100644 --- a/src/utils/Array.hpp +++ b/src/utils/Array.hpp @@ -6,6 +6,8 @@ #include <PastisAssert.hpp> +#include <Kokkos_CopyViews.hpp> + template <typename DataType> class Array { @@ -26,6 +28,15 @@ class Array return m_values.extent(0); } + friend PASTIS_INLINE + Array<std::remove_const_t<DataType>> copy(const Array<DataType>& source) + { + Array<std::remove_const_t<DataType>> image(source.size()); + Kokkos::deep_copy(image.m_values, source.m_values); + + return image; + } + PASTIS_INLINE DataType& operator[](const index_type& i) const { @@ -39,6 +50,7 @@ class Array static_assert(not std::is_const<DataType>(), "Cannot modify Array of const"); + // could consider to use std::fill parallel_for(this->size(), PASTIS_LAMBDA(const index_type& i){ m_values[i] = data; }); @@ -93,17 +105,4 @@ class Array ~Array() = default; }; -template <typename DataType> -PASTIS_INLINE -Array<std::remove_const_t<DataType>> copy(const Array<DataType>& array) -{ - Array<std::remove_const_t<DataType>> image(array.size()); - using index_type = typename Array<DataType>::index_type; - - parallel_for(array.size(), PASTIS_LAMBDA(const index_type& i){ - image[i] = array[i]; - }); - return image; -} - #endif // ARRAY_HPP diff --git a/tests/test_Array.cpp b/tests/test_Array.cpp index 6cd02146c8775dca7a061672b8efed1fbdcad236..29257385c3ae761628a640c2772fd2a90fd22ac5 100644 --- a/tests/test_Array.cpp +++ b/tests/test_Array.cpp @@ -50,7 +50,19 @@ TEST_CASE("Array", "[utils]") { } - SECTION("checking for affectations") { + SECTION("checking for fill") { + Array<int> b(10); + b.fill(3); + + REQUIRE(((b[0] == 3) and (b[1] == 3) and + (b[2] == 3) and (b[3] == 3) and + (b[4] == 3) and (b[5] == 3) and + (b[6] == 3) and (b[7] == 3) and + (b[8] == 3) and (b[9] == 3))); + + } + + SECTION("checking for affectations (shallow copy)") { Array<const int> b; b = a; @@ -78,7 +90,56 @@ TEST_CASE("Array", "[utils]") { (d[4] == 8) and (d[5] ==10) and (d[6] ==12) and (d[7] ==14) and (d[8] ==16) and (d[9] ==18))); + } + + SECTION("checking for affectations (deep copy)") { + Array<int> b(copy(a)); + + REQUIRE(((b[0] == 0) and (b[1] == 2) and + (b[2] == 4) and (b[3] == 6) and + (b[4] == 8) and (b[5] ==10) and + (b[6] ==12) and (b[7] ==14) and + (b[8] ==16) and (b[9] ==18))); + + b.fill(2); + + REQUIRE(((a[0] == 0) and (a[1] == 2) and + (a[2] == 4) and (a[3] == 6) and + (a[4] == 8) and (a[5] ==10) and + (a[6] ==12) and (a[7] ==14) and + (a[8] ==16) and (a[9] ==18))); + + REQUIRE(((b[0] == 2) and (b[1] == 2) and + (b[2] == 2) and (b[3] == 2) and + (b[4] == 2) and (b[5] == 2) and + (b[6] == 2) and (b[7] == 2) and + (b[8] == 2) and (b[9] == 2))); + + Array<int> c; + c = a; + + REQUIRE(((c[0] == 0) and (c[1] == 2) and + (c[2] == 4) and (c[3] == 6) and + (c[4] == 8) and (c[5] ==10) and + (c[6] ==12) and (c[7] ==14) and + (c[8] ==16) and (c[9] ==18))); + + c = copy(b); + + REQUIRE(((a[0] == 0) and (a[1] == 2) and + (a[2] == 4) and (a[3] == 6) and + (a[4] == 8) and (a[5] ==10) and + (a[6] ==12) and (a[7] ==14) and + (a[8] ==16) and (a[9] ==18))); + + REQUIRE(((c[0] == 2) and (c[1] == 2) and + (c[2] == 2) and (c[3] == 2) and + (c[4] == 2) and (c[5] == 2) and + (c[6] == 2) and (c[7] == 2) and + (c[8] == 2) and (c[9] == 2))); + } + #ifndef NDEBUG SECTION("checking for bounds violation") { REQUIRE_THROWS_AS(a[10], AssertError);