From 88a786c45d77bbdd6ce49df9b4109a1f71421cd6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Tue, 24 Nov 2020 17:16:20 +0100
Subject: [PATCH] Add support for TinyMatrix-TinyVector product

---
 .../utils/BinaryOperatorRegisterForRnxn.cpp        | 14 ++++++++++++++
 .../utils/BinaryOperatorRegisterForRnxn.hpp        |  1 +
 2 files changed, 15 insertions(+)

diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp
index b44f42e91..3d8850e79 100644
--- a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp
+++ b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp
@@ -39,6 +39,19 @@ BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_scalar()
     std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, double, Rnxn>>());
 }
 
+template <size_t Dimension>
+void
+BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_vector()
+{
+  OperatorRepository& repository = OperatorRepository::instance();
+
+  using Rnxn = TinyMatrix<Dimension>;
+  using Rn   = TinyVector<Dimension>;
+
+  repository.addBinaryOperator<language::multiply_op>(
+    std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, Rnxn, Rn>>());
+}
+
 template <size_t Dimension>
 template <typename OperatorT>
 void
@@ -58,6 +71,7 @@ BinaryOperatorRegisterForRnxn<Dimension>::BinaryOperatorRegisterForRnxn()
   this->_register_comparisons();
 
   this->_register_product_by_a_scalar();
+  this->_register_product_by_a_vector();
 
   this->_register_arithmetic<language::plus_op>();
   this->_register_arithmetic<language::minus_op>();
diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp
index 7efb4cbb8..594740b62 100644
--- a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp
+++ b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp
@@ -10,6 +10,7 @@ class BinaryOperatorRegisterForRnxn
   void _register_comparisons();
 
   void _register_product_by_a_scalar();
+  void _register_product_by_a_vector();
 
   template <typename OperatorT>
   void _register_arithmetic();
-- 
GitLab