Skip to content
Snippets Groups Projects
Commit 42fd7137 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add TinyMatrix's double-dot product

Also add a bunch of non discard qualifiers for TinyMatrix and TinyVector
and forbids Frobenius norm for non floating point TinyMatrix
parent afc43868
No related branches found
No related tags found
1 merge request!198Add TinyMatrix's double-dot product
This commit is part of merge request !198. Comments created here will be created in the context of that merge request.
...@@ -56,15 +56,13 @@ class [[nodiscard]] TinyMatrix ...@@ -56,15 +56,13 @@ class [[nodiscard]] TinyMatrix
} }
public: public:
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr bool
constexpr bool
isSquare() const noexcept isSquare() const noexcept
{ {
return M == N; return M == N;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr friend TinyMatrix<N, M, T>
constexpr friend TinyMatrix<N, M, T>
transpose(const TinyMatrix& A) transpose(const TinyMatrix& A)
{ {
TinyMatrix<N, M, T> tA; TinyMatrix<N, M, T> tA;
...@@ -76,36 +74,31 @@ class [[nodiscard]] TinyMatrix ...@@ -76,36 +74,31 @@ class [[nodiscard]] TinyMatrix
return tA; return tA;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr size_t
constexpr size_t
dimension() const dimension() const
{ {
return M * N; return M * N;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr size_t
constexpr size_t
numberOfValues() const numberOfValues() const
{ {
return this->dimension(); return this->dimension();
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr size_t
constexpr size_t
numberOfRows() const numberOfRows() const
{ {
return M; return M;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr size_t
constexpr size_t
numberOfColumns() const numberOfColumns() const
{ {
return N; return N;
} }
PUGS_INLINE PUGS_INLINE constexpr TinyMatrix
constexpr TinyMatrix
operator-() const operator-() const
{ {
TinyMatrix opposite; TinyMatrix opposite;
...@@ -140,6 +133,16 @@ class [[nodiscard]] TinyMatrix ...@@ -140,6 +133,16 @@ class [[nodiscard]] TinyMatrix
return *this; return *this;
} }
[[nodiscard]] PUGS_INLINE constexpr friend T
doubleDot(const TinyMatrix& A, const TinyMatrix& B)
{
T t = A.m_values[0] * B.m_values[0];
for (size_t i = 1; i < M * N; ++i) {
t += A.m_values[i] * B.m_values[i];
}
return t;
}
template <size_t P> template <size_t P>
PUGS_INLINE constexpr TinyMatrix<M, P, T> PUGS_INLINE constexpr TinyMatrix<M, P, T>
operator*(const TinyMatrix<N, P, T>& B) const operator*(const TinyMatrix<N, P, T>& B) const
...@@ -194,8 +197,7 @@ class [[nodiscard]] TinyMatrix ...@@ -194,8 +197,7 @@ class [[nodiscard]] TinyMatrix
return os; return os;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr bool
constexpr bool
operator==(const TinyMatrix& A) const operator==(const TinyMatrix& A) const
{ {
for (size_t i = 0; i < M * N; ++i) { for (size_t i = 0; i < M * N; ++i) {
...@@ -205,8 +207,7 @@ class [[nodiscard]] TinyMatrix ...@@ -205,8 +207,7 @@ class [[nodiscard]] TinyMatrix
return true; return true;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr bool
constexpr bool
operator!=(const TinyMatrix& A) const operator!=(const TinyMatrix& A) const
{ {
return not this->operator==(A); return not this->operator==(A);
...@@ -272,24 +273,21 @@ class [[nodiscard]] TinyMatrix ...@@ -272,24 +273,21 @@ class [[nodiscard]] TinyMatrix
return *this; return *this;
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr T&
constexpr T&
operator()(size_t i, size_t j) noexcept(NO_ASSERT) operator()(size_t i, size_t j) noexcept(NO_ASSERT)
{ {
Assert((i < M) and (j < N)); Assert((i < M) and (j < N));
return m_values[_index(i, j)]; return m_values[_index(i, j)];
} }
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr const T&
constexpr const T&
operator()(size_t i, size_t j) const noexcept(NO_ASSERT) operator()(size_t i, size_t j) const noexcept(NO_ASSERT)
{ {
Assert((i < M) and (j < N)); Assert((i < M) and (j < N));
return m_values[_index(i, j)]; return m_values[_index(i, j)];
} }
PUGS_INLINE PUGS_INLINE constexpr TinyMatrix&
constexpr TinyMatrix&
operator=(ZeroType) noexcept operator=(ZeroType) noexcept
{ {
static_assert(std::is_arithmetic<T>(), "Cannot assign 'zero' value for non-arithmetic types"); static_assert(std::is_arithmetic<T>(), "Cannot assign 'zero' value for non-arithmetic types");
...@@ -377,7 +375,7 @@ class [[nodiscard]] TinyMatrix ...@@ -377,7 +375,7 @@ class [[nodiscard]] TinyMatrix
}; };
template <size_t M, size_t N, typename T> template <size_t M, size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<M, N, T> [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<M, N, T>
tensorProduct(const TinyVector<M, T>& x, const TinyVector<N, T>& y) tensorProduct(const TinyVector<M, T>& x, const TinyVector<N, T>& y)
{ {
TinyMatrix<M, N, T> A; TinyMatrix<M, N, T> A;
...@@ -390,7 +388,7 @@ tensorProduct(const TinyVector<M, T>& x, const TinyVector<N, T>& y) ...@@ -390,7 +388,7 @@ tensorProduct(const TinyVector<M, T>& x, const TinyVector<N, T>& y)
} }
template <size_t N, typename T> template <size_t N, typename T>
PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr T
det(const TinyMatrix<N, N, T>& A) det(const TinyMatrix<N, N, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "determinant is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "determinant is not defined for non-arithmetic types");
...@@ -441,7 +439,7 @@ det(const TinyMatrix<N, N, T>& A) ...@@ -441,7 +439,7 @@ det(const TinyMatrix<N, N, T>& A)
} }
template <typename T> template <typename T>
PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr T
det(const TinyMatrix<1, 1, T>& A) det(const TinyMatrix<1, 1, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
...@@ -449,7 +447,7 @@ det(const TinyMatrix<1, 1, T>& A) ...@@ -449,7 +447,7 @@ det(const TinyMatrix<1, 1, T>& A)
} }
template <typename T> template <typename T>
PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr T
det(const TinyMatrix<2, 2, T>& A) det(const TinyMatrix<2, 2, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
...@@ -457,7 +455,7 @@ det(const TinyMatrix<2, 2, T>& A) ...@@ -457,7 +455,7 @@ det(const TinyMatrix<2, 2, T>& A)
} }
template <typename T> template <typename T>
PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr T
det(const TinyMatrix<3, 3, T>& A) det(const TinyMatrix<3, 3, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
...@@ -466,7 +464,7 @@ det(const TinyMatrix<3, 3, T>& A) ...@@ -466,7 +464,7 @@ det(const TinyMatrix<3, 3, T>& A)
} }
template <size_t M, size_t N, typename T> template <size_t M, size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<M - 1, N - 1, T> [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<M - 1, N - 1, T>
getMinor(const TinyMatrix<M, N, T>& A, size_t I, size_t J) getMinor(const TinyMatrix<M, N, T>& A, size_t I, size_t J)
{ {
static_assert(M >= 2 and N >= 2, "minor calculation requires at least 2x2 matrices"); static_assert(M >= 2 and N >= 2, "minor calculation requires at least 2x2 matrices");
...@@ -492,7 +490,7 @@ getMinor(const TinyMatrix<M, N, T>& A, size_t I, size_t J) ...@@ -492,7 +490,7 @@ getMinor(const TinyMatrix<M, N, T>& A, size_t I, size_t J)
} }
template <size_t N, typename T> template <size_t N, typename T>
PUGS_INLINE T [[nodiscard]] PUGS_INLINE T
trace(const TinyMatrix<N, N, T>& A) trace(const TinyMatrix<N, N, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "trace is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "trace is not defined for non-arithmetic types");
...@@ -504,26 +502,21 @@ trace(const TinyMatrix<N, N, T>& A) ...@@ -504,26 +502,21 @@ trace(const TinyMatrix<N, N, T>& A)
return t; return t;
} }
template <size_t N, typename T> template <size_t M, size_t N, typename T>
PUGS_INLINE T [[nodiscard]] PUGS_INLINE constexpr decltype(std::sqrt(std::declval<T>()))
frobeniusNorm(const TinyMatrix<N, N, T>& A) frobeniusNorm(const TinyMatrix<M, N, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "norm is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "norm is not defined for non-arithmetic types");
static_assert(std::is_floating_point<T>::value, "Frobenius norm is defined for floating point types only");
T t = 0; return std::sqrt(doubleDot(A, A));
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < N; ++j) {
t += A(i, j) * A(i, j);
}
}
return std::sqrt(t);
} }
template <size_t N, typename T> template <size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<N, N, T> inverse(const TinyMatrix<N, N, T>& A); [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<N, N, T> inverse(const TinyMatrix<N, N, T>& A);
template <typename T> template <typename T>
PUGS_INLINE constexpr TinyMatrix<1, 1, T> [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<1, 1, T>
inverse(const TinyMatrix<1, 1, T>& A) inverse(const TinyMatrix<1, 1, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
...@@ -534,7 +527,7 @@ inverse(const TinyMatrix<1, 1, T>& A) ...@@ -534,7 +527,7 @@ inverse(const TinyMatrix<1, 1, T>& A)
} }
template <size_t N, typename T> template <size_t N, typename T>
PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr T
cofactor(const TinyMatrix<N, N, T>& A, size_t i, size_t j) cofactor(const TinyMatrix<N, N, T>& A, size_t i, size_t j)
{ {
static_assert(std::is_arithmetic<T>::value, "cofactor is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "cofactor is not defined for non-arithmetic types");
...@@ -544,7 +537,7 @@ cofactor(const TinyMatrix<N, N, T>& A, size_t i, size_t j) ...@@ -544,7 +537,7 @@ cofactor(const TinyMatrix<N, N, T>& A, size_t i, size_t j)
} }
template <typename T> template <typename T>
PUGS_INLINE constexpr TinyMatrix<2, 2, T> [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<2, 2, T>
inverse(const TinyMatrix<2, 2, T>& A) inverse(const TinyMatrix<2, 2, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
...@@ -558,7 +551,7 @@ inverse(const TinyMatrix<2, 2, T>& A) ...@@ -558,7 +551,7 @@ inverse(const TinyMatrix<2, 2, T>& A)
} }
template <typename T> template <typename T>
PUGS_INLINE constexpr TinyMatrix<3, 3, T> [[nodiscard]] PUGS_INLINE constexpr TinyMatrix<3, 3, T>
inverse(const TinyMatrix<3, 3, T>& A) inverse(const TinyMatrix<3, 3, T>& A)
{ {
static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types"); static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
......
...@@ -33,8 +33,7 @@ class [[nodiscard]] TinyVector ...@@ -33,8 +33,7 @@ class [[nodiscard]] TinyVector
} }
public: public:
PUGS_INLINE [[nodiscard]] PUGS_INLINE constexpr TinyVector
constexpr TinyVector
operator-() const operator-() const
{ {
TinyVector opposite; TinyVector opposite;
...@@ -249,7 +248,7 @@ class [[nodiscard]] TinyVector ...@@ -249,7 +248,7 @@ class [[nodiscard]] TinyVector
}; };
template <size_t N, typename T> template <size_t N, typename T>
[[nodiscard]] PUGS_INLINE constexpr T [[nodiscard]] PUGS_INLINE constexpr decltype(std::sqrt(std::declval<T>()))
l2Norm(const TinyVector<N, T>& x) l2Norm(const TinyVector<N, T>& x)
{ {
static_assert(std::is_arithmetic<T>(), "Cannot compute L2 norm for non-arithmetic types"); static_assert(std::is_arithmetic<T>(), "Cannot compute L2 norm for non-arithmetic types");
......
...@@ -209,9 +209,21 @@ TEST_CASE("TinyMatrix", "[algebra]") ...@@ -209,9 +209,21 @@ TEST_CASE("TinyMatrix", "[algebra]")
REQUIRE(trace(TinyMatrix<4>(1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 0, 0, 2, 2)) == 1 + 0 + 1 + 2); REQUIRE(trace(TinyMatrix<4>(1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 0, 0, 2, 2)) == 1 + 0 + 1 + 2);
} }
SECTION("checking for doubleDot calculations")
{
REQUIRE(doubleDot(TinyMatrix<1, 1, int>(6), TinyMatrix<1, 1, int>(4)) == 24);
REQUIRE(doubleDot(TinyMatrix<2, 2>(5, 1, -3, 6), TinyMatrix<2, 2>(-2, 3, -5, 1)) ==
Catch::Approx(-10 + 3 + 15 + 6).epsilon(1E-14));
REQUIRE(doubleDot(TinyMatrix<4>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5),
TinyMatrix<4>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5)) ==
Catch::Approx(1 + 2.3 * 2.3 + 7 * 7 + 6.2 * 6.2 + 9 + 16 + 81 + 1 + 4.1 * 4.1 + 25 + 4 + 9 + 4 + 27 * 27 +
9 + 17.5 * 17.5)
.epsilon(1E-14));
}
SECTION("checking for norm calculations") SECTION("checking for norm calculations")
{ {
REQUIRE(frobeniusNorm(TinyMatrix<1, 1, int>(6)) == 6); REQUIRE(frobeniusNorm(TinyMatrix<1, 1>(6)) == Catch::Approx(6));
REQUIRE(frobeniusNorm(TinyMatrix<2, 2>(5, 1, -3, 6)) == Catch::Approx(std::sqrt(25 + 1 + 9 + 36)).epsilon(1E-14)); REQUIRE(frobeniusNorm(TinyMatrix<2, 2>(5, 1, -3, 6)) == Catch::Approx(std::sqrt(25 + 1 + 9 + 36)).epsilon(1E-14));
REQUIRE(frobeniusNorm(TinyMatrix<4>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5)) == REQUIRE(frobeniusNorm(TinyMatrix<4>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5)) ==
Catch::Approx(std::sqrt(1 + 2.3 * 2.3 + 7 * 7 + 6.2 * 6.2 + 9 + 16 + 81 + 1 + 4.1 * 4.1 + 25 + 4 + 9 + 4 + Catch::Approx(std::sqrt(1 + 2.3 * 2.3 + 7 * 7 + 6.2 * 6.2 + 9 + 16 + 81 + 1 + 4.1 * 4.1 + 25 + 4 + 9 + 4 +
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment