From 88a786c45d77bbdd6ce49df9b4109a1f71421cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Tue, 24 Nov 2020 17:16:20 +0100 Subject: [PATCH] Add support for TinyMatrix-TinyVector product --- .../utils/BinaryOperatorRegisterForRnxn.cpp | 14 ++++++++++++++ .../utils/BinaryOperatorRegisterForRnxn.hpp | 1 + 2 files changed, 15 insertions(+) diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp index b44f42e91..3d8850e79 100644 --- a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp @@ -39,6 +39,19 @@ BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_scalar() std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, double, Rnxn>>()); } +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_vector() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, Rnxn, Rn>>()); +} + template <size_t Dimension> template <typename OperatorT> void @@ -58,6 +71,7 @@ BinaryOperatorRegisterForRnxn<Dimension>::BinaryOperatorRegisterForRnxn() this->_register_comparisons(); this->_register_product_by_a_scalar(); + this->_register_product_by_a_vector(); this->_register_arithmetic<language::plus_op>(); this->_register_arithmetic<language::minus_op>(); diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp index 7efb4cbb8..594740b62 100644 --- a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp @@ -10,6 +10,7 @@ class BinaryOperatorRegisterForRnxn void _register_comparisons(); void _register_product_by_a_scalar(); + void _register_product_by_a_vector(); template <typename OperatorT> void _register_arithmetic(); -- GitLab