From 8d32d8a9bd984c3ef406ede54b3515ccced54e39 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Mon, 29 Nov 2021 15:20:32 +0900 Subject: [PATCH] [tensor] Add derivatives for dot operation Add derivatives for the dot operation as an easier interface to calculate derivative for a dot operation used in the forward for both the inputs. Add the corresponding interface for dot_batched as well. See Also #1721 Signed-off-by: Parichay Kapoor --- nntrainer/tensor/tensor.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++++ nntrainer/tensor/tensor.h | 44 ++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 11ea36d..ae9254c 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -978,6 +978,68 @@ Tensor Tensor::dot(Tensor const &m, bool trans, bool trans_m) const { return output; } +/** + * @brief compute the derivative of this in the current tensor + * @todo will have to see if beta effects this computation + */ +Tensor &Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv, + bool trans, bool trans_m, float beta) { + bool deriv_trans_m = true; + bool deriv_trans = false; + /** @todo handle all cases of trans and trans_m */ + if (!trans && trans_m) { + deriv_trans_m = false; + } + + return output_deriv.dot(m, *this, deriv_trans, deriv_trans_m, beta); +} + +/** + * @brief compute the derivative wrt m in the m tensor + * @note The caller tensor must be the same tensor as the one which called the + * dot() product. + */ +Tensor &Tensor::dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv, + bool trans, bool trans_m, float beta) const { + bool deriv_trans_m = false; + bool deriv_trans = true; + /** @todo handle all cases of trans and trans_m */ + + if (!trans && trans_m) { + output_deriv.dot(*this, m_deriv, deriv_trans, deriv_trans_m, beta); + return m_deriv; + } else { + return dot(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta); + } +} + +Tensor &Tensor::dot_batched_deriv_wrt_1(Tensor const &m, + Tensor const &output_deriv, bool trans, + bool trans_m, float beta) { + bool deriv_trans_m = true; + bool deriv_trans = false; + /** @todo handle all cases of trans and trans_m */ + if (!trans && trans_m) { + deriv_trans_m = false; + } + + return output_deriv.dotBatched(m, *this, deriv_trans, deriv_trans_m, beta); +} + +Tensor &Tensor::dot_batched_deriv_wrt_2(Tensor &m_deriv, + Tensor const &output_deriv, bool trans, + bool trans_m, float beta) const { + bool deriv_trans_m = false; + bool deriv_trans = true; + /** @todo handle all cases of trans and trans_m */ + + if (!trans && trans_m) { + output_deriv.dotBatched(*this, m_deriv, deriv_trans, deriv_trans_m, beta); + return m_deriv; + } else { + return dotBatched(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta); + } +} /** * @note: This dot product flattens the fist 3 axis for the purpose of diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 52b4be3..53e5064 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -618,6 +618,34 @@ public: bool trans_m = false, float beta = 0.0f) const; /** + * @brief compute the derivative of this in the current tensor + * @param m same as given to the dot() + * @param output_deriv the derivative of the output + * @param[in] trans same as given to the dot() + * @param[in] trans_m same as given to the dot() + * @param[in] beta same as given to the dot() + * @note This will compute the derivative in-place and will overwrite existing + * data in the tensor + */ + Tensor &dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv, + bool trans = false, bool trans_m = false, + float beta = 0.0f); + + /** + * @brief compute the derivative wrt m in the m tensor + * @param m_deriv tensor where derivative wrt m will be stored + * @param output_deriv the derivative of the output + * @param[in] trans same as given to the dot() + * @param[in] trans_m same as given to the dot() + * @param[in] beta same as given to the dot() + * @note The caller tensor must be the same tensor as the one which called the + * dot() product. + */ + Tensor &dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv, + bool trans = false, bool trans_m = false, + float beta = 0.0f) const; + + /** * @copydoc Tensor::dot(Tensor const &m, Tensor &output, bool trans, bool trans_m, float beta) const * @details performs dot operation over a batch of inputs @@ -626,6 +654,22 @@ public: bool trans_m = false, float beta = 0.0f) const; /** + * @copydoc Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const + &output_deriv, bool trans, bool trans_m, float beta) + */ + Tensor &dot_batched_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv, + bool trans = false, bool trans_m = false, + float beta = 0.0f); + + /** + * @brief Tensor::dot_deriv_wrt_2(Tensor const &m_deriv, Tensor const + &output_deriv, bool trans, bool trans_m, float beta) const + */ + Tensor &dot_batched_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv, + bool trans = false, bool trans_m = false, + float beta = 0.0f) const; + + /** * @brief Transpose Tensor * @param[in] direction to transpose ex) 0:2:1 * @retval Calculated Tensor -- 2.7.4