From: hyeonseok lee Date: Wed, 5 Jan 2022 03:42:45 +0000 (+0900) Subject: [grucell] remove temporary tensor X-Git-Tag: accepted/tizen/6.0/unified/20220610.012600~38 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c53dfbeacea2b8c1b85ba8805a2027389552ca39;p=platform%2Fcore%2Fml%2Fnntrainer.git [grucell] remove temporary tensor - Reduce temporary tensor Signed-off-by: hyeonseok lee --- diff --git a/nntrainer/layers/grucell.cpp b/nntrainer/layers/grucell.cpp index d50b181..66879d6 100644 --- a/nntrainer/layers/grucell.cpp +++ b/nntrainer/layers/grucell.cpp @@ -39,14 +39,17 @@ namespace nntrainer { -void grucell_forwarding(const unsigned int unit, const unsigned int batch_size, - const bool disable_bias, const bool integrate_bias, - const bool reset_after, ActiFunc &acti_func, - ActiFunc &recurrent_acti_func, const Tensor &input, - const Tensor &prev_hidden_state, Tensor &hidden_state, - const Tensor &weight_ih, const Tensor &weight_hh, - const Tensor &bias_h, const Tensor &bias_ih, - const Tensor &bias_hh, Tensor &zrg) { +/** + * @brief gru forwarding + * + */ +static void grucell_forwarding( + const unsigned int unit, const unsigned int batch_size, + const bool disable_bias, const bool integrate_bias, const bool reset_after, + ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input, + const Tensor &prev_hidden_state, Tensor &hidden_state, + const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h, + const Tensor &bias_ih, const Tensor &bias_hh, Tensor &zrg) { input.dot(weight_ih, zrg); Tensor update_reset_gate = @@ -120,11 +123,14 @@ void grucell_forwarding(const unsigned int unit, const unsigned int batch_size, update_gate.multiply_strided(prev_hidden_state, hidden_state); temp = update_gate.multiply(-1.0).add(1.0); - - hidden_state.add_i(memory_cell.multiply_strided(temp)); + memory_cell.multiply_strided(temp, hidden_state, 1.0f); } -void grucell_calcGradient( +/** + * @brief gru calcGradient + * + */ +static void grucell_calcGradient( const unsigned int unit, const unsigned int batch_size, const bool disable_bias, const bool integrate_bias, const bool reset_after, ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input, @@ -402,7 +408,8 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) { const Tensor &input = context.getInput(INOUT_INDEX::INPUT); const Tensor &prev_hidden_state = context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE); - Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT); + // hidden_state == output in grucell + Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT); const unsigned int batch_size = input.getDim().batch(); @@ -421,8 +428,6 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) { Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]); - Tensor hidden_state; - grucell_forwarding(unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func, recurrent_acti_func, input, prev_hidden_state, hidden_state, weight_ih, weight_hh, @@ -433,8 +438,6 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) { mask.dropout_mask(dropout_rate); hidden_state.multiply_i(mask); } - - output.copyData(hidden_state); } void GRUCellLayer::calcDerivative(RunLayerContext &context) { @@ -508,25 +511,26 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) { } } - Tensor d_hidden_state(batch_size, 1, 1, unit); - d_hidden_state.copyData(incoming_derivative); - + Tensor incoming_derivative_masked(batch_size, 1, 1, unit); if (dropout_rate > epsilon) { - d_hidden_state.multiply_i( - context.getTensor(wt_idx[GRUCellParams::dropout_mask])); + incoming_derivative.multiply_strided( + context.getTensor(wt_idx[GRUCellParams::dropout_mask]), + incoming_derivative_masked); } - grucell_calcGradient(unit, batch_size, disable_bias, integrate_bias, - reset_after, acti_func, recurrent_acti_func, input, - prev_hidden_state, d_prev_hidden_state, d_hidden_state, - d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, - bias_hh, d_bias_hh, zrg, d_zrg); + grucell_calcGradient( + unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func, + recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state, + dropout_rate > epsilon ? incoming_derivative_masked : incoming_derivative, + d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, bias_hh, + d_bias_hh, zrg, d_zrg); } void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) { const float dropout_rate = std::get(grucell_props); context.updateTensor(wt_idx[GRUCellParams::zrg], batch); + if (dropout_rate > epsilon) { context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch); }