diff --git a/src/utils/Array.hpp b/src/utils/Array.hpp index 9c65e52ecb8c40c33a33211484e2c67ad8f3882d..ae899cc3e2d916880473e6c549a73f43783ae030 100644 --- a/src/utils/Array.hpp +++ b/src/utils/Array.hpp @@ -7,6 +7,7 @@ #include <PastisAssert.hpp> #include <Kokkos_CopyViews.hpp> +#include <algorithm> template <typename DataType> class Array @@ -66,7 +67,7 @@ class Array // ensures that const is not lost through copy static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()), - "Cannot assign Array of const to Array of non-const"); + "Cannot assign Array of const to Array of non-const"); m_values = array.m_values; return *this; } @@ -105,4 +106,16 @@ class Array ~Array() = default; }; +template <typename Container> +PASTIS_INLINE +Array<typename Container::value_type> convert_to_array(const Container& given_vector) +{ + using DataType = typename Container::value_type; + Array<std::remove_const_t<DataType>> array(given_vector.size()); + if (given_vector.size()>0) { + std::copy(begin(given_vector), end(given_vector), &(array[0])); + } + return array; +} + #endif // ARRAY_HPP diff --git a/tests/test_Array.cpp b/tests/test_Array.cpp index 29257385c3ae761628a640c2772fd2a90fd22ac5..c2ec98c99db1774873e9178ebc2afa0065251a42 100644 --- a/tests/test_Array.cpp +++ b/tests/test_Array.cpp @@ -4,6 +4,9 @@ #include <Array.hpp> #include <Types.hpp> +#include <vector> +#include <set> + // Instantiate to ensure full coverage is performed template class Array<int>; @@ -140,6 +143,48 @@ TEST_CASE("Array", "[utils]") { (c[8] == 2) and (c[9] == 2))); } + SECTION("checking for std container conversion") { + { + std::vector<int> v{1,2,5,3}; + { + Array<int> v_array = convert_to_array(v); + + REQUIRE(v_array.size() == v.size()); + REQUIRE(((v_array[0] == 1) and (v_array[1] == 2) and + (v_array[2] == 5) and (v_array[3] == 3))); + } + + { + Array<const int> v_array = convert_to_array(v); + + REQUIRE(v_array.size() == v.size()); + REQUIRE(((v_array[0] == 1) and (v_array[1] == 2) and + (v_array[2] == 5) and (v_array[3] == 3))); + } + } + { + std::vector<int> w; + { + Array<int> w_array = convert_to_array(w); + REQUIRE(w_array.size() == 0); + } + { + Array<const int> w_array = convert_to_array(w); + REQUIRE(w_array.size() == 0); + } + } + + { + std::set<int> s{4,2,5,3,1}; + Array<int> s_array = convert_to_array(s); + + REQUIRE(s_array.size() == s.size()); + REQUIRE(((s_array[0] == 1) and (s_array[1] == 2) and + (s_array[2] == 3) and (s_array[3] == 4) and + (s_array[4] == 5))); + } + } + #ifndef NDEBUG SECTION("checking for bounds violation") { REQUIRE_THROWS_AS(a[10], AssertError);