return output;
}
+/**
+ * @brief compute the derivative of this in the current tensor
+ * @todo will have to see if beta effects this computation
+ */
+Tensor &Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+ bool trans, bool trans_m, float beta) {
+ bool deriv_trans_m = true;
+ bool deriv_trans = false;
+ /** @todo handle all cases of trans and trans_m */
+ if (!trans && trans_m) {
+ deriv_trans_m = false;
+ }
+
+ return output_deriv.dot(m, *this, deriv_trans, deriv_trans_m, beta);
+}
+
+/**
+ * @brief compute the derivative wrt m in the m tensor
+ * @note The caller tensor must be the same tensor as the one which called the
+ * dot() product.
+ */
+Tensor &Tensor::dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+ bool trans, bool trans_m, float beta) const {
+ bool deriv_trans_m = false;
+ bool deriv_trans = true;
+ /** @todo handle all cases of trans and trans_m */
+
+ if (!trans && trans_m) {
+ output_deriv.dot(*this, m_deriv, deriv_trans, deriv_trans_m, beta);
+ return m_deriv;
+ } else {
+ return dot(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta);
+ }
+}
+
+Tensor &Tensor::dot_batched_deriv_wrt_1(Tensor const &m,
+ Tensor const &output_deriv, bool trans,
+ bool trans_m, float beta) {
+ bool deriv_trans_m = true;
+ bool deriv_trans = false;
+ /** @todo handle all cases of trans and trans_m */
+ if (!trans && trans_m) {
+ deriv_trans_m = false;
+ }
+
+ return output_deriv.dotBatched(m, *this, deriv_trans, deriv_trans_m, beta);
+}
+
+Tensor &Tensor::dot_batched_deriv_wrt_2(Tensor &m_deriv,
+ Tensor const &output_deriv, bool trans,
+ bool trans_m, float beta) const {
+ bool deriv_trans_m = false;
+ bool deriv_trans = true;
+ /** @todo handle all cases of trans and trans_m */
+
+ if (!trans && trans_m) {
+ output_deriv.dotBatched(*this, m_deriv, deriv_trans, deriv_trans_m, beta);
+ return m_deriv;
+ } else {
+ return dotBatched(output_deriv, m_deriv, deriv_trans, deriv_trans_m, beta);
+ }
+}
/**
* @note: This dot product flattens the fist 3 axis for the purpose of
bool trans_m = false, float beta = 0.0f) const;
/**
+ * @brief compute the derivative of this in the current tensor
+ * @param m same as given to the dot()
+ * @param output_deriv the derivative of the output
+ * @param[in] trans same as given to the dot()
+ * @param[in] trans_m same as given to the dot()
+ * @param[in] beta same as given to the dot()
+ * @note This will compute the derivative in-place and will overwrite existing
+ * data in the tensor
+ */
+ Tensor &dot_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+ bool trans = false, bool trans_m = false,
+ float beta = 0.0f);
+
+ /**
+ * @brief compute the derivative wrt m in the m tensor
+ * @param m_deriv tensor where derivative wrt m will be stored
+ * @param output_deriv the derivative of the output
+ * @param[in] trans same as given to the dot()
+ * @param[in] trans_m same as given to the dot()
+ * @param[in] beta same as given to the dot()
+ * @note The caller tensor must be the same tensor as the one which called the
+ * dot() product.
+ */
+ Tensor &dot_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+ bool trans = false, bool trans_m = false,
+ float beta = 0.0f) const;
+
+ /**
* @copydoc Tensor::dot(Tensor const &m, Tensor &output, bool trans,
bool trans_m, float beta) const
* @details performs dot operation over a batch of inputs
bool trans_m = false, float beta = 0.0f) const;
/**
+ * @copydoc Tensor::dot_deriv_wrt_1(Tensor const &m, Tensor const
+ &output_deriv, bool trans, bool trans_m, float beta)
+ */
+ Tensor &dot_batched_deriv_wrt_1(Tensor const &m, Tensor const &output_deriv,
+ bool trans = false, bool trans_m = false,
+ float beta = 0.0f);
+
+ /**
+ * @brief Tensor::dot_deriv_wrt_2(Tensor const &m_deriv, Tensor const
+ &output_deriv, bool trans, bool trans_m, float beta) const
+ */
+ Tensor &dot_batched_deriv_wrt_2(Tensor &m_deriv, Tensor const &output_deriv,
+ bool trans = false, bool trans_m = false,
+ float beta = 0.0f) const;
+
+ /**
* @brief Transpose Tensor
* @param[in] direction to transpose ex) 0:2:1
* @retval Calculated Tensor