[tensor] Add derivatives for dot operation
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 29 Nov 2021 06:20:32 +0000 (15:20 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Dec 2021 05:46:00 +0000 (14:46 +0900)
Add derivatives for the dot operation as an easier interface to
calculate derivative for a dot operation used in the forward for both
the inputs.
Add the corresponding interface for dot_batched as well.

See Also #1721

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h

index 11ea36d..ae9254c 100644 (file)
@@ -978,6 +978,68 @@ Tensor Tensor::dot(Tensor const &m, bool trans, bool trans_m) const {
 
   return output;
 }
+/**
+ * @brief compute the derivative of this in the current tensor
+ * @todo will have to see if beta effects this computation
+ */
+Tensor &Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+                                bool trans, bool trans_m, float beta) {
+  bool deriv_trans_m = true;
+  bool deriv_trans = false;
+  /** @todo handle all cases of trans and trans_m */
+  if (!trans && trans_m) {
+    deriv_trans_m = false;
+  }
+
+  return output_deriv.dot(m, *this, deriv_trans, deriv_trans_m, beta);
+}
+
+/**
+ * @brief compute the derivative wrt m in the m tensor
+ * @note The caller tensor must be the same tensor as the one which called the
+ * dot() product.
+ */
+Tensor &Tensor::dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+                                bool trans, bool trans_m, float beta) const {
+  bool deriv_trans_m = false;
+  bool deriv_trans = true;
+  /** @todo handle all cases of trans and trans_m */
+
+  if (!trans && trans_m) {
+    output_deriv.dot(*this, m_deriv, deriv_trans, deriv_trans_m, beta);
+    return m_deriv;
+  } else {
+    return dot(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta);
+  }
+}
+
+Tensor &Tensor::dot_batched_deriv_wrt_1(Tensor const &m,
+                                        Tensor const &output_deriv, bool trans,
+                                        bool trans_m, float beta) {
+  bool deriv_trans_m = true;
+  bool deriv_trans = false;
+  /** @todo handle all cases of trans and trans_m */
+  if (!trans && trans_m) {
+    deriv_trans_m = false;
+  }
+
+  return output_deriv.dotBatched(m, *this, deriv_trans, deriv_trans_m, beta);
+}
+
+Tensor &Tensor::dot_batched_deriv_wrt_2(Tensor &m_deriv,
+                                        Tensor const &output_deriv, bool trans,
+                                        bool trans_m, float beta) const {
+  bool deriv_trans_m = false;
+  bool deriv_trans = true;
+  /** @todo handle all cases of trans and trans_m */
+
+  if (!trans && trans_m) {
+    output_deriv.dotBatched(*this, m_deriv, deriv_trans, deriv_trans_m, beta);
+    return m_deriv;
+  } else {
+    return dotBatched(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta);
+  }
+}
 
 /**
  * @note: This dot product flattens the fist 3 axis for the purpose of
index 52b4be3..53e5064 100644 (file)
@@ -618,6 +618,34 @@ public:
               bool trans_m = false, float beta = 0.0f) const;
 
   /**
+   * @brief compute the derivative of this in the current tensor
+   * @param m same as given to the dot()
+   * @param output_deriv the derivative of the output
+   * @param[in] trans same as given to the dot()
+   * @param[in] trans_m same as given to the dot()
+   * @param[in] beta same as given to the dot()
+   * @note This will compute the derivative in-place and will overwrite existing
+   * data in the tensor
+   */
+  Tensor &dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+                          bool trans = false, bool trans_m = false,
+                          float beta = 0.0f);
+
+  /**
+   * @brief compute the derivative wrt m in the m tensor
+   * @param m_deriv tensor where derivative wrt m will be stored
+   * @param output_deriv the derivative of the output
+   * @param[in] trans same as given to the dot()
+   * @param[in] trans_m same as given to the dot()
+   * @param[in] beta same as given to the dot()
+   * @note The caller tensor must be the same tensor as the one which called the
+   * dot() product.
+   */
+  Tensor &dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+                          bool trans = false, bool trans_m = false,
+                          float beta = 0.0f) const;
+
+  /**
    * @copydoc Tensor::dot(Tensor const &m, Tensor &output, bool trans,
               bool trans_m, float beta) const
    * @details performs dot operation over a batch of inputs
@@ -626,6 +654,22 @@ public:
                      bool trans_m = false, float beta = 0.0f) const;
 
   /**
+   * @copydoc Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const
+   &output_deriv, bool trans, bool trans_m, float beta)
+   */
+  Tensor &dot_batched_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+                                  bool trans = false, bool trans_m = false,
+                                  float beta = 0.0f);
+
+  /**
+   * @brief Tensor::dot_deriv_wrt_2(Tensor const &m_deriv, Tensor const
+   &output_deriv, bool trans, bool trans_m, float beta) const
+   */
+  Tensor &dot_batched_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+                                  bool trans = false, bool trans_m = false,
+                                  float beta = 0.0f) const;
+
+  /**
    * @brief     Transpose Tensor
    * @param[in] direction to transpose ex) 0:2:1
    * @retval    Calculated Tensor