[grucell] remove temporary tensor
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 5 Jan 2022 03:42:45 +0000 (12:42 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 21 Apr 2022 01:05:48 +0000 (10:05 +0900)
 - Reduce temporary tensor

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/grucell.cpp

index d50b181..66879d6 100644 (file)
 
 namespace nntrainer {
 
-void grucell_forwarding(const unsigned int unit, const unsigned int batch_size,
-                        const bool disable_bias, const bool integrate_bias,
-                        const bool reset_after, ActiFunc &acti_func,
-                        ActiFunc &recurrent_acti_func, const Tensor &input,
-                        const Tensor &prev_hidden_state, Tensor &hidden_state,
-                        const Tensor &weight_ih, const Tensor &weight_hh,
-                        const Tensor &bias_h, const Tensor &bias_ih,
-                        const Tensor &bias_hh, Tensor &zrg) {
+/**
+ * @brief  gru forwarding
+ *
+ */
+static void grucell_forwarding(
+  const unsigned int unit, const unsigned int batch_size,
+  const bool disable_bias, const bool integrate_bias, const bool reset_after,
+  ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
+  const Tensor &prev_hidden_state, Tensor &hidden_state,
+  const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h,
+  const Tensor &bias_ih, const Tensor &bias_hh, Tensor &zrg) {
   input.dot(weight_ih, zrg);
 
   Tensor update_reset_gate =
@@ -120,11 +123,14 @@ void grucell_forwarding(const unsigned int unit, const unsigned int batch_size,
 
   update_gate.multiply_strided(prev_hidden_state, hidden_state);
   temp = update_gate.multiply(-1.0).add(1.0);
-
-  hidden_state.add_i(memory_cell.multiply_strided(temp));
+  memory_cell.multiply_strided(temp, hidden_state, 1.0f);
 }
 
-void grucell_calcGradient(
+/**
+ * @brief  gru calcGradient
+ *
+ */
+static void grucell_calcGradient(
   const unsigned int unit, const unsigned int batch_size,
   const bool disable_bias, const bool integrate_bias, const bool reset_after,
   ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
@@ -402,7 +408,8 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
   const Tensor &prev_hidden_state =
     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
-  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+  // hidden_state == output in grucell
+  Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT);
 
   const unsigned int batch_size = input.getDim().batch();
 
@@ -421,8 +428,6 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
 
   Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
 
-  Tensor hidden_state;
-
   grucell_forwarding(unit, batch_size, disable_bias, integrate_bias,
                      reset_after, acti_func, recurrent_acti_func, input,
                      prev_hidden_state, hidden_state, weight_ih, weight_hh,
@@ -433,8 +438,6 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
     mask.dropout_mask(dropout_rate);
     hidden_state.multiply_i(mask);
   }
-
-  output.copyData(hidden_state);
 }
 
 void GRUCellLayer::calcDerivative(RunLayerContext &context) {
@@ -508,25 +511,26 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
     }
   }
 
-  Tensor d_hidden_state(batch_size, 1, 1, unit);
-  d_hidden_state.copyData(incoming_derivative);
-
+  Tensor incoming_derivative_masked(batch_size, 1, 1, unit);
   if (dropout_rate > epsilon) {
-    d_hidden_state.multiply_i(
-      context.getTensor(wt_idx[GRUCellParams::dropout_mask]));
+    incoming_derivative.multiply_strided(
+      context.getTensor(wt_idx[GRUCellParams::dropout_mask]),
+      incoming_derivative_masked);
   }
 
-  grucell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
-                       reset_after, acti_func, recurrent_acti_func, input,
-                       prev_hidden_state, d_prev_hidden_state, d_hidden_state,
-                       d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
-                       bias_hh, d_bias_hh, zrg, d_zrg);
+  grucell_calcGradient(
+    unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func,
+    recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
+    dropout_rate > epsilon ? incoming_derivative_masked : incoming_derivative,
+    d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, bias_hh,
+    d_bias_hh, zrg, d_zrg);
 }
 
 void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
 
   context.updateTensor(wt_idx[GRUCellParams::zrg], batch);
+
   if (dropout_rate > epsilon) {
     context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
   }