diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp index 9a9e9aa2239de844fb8e01d38c73cec514b32687..0e6cdb8222c9373c8bd2c81f94becd7a8d99ee3b 100644 --- a/src/algebra/TinyMatrix.hpp +++ b/src/algebra/TinyMatrix.hpp @@ -2,11 +2,13 @@ #define TINY_MATRIX_HPP #include <PastisAssert.hpp> -#include <iostream> +#include <Kokkos_Core.hpp> #include <Types.hpp> #include <TinyVector.hpp> +#include <iostream> + template <size_t N, typename T=double> class TinyMatrix { @@ -33,6 +35,15 @@ private: } public: + KOKKOS_INLINE_FUNCTION + TinyMatrix operator-() const + { + TinyMatrix opposed; + for (size_t i=0; i<N*N; ++i) { + opposed.m_values[i] = -m_values[i]; + } + return std::move(opposed); + } KOKKOS_INLINE_FUNCTION friend TinyMatrix operator*(const T& t, const TinyMatrix& A) @@ -44,7 +55,6 @@ public: return std::move(tA); } - KOKKOS_INLINE_FUNCTION TinyMatrix operator*(const TinyMatrix& B) const { @@ -62,7 +72,6 @@ public: return std::move(AB); } - KOKKOS_INLINE_FUNCTION TinyVector<N,T> operator*(const TinyVector<N,T>& x) const { @@ -124,6 +133,15 @@ public: return std::move(sum); } + KOKKOS_INLINE_FUNCTION + TinyMatrix operator+(TinyMatrix&& A) const + { + for (size_t i=0; i<N*N; ++i) { + A.m_values[i] += m_values[i]; + } + return std::move(A); + } + KOKKOS_INLINE_FUNCTION TinyMatrix operator-(const TinyMatrix& A) const { @@ -134,6 +152,15 @@ public: return std::move(difference); } + KOKKOS_INLINE_FUNCTION + TinyMatrix operator-(TinyMatrix&& A) const + { + for (size_t i=0; i<N*N; ++i) { + A.m_values[i] = m_values[i]-A.m_values[i]; + } + return std::move(A); + } + KOKKOS_INLINE_FUNCTION TinyMatrix& operator+=(const TinyMatrix& A) { diff --git a/src/algebra/TinyVector.hpp b/src/algebra/TinyVector.hpp index f95d456395427efebbe48802bdf65fc638205bda..038741231cb32963330d8a061f54ff41a7b05238 100644 --- a/src/algebra/TinyVector.hpp +++ b/src/algebra/TinyVector.hpp @@ -27,6 +27,17 @@ private: } public: + KOKKOS_INLINE_FUNCTION + TinyVector operator-() const + { + TinyVector opposed; + for (size_t i=0; i<N; ++i) { + opposed.m_values[i] =-m_values[i]; + } + return std::move(opposed); + } + + KOKKOS_INLINE_FUNCTION constexpr size_t dimension() const { return N; @@ -67,6 +78,15 @@ public: return std::move(tv); } + KOKKOS_INLINE_FUNCTION + friend TinyVector operator*(const T& t, TinyVector&& v) + { + for (size_t i=0; i<N; ++i) { + v.m_values[i] *= t; + } + return std::move(v); + } + KOKKOS_INLINE_FUNCTION friend std::ostream& operator<<(std::ostream& os, const TinyVector& v) { @@ -83,21 +103,39 @@ public: { TinyVector sum; for (size_t i=0; i<N; ++i) { - sum[i] = m_values[i]+v.m_values[i]; + sum.m_values[i] = m_values[i]+v.m_values[i]; } return std::move(sum); } + KOKKOS_INLINE_FUNCTION + TinyVector operator+(TinyVector&& v) const + { + for (size_t i=0; i<N; ++i) { + v.m_values[i] += m_values[i]; + } + return std::move(v); + } + KOKKOS_INLINE_FUNCTION TinyVector operator-(const TinyVector& v) const { TinyVector difference; for (size_t i=0; i<N; ++i) { - difference[i] = m_values[i]-v.m_values[i]; + difference.m_values[i] = m_values[i]-v.m_values[i]; } return std::move(difference); } + KOKKOS_INLINE_FUNCTION + TinyVector operator-(TinyVector&& v) const + { + for (size_t i=0; i<N; ++i) { + v.m_values[i] = m_values[i]-v.m_values[i]; + } + return std::move(v); + } + KOKKOS_INLINE_FUNCTION TinyVector& operator+=(const TinyVector& v) { diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp index e0bc8e843304dd734cb566ede45f2b49e60f391d..25dc03c35bde7daf137ac6b545a575b94f315dfc 100644 --- a/tests/test_TinyMatrix.cpp +++ b/tests/test_TinyMatrix.cpp @@ -2,6 +2,14 @@ #include <TinyMatrix.hpp> #include <PastisAssert.hpp> +#include <Kokkos_Core.hpp> +#include <Types.hpp> + +// Instantiate to ensure full coverage is performed +template class TinyMatrix<1,int>; +template class TinyMatrix<2,int>; +template class TinyMatrix<3,int>; +template class TinyMatrix<4,double>; TEST_CASE("TinyMatrix", "[algebra]") { TinyMatrix<3, int> A(1,2,3, @@ -14,7 +22,12 @@ TEST_CASE("TinyMatrix", "[algebra]") { TinyMatrix<3,int> B(6,5,3, 8,34,6, 35,6,7); - + SECTION("checking for opposed matrix") { + const TinyMatrix<3, int> minus_A = -A; + REQUIRE(((minus_A(0,0)==-1) and (minus_A(0,1)==-2) and (minus_A(0,2)==-3) and + (minus_A(1,0)==-4) and (minus_A(1,1)==-5) and (minus_A(1,2)==-6) and + (minus_A(2,0)==-7) and (minus_A(2,1)==-8) and (minus_A(2,2)==-9))); + } SECTION("checking for equality and difference tests") { const TinyMatrix<3, int> copy_A = A; REQUIRE(((copy_A(0,0)==1) and (copy_A(0,1)==2) and (copy_A(0,2)==3) and @@ -69,6 +82,9 @@ TEST_CASE("TinyMatrix", "[algebra]") { TinyMatrix<3, int> ApB = A; ApB += B; REQUIRE(ApB==A+B); + + TinyMatrix<3, int> Ap2B = A+2*B; + REQUIRE(Ap2B==ApB+B); } SECTION("checking for matrices difference ") { @@ -79,6 +95,9 @@ TEST_CASE("TinyMatrix", "[algebra]") { TinyMatrix<3, int> AmB = A; AmB -= B; REQUIRE(AmB==A-B); + + TinyMatrix<3, int> Am2B = A-2*B; + REQUIRE(Am2B == AmB-B); } SECTION("checking for matrices product") { diff --git a/tests/test_TinyVector.cpp b/tests/test_TinyVector.cpp index 38588a6b4bed4a4440bf1a320ac8e4c9c23a38cc..3dd124cd3f95964893daf5621635e4503f0a0687 100644 --- a/tests/test_TinyVector.cpp +++ b/tests/test_TinyVector.cpp @@ -3,20 +3,28 @@ #include <TinyVector.hpp> #include <PastisAssert.hpp> +// Instantiate to ensure full coverage is performed +template class TinyVector<1,int>; +template class TinyVector<3,int>; + TEST_CASE("TinyVector", "[algebra]") { TinyVector<3,int> v(1,2,3); REQUIRE(((v[0] == 1) and (v[1] == 2) and (v[2] == 3))); + REQUIRE(-v == TinyVector<3,int>(-1,-2,-3)); + REQUIRE(v.dimension() == 3); const TinyVector<3,int> z = zero; REQUIRE(((z[0] == 0) and (z[1] == 0) and (z[2] == 0))); v = TinyVector<3,int>(3,2,4); + REQUIRE(Catch::Detail::stringify(v) == "(3,2,4)"); + TinyVector<3,int> w(1,2,6); REQUIRE((v,w)==31); w = 2*v; - REQUIRE(((w[0] == 2*v[0]) and (w[1]==2*v[1]) and (w[2]==2*v[2]))); + REQUIRE(w == TinyVector<3,int>(6,4,8)); TinyVector<3,int> x = v; REQUIRE(x==v); @@ -36,17 +44,27 @@ TEST_CASE("TinyVector", "[algebra]") { x=v+w; REQUIRE(x==3*v); - x-=w; + x=2*(v+v); + REQUIRE(x==4*v); + + x=v+(w+x); + REQUIRE(x==7*v); + + x=x-(v+w); + REQUIRE(x==4*v); + + x-=v+w; REQUIRE(x==v); x=w-v; REQUIRE(x==v); + x=v-2*v; + REQUIRE(x==-v); + TinyVector<3,int> z1; z1 = zero; REQUIRE(((z1[0] == 0) and (z1[1] == 0) and (z1[2] == 0))); - REQUIRE(Catch::Detail::stringify(x) == "(3,2,4)"); - #ifndef NDEBUG REQUIRE_THROWS_AS(x[4]=0, AssertError); const TinyVector<3,int>& const_x = x;