From 2ab73cd8cc523eeaed773feabbb5b50cd69025ba Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Tue, 26 Oct 2021 16:06:11 +0900 Subject: [PATCH] [layer] LSTM cell batchresize bug fix As LSTM cell uses the batch values not in the batch dimension but in the inner dimension, it creates bug with the existing layer context interface as it only allows updating the batch along the batch dimension. This patch provides a quick fix to combine the time and batch values in the batch dimension. Once recurrent realizer allows multiple inputs, the time factor will be removed and the reshapes can be removed. Signed-off-by: Parichay Kapoor --- nntrainer/layers/lstmcell.cpp | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/nntrainer/layers/lstmcell.cpp b/nntrainer/layers/lstmcell.cpp index c7680ee..1acac65 100644 --- a/nntrainer/layers/lstmcell.cpp +++ b/nntrainer/layers/lstmcell.cpp @@ -126,8 +126,9 @@ void LSTMCellLayer::finalize(InitLayerContext &context) { unsigned int max_timestep = std::get(lstm_props); TensorDim d = input_dim; - d.height(d.batch()); - d.batch(max_timestep); + // d.height(d.batch()); + d.height(1); + d.batch(max_timestep * d.batch()); d.width(unit); /** hidden dim = [ UnrollLength, 1, Batch, Units ] */ @@ -194,6 +195,11 @@ void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) { cell_.setZero(); } + unsigned int max_timestep = std::get(lstm_props); + hidden_.reshape({max_timestep, 1, batch, hidden_.width()}); + cell_.reshape({max_timestep, 1, batch, cell_.width()}); + fgio.reshape({max_timestep, 1, batch, fgio.width()}); + /** * @note when the recurrent realization happens, different instances of lstm * will share the weights, hidden state, cell and fgio memory. However, they @@ -247,6 +253,9 @@ void LSTMCellLayer::calcDerivative(RunLayerContext &context) { Tensor &weight = context.getWeight(wt_idx[LSTMParams::weight_xh]); Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + unsigned int max_timestep = std::get(lstm_props); + derivative_.reshape({max_timestep, 1, ret_.batch(), derivative_.width()}); + /** get the timestep values */ unsigned int start_timestep = std::get(lstm_props); @@ -264,6 +273,10 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) { Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]); Tensor &derivative_ = context.getTensorGrad(wt_idx[LSTMParams::hidden_state]); + /** + * TODO: hidden_ is only used from the previous timestep. Once it is supported + * as input, no need to cache the hidden_ itself + */ Tensor &hidden_ = context.getTensor(wt_idx[LSTMParams::hidden_state]); Tensor &incoming_deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); @@ -278,6 +291,13 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) { unsigned int max_timestep = std::get(lstm_props); unsigned int start_timestep = std::get(lstm_props); + derivative_.reshape({max_timestep, 1, batch, derivative_.width()}); + hidden_.reshape({max_timestep, 1, batch, hidden_.width()}); + m_cell_.reshape({max_timestep, 1, batch, m_cell_.width()}); + dm_cell_.reshape({max_timestep, 1, batch, dm_cell_.width()}); + fgio.reshape({max_timestep, 1, batch, fgio.width()}); + d_fgio.reshape({max_timestep, 1, batch, d_fgio.width()}); + if (start_timestep + 1 == max_timestep) { djdw_x.setZero(); djdw_h.setZero(); @@ -361,9 +381,10 @@ void LSTMCellLayer::calcGradient(RunLayerContext &context) { } void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) { - context.updateTensor(wt_idx[LSTMParams::hidden_state], batch); - context.updateTensor(wt_idx[LSTMParams::mem_cell], batch); - context.updateTensor(wt_idx[LSTMParams::fgio], batch); + unsigned int max_timestep = std::get(lstm_props); + context.updateTensor(wt_idx[LSTMParams::hidden_state], batch * max_timestep); + context.updateTensor(wt_idx[LSTMParams::mem_cell], batch * max_timestep); + context.updateTensor(wt_idx[LSTMParams::fgio], batch * max_timestep); context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch); } -- 2.7.4