From 61b191bacc0b9e2eb354e393e6ac398dd166fff4 Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Sat, 18 Dec 2021 10:21:29 +0900 Subject: [PATCH] [zoneout lstmcell] refactoring zoneout lstmcell layer - Refactoring zoneout lstmcell layer to use lstmcore functions. - Preserve lstm_cell_state tensor for calcGradient. - Remove lstmcell core layer Self evaluation: Build test: [X]Passed [ ]Failed [ ]Skipped Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: hyeonseok lee --- nntrainer/compiler/recurrent_realizer.cpp | 2 - nntrainer/layers/lstmcell_core.cpp | 611 +----------------------------- nntrainer/layers/lstmcell_core.h | 117 +----- nntrainer/layers/zoneout_lstmcell.cpp | 546 +++++++++++--------------- nntrainer/layers/zoneout_lstmcell.h | 30 +- 5 files changed, 233 insertions(+), 1073 deletions(-) diff --git a/nntrainer/compiler/recurrent_realizer.cpp b/nntrainer/compiler/recurrent_realizer.cpp index cf9c116..d750100 100644 --- a/nntrainer/compiler/recurrent_realizer.cpp +++ b/nntrainer/compiler/recurrent_realizer.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -180,7 +179,6 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step, return node->getType() == RNNCellLayer::type || node->getType() == LSTMLayer::type || node->getType() == LSTMCellLayer::type || - node->getType() == LSTMCellCoreLayer::type || node->getType() == ZoneoutLSTMCellLayer::type || node->getType() == GRUCellLayer::type; }; diff --git a/nntrainer/layers/lstmcell_core.cpp b/nntrainer/layers/lstmcell_core.cpp index 9bf6fe7..c6d31bd 100644 --- a/nntrainer/layers/lstmcell_core.cpp +++ b/nntrainer/layers/lstmcell_core.cpp @@ -4,628 +4,19 @@ * * @file lstmcell_core.cpp * @date 25 November 2021 - * @brief This is LSTMCellCore Layer Class of Neural Network + * @brief These are lstm core functions. * @see https://github.com/nnstreamer/nntrainer * @author hyeonseok lee * @bug No known bugs except for NYI items * */ -#include #include #include #include -#include - -// ENABLE_SHARING_WT_IDX implies does the wt_idx of lstm_core can be shared -// with lstm_cell variant layer. -// Todo: remove this if sharing wt_idx with other lstm variant is enabled -#define ENABLE_SHARING_WT_IDX 0 namespace nntrainer { -namespace init_lstm_context { -void fillLayerInitContext(InitLayerContext &context, - const InitLayerContext &core_context) { - /** real set the input flags */ - auto const &input_dims = context.getInputDimensions(); - for (unsigned int idx = 0; idx < core_context.getNumInputs(); idx++) { - context.setDynDimFlagInputDimension(idx, input_dims[idx].getDynDimFlag()); - context.setEffDimFlagInputDimension(idx, input_dims[idx].getEffDimFlag()); - } - - /** real request of tensors */ - for (auto const &ts : core_context.getTensorsSpec()) - context.requestTensor(ts); - - /** real request of weights */ - for (auto const &ws : core_context.getWeightsSpec()) - context.requestWeight(ws); -} - -void fillWeights(std::vector &weights, const RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep, - bool test) { - weights.resize(context.getNumWeights()); - for (unsigned int i = 0; i < context.getNumWeights(); ++i) { - if (training && (!test || i < context.getNumWeights() - 2)) { - weights[i] = - Weight(context.getWeight(wt_idx[i]), context.getWeightGrad(wt_idx[i]), - context.getWeightName(wt_idx[i])); - } else { - weights[i] = Weight(context.getWeight(wt_idx[i]), Tensor(), - context.getWeightName(wt_idx[i])); - } - } -} - -const std::vector getWeights(std::vector &weights) { - std::vector ret(weights.size()); - for (unsigned int i = 0; i < weights.size(); ++i) { - ret[i] = &weights[i]; - } - return ret; -} - -void fillInputs(std::vector &inputs, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep) { - inputs.resize(3); - Tensor empty; - const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim(); - const unsigned int batch_size = output_dim.batch(); - const unsigned int unit = output_dim.width(); - - const Tensor &input = context.getInput(wt_idx[0]); - const Tensor &outgoing_derivative = - training ? context.getOutgoingDerivative(0) : empty; - - Tensor &hidden_state = context.getTensor(wt_idx[1]); - hidden_state.reshape({max_timestep, 1, batch_size, unit}); - Tensor &hidden_state_derivative = - training ? context.getTensorGrad(wt_idx[1]) : empty; - if (training) { - hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - } - - Tensor &cell_state = context.getTensor(wt_idx[2]); - cell_state.reshape({max_timestep, 1, batch_size, unit}); - Tensor &cell_state_derivative = - training ? context.getTensorGrad(wt_idx[2]) : empty; - if (training) { - cell_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - } - - Tensor prev_hidden_state; - Tensor prev_hidden_state_derivative; - Tensor prev_cell_state; - Tensor prev_cell_state_derivative; - if (!timestep) { - prev_hidden_state = Tensor(batch_size, 1, 1, unit); - prev_hidden_state.setZero(); - prev_hidden_state_derivative = Tensor(batch_size, 1, 1, unit); - prev_hidden_state_derivative.setZero(); - prev_cell_state = Tensor(batch_size, 1, 1, unit); - prev_cell_state.setZero(); - prev_cell_state_derivative = Tensor(batch_size, 1, 1, unit); - prev_cell_state_derivative.setZero(); - } else { - prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1); - prev_hidden_state.reshape({batch_size, 1, 1, unit}); - if (training) { - prev_hidden_state_derivative = - hidden_state_derivative.getBatchSlice(timestep - 1, 1); - prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); - } - prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1); - prev_cell_state.reshape({batch_size, 1, 1, unit}); - if (training) { - prev_cell_state_derivative = - cell_state_derivative.getBatchSlice(timestep - 1, 1); - prev_cell_state_derivative.reshape({batch_size, 1, 1, unit}); - } - } - - inputs[0] = Var_Grad(input, outgoing_derivative, "lstmcell_core input"); - inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative, - context.getTensorName(wt_idx[1])); - inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative, - context.getTensorName(wt_idx[2])); -} - -const std::vector getInputs(std::vector &inputs) { - std::vector ret(inputs.size()); - for (unsigned int i = 0; i < inputs.size(); ++i) { - ret[i] = &inputs[i]; - } - return ret; -} - -void fillOutputs(std::vector &outputs, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep) { - outputs.resize(2); - Tensor empty; - const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim(); - const unsigned int batch_size = output_dim.batch(); - const unsigned int unit = output_dim.width(); - - Tensor &hidden_state = context.getTensor(wt_idx[1]); - hidden_state.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1); - next_hidden_state.reshape({batch_size, 1, 1, unit}); - Tensor next_hidden_state_derivative; - if (training) { - Tensor &hidden_state_derivative = context.getTensorGrad(wt_idx[1]); - hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - next_hidden_state_derivative = - hidden_state_derivative.getBatchSlice(timestep, 1); - next_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); - } - - Tensor &cell_state = context.getTensor(wt_idx[2]); - cell_state.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1); - next_cell_state.reshape({batch_size, 1, 1, unit}); - Tensor next_cell_state_derivative; - if (training) { - Tensor &cell_state_derivative = context.getTensorGrad(wt_idx[2]); - cell_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - next_cell_state_derivative = - cell_state_derivative.getBatchSlice(timestep, 1); - next_cell_state_derivative.reshape({batch_size, 1, 1, unit}); - } - - outputs[0] = Var_Grad(next_hidden_state, next_hidden_state_derivative, - context.getTensorName(wt_idx[1])); - outputs[1] = Var_Grad(next_cell_state, next_cell_state_derivative, - context.getTensorName(wt_idx[2])); -} - -const std::vector getOutputs(std::vector &outputs) { - std::vector ret(outputs.size()); - for (unsigned int i = 0; i < outputs.size(); ++i) { - ret[i] = &outputs[i]; - } - return ret; -} - -void fillTensors(std::vector &tensors, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep) { - tensors.resize(1); - -#if ENABLE_SHARING_WT_IDX - const Tensor &ifgo = context.getTensor(wt_idx[0]); - Tensor empty; - const Tensor &ifgo_derivative = - training ? context.getTensorGrad(wt_idx[0]) : empty; - tensors[0] = - Var_Grad(ifgo, ifgo_derivative, context.getTensorName(wt_idx[0])); -#else - const TensorDim &output_dim = context.getOutput(0).getDim(); - const unsigned int batch_size = output_dim.batch(); - const unsigned int unit = output_dim.width(); - - Tensor &ifgo = context.getTensor(wt_idx[0]); - const unsigned int NUM_GATE = ifgo.width() / unit; - ifgo.reshape({max_timestep, 1, batch_size, NUM_GATE * unit}); - Tensor ifgo_t = ifgo.getBatchSlice(timestep, 1); - ifgo_t.reshape({batch_size, 1, 1, NUM_GATE * unit}); - Tensor ifgo_derivative_t; - if (training) { - Tensor &ifgo_derivative = context.getTensorGrad(wt_idx[0]); - ifgo_derivative.reshape({max_timestep, 1, batch_size, NUM_GATE * unit}); - ifgo_derivative_t = ifgo_derivative.getBatchSlice(timestep, 1); - ifgo_derivative_t.reshape({batch_size, 1, 1, NUM_GATE * unit}); - } - tensors[0] = - Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0])); -#endif -} - -const std::vector getTensors(std::vector &tensors) { - std::vector ret(tensors.size()); - for (unsigned int i = 0; i < tensors.size(); ++i) { - ret[i] = &tensors[i]; - } - return ret; -} - -} // namespace init_lstm_context - -enum INDEX { - INPUT = 0, - HIDDEN_STATE_IN = 1, - CELL_STATE_IN = 2, - HIDDEN_STATE_OUT = 0, - CELL_STATE_OUT = 1 -}; - -enum LSTMCellCoreParams { - weight_ih, - weight_hh, - bias_h, - bias_ih, - bias_hh, - ifgo, -}; - -LSTMCellCoreLayer::LSTMCellCoreLayer() : - LayerImpl(), - lstmcell_core_props( - props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH, - props::RecurrentActivation() = ActivationType::ACT_SIGMOID, - props::IntegrateBias()), - acti_func(ActivationType::ACT_NONE, true), - recurrent_acti_func(ActivationType::ACT_NONE, true) { - wt_idx.fill(std::numeric_limits::max()); -} - -void LSTMCellCoreLayer::finalize(InitLayerContext &context) { -#if ENBABLE_SHARING_WEIGHT - const Tensor::Initializer weight_initializer = - std::get(*layer_impl_props).get(); - const Tensor::Initializer bias_initializer = - std::get(*layer_impl_props).get(); - const WeightRegularizer weight_regularizer = - std::get(*layer_impl_props).get(); - const float weight_regularizer_constant = - std::get(*layer_impl_props).get(); - const bool disable_bias = - std::get(*layer_impl_props).get(); -#endif - - NNTR_THROW_IF(std::get(lstmcell_core_props).empty(), - std::invalid_argument) - << "unit property missing for lstmcell_core layer"; - const unsigned int unit = std::get(lstmcell_core_props).get(); - const ActivationType hidden_state_activation_type = - std::get(lstmcell_core_props).get(); - const ActivationType recurrent_activation_type = - std::get(lstmcell_core_props).get(); -#if ENBABLE_SHARING_WEIGHT - const bool integrate_bias = - std::get(lstmcell_core_props).get(); -#endif - - if (context.getNumInputs() != 3) - throw std::invalid_argument("LSTMCellCore layer should takes 3 input"); - - // input_dim = [ batch, 1, 1, feature_size ] - const TensorDim &input_dim = context.getInputDimensions()[0]; - if (input_dim.height() != 1 || input_dim.channel() != 1) - throw std::invalid_argument( - "Input must be single time dimension for LSTMCellCore"); - // input_hidden_state_dim = [ batch, 1, 1, unit ] - const TensorDim &input_hidden_state_dim = context.getInputDimensions()[1]; - if (input_hidden_state_dim.channel() != 1 || - input_hidden_state_dim.height() != 1) - throw std::invalid_argument("Input hidden state's dimension should be " - "[batch, 1, 1, unit] for LSTMCellCore"); - // input_cell_state_dim = [ batch, 1, 1, unit ] - const TensorDim &input_cell_state_dim = context.getInputDimensions()[2]; - if (input_cell_state_dim.channel() != 1 || input_cell_state_dim.height() != 1) - throw std::invalid_argument("Input cell state's dimension should be " - "[batch, 1, 1, unit] for LSTMCellCore"); - - const unsigned int batch_size = input_dim.batch(); -#if ENABLE_SHARING_WT_IDX - const unsigned int feature_size = input_dim.width(); -#endif - - const TensorDim output_dim(batch_size, 1, 1, unit); - const TensorDim output_hidden_state_dim = input_hidden_state_dim; - const TensorDim output_cell_state_dim = input_cell_state_dim; - - context.setOutputDimensions( - {output_dim, output_hidden_state_dim, output_cell_state_dim}); - -#if ENABLE_SHARING_WT_IDX - // weight_initializer can be set seperately. weight_ih initializer, - // weight_hh initializer kernel initializer & recurrent_initializer in keras - // for now, it is set same way. - - // - weight_ih ( input to hidden ) - // : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o - TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}); - wt_idx[LSTMCellCoreParams::weight_ih] = - context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, "weight_ih", true); - // - weight_hh ( hidden to hidden ) - // : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o - TensorDim weight_hh_dim({unit, NUM_GATE * unit}); - wt_idx[LSTMCellCoreParams::weight_hh] = - context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer, - weight_regularizer_constant, "weight_hh", true); - if (!disable_bias) { - if (integrate_bias) { - // - bias_h ( input bias, hidden bias are integrate to 1 bias ) - // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o - TensorDim bias_h_dim({NUM_GATE * unit}); - wt_idx[LSTMCellCoreParams::bias_h] = - context.requestWeight(bias_h_dim, bias_initializer, - WeightRegularizer::NONE, 1.0f, "bias_h", true); - } else { - // - bias_ih ( input bias ) - // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o - TensorDim bias_ih_dim({NUM_GATE * unit}); - wt_idx[LSTMCellCoreParams::bias_ih] = - context.requestWeight(bias_ih_dim, bias_initializer, - WeightRegularizer::NONE, 1.0f, "bias_ih", true); - // - bias_hh ( hidden bias ) - // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o - TensorDim bias_hh_dim({NUM_GATE * unit}); - wt_idx[LSTMCellCoreParams::bias_hh] = - context.requestWeight(bias_hh_dim, bias_initializer, - WeightRegularizer::NONE, 1.0f, "bias_hh", true); - } - } -#endif - -#if ENABLE_SHARING_WT_IDX - TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit); - wt_idx[LSTMCellCoreParams::ifgo] = - context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN); -#endif - acti_func.setActiFunc(hidden_state_activation_type); - recurrent_acti_func.setActiFunc(recurrent_activation_type); -} - -void LSTMCellCoreLayer::setProperty(const std::vector &values) { - std::vector remain_props = - loadProperties(values, lstmcell_core_props); - LayerImpl::setProperty(remain_props); -} - -void LSTMCellCoreLayer::exportTo(Exporter &exporter, - const ExportMethods &method) const { -#if ENABLE_SHARING_WT_IDX - LayerImpl::exportTo(exporter, method); -#endif - exporter.saveResult(lstmcell_core_props, method, this); -} - -void LSTMCellCoreLayer::forwarding(RunLayerContext &context, bool training) { - const bool disable_bias = - std::get(*layer_impl_props).get(); - - const unsigned int unit = std::get(lstmcell_core_props).get(); - const bool integrate_bias = - std::get(lstmcell_core_props).get(); - - const Tensor &input = context.getInput(INDEX::INPUT); - const unsigned int batch_size = input.getDim().batch(); - - const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN); - const Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN); - Tensor &next_hidden_state = context.getOutput(INDEX::HIDDEN_STATE_OUT); - Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT); - -#if ENABLE_SHARING_WT_IDX - const Tensor &weight_ih = - context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]); - const Tensor &weight_hh = - context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]); - Tensor empty; - Tensor &bias_h = !disable_bias && integrate_bias - ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_h]) - : empty; - Tensor &bias_ih = !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_ih]) - : empty; - Tensor &bias_hh = !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_hh]) - : empty; -#else - const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih); - const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh); - Tensor empty; - Tensor &bias_h = !disable_bias && integrate_bias - ? context.getWeight(LSTMCellCoreParams::bias_h) - : empty; - // subtract index by 1 cause there is no bias_h - Tensor &bias_ih = !disable_bias && !integrate_bias - ? context.getWeight(LSTMCellCoreParams::bias_ih - 1) - : empty; - Tensor &bias_hh = !disable_bias && !integrate_bias - ? context.getWeight(LSTMCellCoreParams::bias_hh - 1) - : empty; -#endif - -#if ENABLE_SHARING_WT_IDX - Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]); -#else - Tensor &ifgo = context.getTensor(0); -#endif - - input.dot(weight_ih, ifgo); - prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0); - if (!disable_bias) { - if (integrate_bias) { - ifgo.add_i(bias_h); - } else { - ifgo.add_i(bias_ih); - ifgo.add_i(bias_hh); - } - } - - Tensor input_forget_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false); - Tensor input_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false); - Tensor forget_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false); - Tensor memory_cell = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false); - Tensor output_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false); - - recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate); - recurrent_acti_func.run_fn(output_gate, output_gate); - acti_func.run_fn(memory_cell, memory_cell); - - forget_gate.multiply_strided(prev_cell_state, next_cell_state); - memory_cell.multiply_strided(input_gate, next_cell_state, 1.0f); - - acti_func.run_fn(next_cell_state, next_hidden_state); - next_hidden_state.multiply_i_strided(output_gate); -} - -void LSTMCellCoreLayer::calcDerivative(RunLayerContext &context) { -#if ENABLE_SHARING_WT_IDX - Tensor &ifgo_derivative = - context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]); - const Tensor &weight_ih = - context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]); -#else - const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih); - Tensor &ifgo_derivative = context.getTensorGrad(0); -#endif - Tensor &outgoing_derivative = context.getOutgoingDerivative(INDEX::INPUT); - - ifgo_derivative.dot(weight_ih, outgoing_derivative, false, true); -} - -void LSTMCellCoreLayer::calcGradient(RunLayerContext &context) { - const bool disable_bias = - std::get(*layer_impl_props).get(); - - const unsigned int unit = std::get(lstmcell_core_props).get(); - const bool integrate_bias = - std::get(lstmcell_core_props).get(); - - const Tensor &input = context.getInput(INDEX::INPUT); - const unsigned int batch_size = input.getDim().batch(); - -#if ENABLE_SHARING_WT_IDX - Tensor &djdweight_ih = - context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_ih]); - const Tensor &weight_hh = - context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]); - Tensor &djdweight_hh = - context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_hh]); - Tensor empty; - Tensor &djdbias_h = - !disable_bias && integrate_bias - ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_h]) - : empty; - Tensor &djdbias_ih = - !disable_bias && !integrate_bias - ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_ih]) - : empty; - Tensor &djdbias_hh = - !disable_bias && !integrate_bias - ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_hh]) - : empty; -#else - Tensor &djdweight_ih = context.getWeightGrad(LSTMCellCoreParams::weight_ih); - const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh); - Tensor &djdweight_hh = context.getWeightGrad(LSTMCellCoreParams::weight_hh); - Tensor empty; - Tensor &djdbias_h = !disable_bias && integrate_bias - ? context.getWeightGrad(LSTMCellCoreParams::bias_h) - : empty; - // subtract index by 1 cause there is no bias_h(and also djdbias_h) - Tensor &djdbias_ih = - !disable_bias && !integrate_bias - ? context.getWeightGrad(LSTMCellCoreParams::bias_ih - 1) - : empty; - Tensor &djdbias_hh = - !disable_bias && !integrate_bias - ? context.getWeightGrad(LSTMCellCoreParams::bias_hh - 1) - : empty; -#endif - - const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN); - Tensor &prev_hidden_state_derivative = - context.getOutgoingDerivative(INDEX::HIDDEN_STATE_IN); - Tensor &next_hidden_state_derivative = - context.getIncomingDerivative(INDEX::HIDDEN_STATE_OUT); - - Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN); - Tensor &prev_cell_state_derivative = - context.getOutgoingDerivative(INDEX::CELL_STATE_IN); - Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT); - Tensor &next_cell_state_derivative = - context.getIncomingDerivative(INDEX::CELL_STATE_OUT); - -#if ENABLE_SHARING_WT_IDX - Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]); - Tensor &ifgo_derivative = - context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]); -#else - Tensor &ifgo = context.getTensor(0); - Tensor &ifgo_derivative = context.getTensorGrad(0); -#endif - - Tensor input_forget_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false); - Tensor input_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false); - Tensor forget_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false); - Tensor memory_cell = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false); - Tensor output_gate = - ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false); - - Tensor input_forget_gate_derivative = - ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false); - Tensor input_gate_derivative = - ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false); - Tensor forget_gate_derivative = - ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false); - Tensor memory_cell_derivative = ifgo_derivative.getSharedDataTensor( - {batch_size, 1, 1, unit}, unit * 2, false); - Tensor output_gate_derivative = ifgo_derivative.getSharedDataTensor( - {batch_size, 1, 1, unit}, unit * 3, false); - - acti_func.run_fn(next_cell_state, next_cell_state); - next_hidden_state_derivative.multiply_strided(next_cell_state, - output_gate_derivative); - - acti_func.run_prime_fn(next_cell_state, prev_cell_state_derivative, - next_hidden_state_derivative); - prev_cell_state_derivative.multiply_i_strided(output_gate); - prev_cell_state_derivative.add_i(next_cell_state_derivative); - - prev_cell_state_derivative.multiply_strided(input_gate, - memory_cell_derivative); - prev_cell_state_derivative.multiply_strided(memory_cell, - input_gate_derivative); - - prev_cell_state_derivative.multiply_strided(prev_cell_state, - forget_gate_derivative); - prev_cell_state_derivative.multiply_i_strided(forget_gate); - - recurrent_acti_func.run_prime_fn(output_gate, output_gate_derivative, - output_gate_derivative); - recurrent_acti_func.run_prime_fn(input_forget_gate, - input_forget_gate_derivative, - input_forget_gate_derivative); - acti_func.run_prime_fn(memory_cell, memory_cell_derivative, - memory_cell_derivative); - - if (!disable_bias) { - if (integrate_bias) { - ifgo_derivative.sum(0, djdbias_h, 1.0f, 1.0f); - } else { - ifgo_derivative.sum(0, djdbias_ih, 1.0f, 1.0f); - ifgo_derivative.sum(0, djdbias_hh, 1.0f, 1.0f); - } - } - input.dot(ifgo_derivative, djdweight_ih, true, false, 1.0f); - prev_hidden_state.dot(ifgo_derivative, djdweight_hh, true, false, 1.0f); - ifgo_derivative.dot(weight_hh, prev_hidden_state_derivative, false, true); -} - -void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) { - context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch); -} - void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size, const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func, ActiFunc &recurrent_acti_func, diff --git a/nntrainer/layers/lstmcell_core.h b/nntrainer/layers/lstmcell_core.h index 4f02398..cdb33ed 100644 --- a/nntrainer/layers/lstmcell_core.h +++ b/nntrainer/layers/lstmcell_core.h @@ -4,7 +4,7 @@ * * @file lstmcell_core.h * @date 25 November 2021 - * @brief This is LSTMCellCore Layer Class of Neural Network + * @brief These are lstm core functions. * @see https://github.com/nnstreamer/nntrainer * @author hyeonseok lee * @bug No known bugs except for NYI items @@ -16,124 +16,9 @@ #ifdef __cplusplus #include -#include -#include namespace nntrainer { -namespace init_lstm_context { -void fillLayerInitContext(InitLayerContext &context, - const InitLayerContext &core_context); -void fillWeights(std::vector &weights, const RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep, - bool test = false); -const std::vector getWeights(std::vector &weights); -void fillInputs(std::vector &inputs, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep); -const std::vector getInputs(std::vector &inputs); -void fillOutputs(std::vector &outputs, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep); -const std::vector getOutputs(std::vector &outputs); -void fillTensors(std::vector &tensors, RunLayerContext &context, - bool training, const std::vector &wt_idx, - const unsigned int max_timestep, const unsigned int timestep); -const std::vector getTensors(std::vector &tensors); -} // namespace init_lstm_context - -/** - * @class LSTMCellCoreLayer - * @brief LSTMCellCoreLayer - */ -class LSTMCellCoreLayer : public LayerImpl { -public: - /** - * @brief Constructor of LSTMCellLayer - */ - LSTMCellCoreLayer(); - - /** - * @brief Destructor of LSTMCellLayer - */ - ~LSTMCellCoreLayer() = default; - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(InitLayerContext &context) override; - - /** - * @copydoc Layer::forwarding(RunLayerContext &context, bool training) - */ - void forwarding(RunLayerContext &context, bool training) override; - - /** - * @copydoc Layer::calcDerivative(RunLayerContext &context) - */ - void calcDerivative(RunLayerContext &context) override; - - /** - * @copydoc Layer::calcGradient(RunLayerContext &context) - */ - void calcGradient(RunLayerContext &context) override; - /** - * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) - */ - void exportTo(Exporter &exporter, const ExportMethods &method) const override; - - /** - * @copydoc Layer::getType() - */ - const std::string getType() const override { - return LSTMCellCoreLayer::type; - }; - - /** - * @copydoc Layer::supportBackwarding() - */ - bool supportBackwarding() const override { return true; } - - /** - * @copydoc Layer::setProperty(const PropertyType type, const std::string - * &value) - */ - void setProperty(const std::vector &values) override; - - /** - * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) - */ - void setBatch(RunLayerContext &context, unsigned int batch) override; - - inline static const std::string type = "lstmcell_core"; - -private: - static constexpr unsigned int NUM_GATE = 4; - - /** - * Unit: number of output neurons - * HiddenStateActivation: activation type for hidden state. default is tanh - * RecurrentActivation: activation type for recurrent. default is sigmoid - * IntegrateBias: integrate bias_ih, bias_hh to bias_h - * - * */ - std::tuple - lstmcell_core_props; - std::array wt_idx; /**< indices of the weights */ - - /** - * @brief activation function for h_t : default is tanh - */ - ActiFunc acti_func; - - /** - * @brief activation function for recurrent : default is sigmoid - */ - ActiFunc recurrent_acti_func; -}; - void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size, const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func, ActiFunc &recurrent_acti_func, diff --git a/nntrainer/layers/zoneout_lstmcell.cpp b/nntrainer/layers/zoneout_lstmcell.cpp index 5c0c3d6..1f5e01a 100644 --- a/nntrainer/layers/zoneout_lstmcell.cpp +++ b/nntrainer/layers/zoneout_lstmcell.cpp @@ -32,63 +32,21 @@ enum ZoneoutLSTMParams { hidden_state, cell_state, ifgo, + lstm_cell_state, hidden_state_zoneout_mask, cell_state_zoneout_mask, }; -unsigned int hidden_state_origin_idx = 0, cell_state_origin_idx = 0; - -const std::vector -getWeightIdx(std::array &wt_idx, const bool disable_bias, - const bool integrate_bias, const bool test) { - std::vector ret; - ret.push_back(wt_idx[ZoneoutLSTMParams::weight_ih]); - ret.push_back(wt_idx[ZoneoutLSTMParams::weight_hh]); - if (!disable_bias) { - if (integrate_bias) { - ret.push_back(wt_idx[ZoneoutLSTMParams::bias_h]); - } else { - ret.push_back(wt_idx[ZoneoutLSTMParams::bias_ih]); - ret.push_back(wt_idx[ZoneoutLSTMParams::bias_hh]); - } - } - if (test) { - ret.push_back(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]); - ret.push_back(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]); - } - return ret; -} - -const std::vector -getInputIdx(std::array &wt_idx) { - std::vector ret(3); - ret[0] = SINGLE_INOUT_IDX; - ret[1] = wt_idx[ZoneoutLSTMParams::hidden_state]; - ret[2] = wt_idx[ZoneoutLSTMParams::cell_state]; - return ret; -} - -const std::vector -getOutputIdx(std::array &wt_idx) { - std::vector ret(3); - ret[0] = SINGLE_INOUT_IDX; - ret[1] = hidden_state_origin_idx; - ret[2] = cell_state_origin_idx; - return ret; -} - -const std::vector -getTensorIdx(std::array &wt_idx) { - std::vector ret(1); - ret[0] = wt_idx[ZoneoutLSTMParams::ifgo]; - return ret; -} - ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() : LayerImpl(), - zoneout_lstmcell_props(props::Unit(), HiddenStateZoneOutRate(), - CellStateZoneOutRate(), props::IntegrateBias(), Test(), - props::MaxTimestep(), props::Timestep()), + zoneout_lstmcell_props( + props::Unit(), props::IntegrateBias(), + props::HiddenStateActivation() = ActivationType::ACT_TANH, + props::RecurrentActivation() = ActivationType::ACT_SIGMOID, + HiddenStateZoneOutRate(), CellStateZoneOutRate(), Test(), + props::MaxTimestep(), props::Timestep()), + acti_func(ActivationType::ACT_NONE, true), + recurrent_acti_func(ActivationType::ACT_NONE, true), epsilon(1e-3) { wt_idx.fill(std::numeric_limits::max()); } @@ -112,7 +70,6 @@ bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid( } void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { -#if !ENABLE_SHARING_WT_IDX const Tensor::Initializer weight_initializer = std::get(*layer_impl_props).get(); const Tensor::Initializer bias_initializer = @@ -123,7 +80,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { std::get(*layer_impl_props).get(); const bool disable_bias = std::get(*layer_impl_props).get(); -#endif NNTR_THROW_IF(std::get(zoneout_lstmcell_props).empty(), std::invalid_argument) @@ -131,6 +87,10 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { const unsigned int unit = std::get(zoneout_lstmcell_props).get(); const bool integrate_bias = std::get(zoneout_lstmcell_props).get(); + const ActivationType hidden_state_activation_type = + std::get(zoneout_lstmcell_props).get(); + const ActivationType recurrent_activation_type = + std::get(zoneout_lstmcell_props).get(); const bool test = std::get(zoneout_lstmcell_props).get(); const unsigned int max_timestep = std::get(zoneout_lstmcell_props).get(); @@ -146,20 +106,17 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { // input_dim = [ batch_size, 1, 1, feature_size ] const TensorDim &input_dim = context.getInputDimensions()[0]; - if (input_dim.height() != 1 || input_dim.channel() != 1) + if (input_dim.channel() != 1 || input_dim.height() != 1) throw std::invalid_argument("Input must be single time dimension for " "ZoneoutLSTMCell (shape should be " "[batch_size, 1, 1, feature_size])"); const unsigned int batch_size = input_dim.batch(); -#if !ENABLE_SHARING_WT_IDX const unsigned int feature_size = input_dim.width(); -#endif // output_dim = [ batch_size, 1, 1, unit ] const TensorDim output_dim(batch_size, 1, 1, unit); context.setOutputDimensions({output_dim}); -#if !ENABLE_SHARING_WT_IDX // weight_initializer can be set seperately. weight_ih initializer, // weight_hh initializer kernel initializer & recurrent_initializer in keras // for now, it is set same way. @@ -199,7 +156,6 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { WeightRegularizer::NONE, 1.0f, "bias_hh", true); } } -#endif /** * TODO: hidden_state is only used from the previous timestep. Once it is @@ -216,24 +172,17 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN, false); - hidden_state_origin_idx = context.requestTensor( - hidden_state_dim, "hidden_state_origin", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); - cell_state_origin_idx = context.requestTensor( - cell_state_dim, "cell_state_origin", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); - -#if !ENABLE_SHARING_WT_IDX - /** - * TODO: make this independent of time dimension once recurrent realizer - * supports requesting tensors which are not always shared - */ - /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */ - const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit); + /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */ + const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit); wt_idx[ZoneoutLSTMParams::ifgo] = context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); -#endif + TensorLifespan::ITERATION_LIFESPAN); + + /** lstm_cell_state_dim = [ batch_size, 1, 1, unit ] */ + const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit); + wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor( + lstm_cell_state_dim, "lstm_cell_state", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN); // hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ] const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1, @@ -262,55 +211,20 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); } - TensorDim hidden_state_t_dim({batch_size, 1, 1, unit}); - TensorDim cell_state_t_dim({batch_size, 1, 1, unit}); - InitLayerContext core_context( - {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3, - context.executeInPlace(), context.getName()); - lstmcellcorelayer.finalize(core_context); - init_lstm_context::fillLayerInitContext(context, core_context); + acti_func.setActiFunc(hidden_state_activation_type); + recurrent_acti_func.setActiFunc(recurrent_activation_type); } void ZoneoutLSTMCellLayer::setProperty(const std::vector &values) { - std::vector remain_props = + const std::vector &remain_props = loadProperties(values, zoneout_lstmcell_props); - - // Note: In current implementation the lstmcellcorelayer also has - // a properties related to weight. But it is not exported or used anywhere. - lstmcellcorelayer.setProperty(remain_props); - if (!std::get(zoneout_lstmcell_props).empty()) { - lstmcellcorelayer.setProperty( - {"unit=" + to_string(std::get(zoneout_lstmcell_props))}); - } - lstmcellcorelayer.setProperty( - {"integrate_bias=" + - to_string(std::get(zoneout_lstmcell_props))}); - -#if !ENABLE_SHARING_WT_IDX - // To remove lstmcell core layer's properties - std::tuple - lstmcell_core_props; - std::vector impl_props = - loadProperties(remain_props, lstmcell_core_props); - - LayerImpl::setProperty(impl_props); -#endif + LayerImpl::setProperty(remain_props); } void ZoneoutLSTMCellLayer::exportTo(Exporter &exporter, const ExportMethods &method) const { -#if !ENABLE_SHARING_WT_IDX LayerImpl::exportTo(exporter, method); -#endif - exporter.saveResult( - std::forward_as_tuple( - std::get(zoneout_lstmcell_props), - std::get(zoneout_lstmcell_props), - std::get(zoneout_lstmcell_props), - std::get(zoneout_lstmcell_props), - std::get(zoneout_lstmcell_props)), - method, this); - lstmcellcorelayer.exportTo(exporter, method); + exporter.saveResult(zoneout_lstmcell_props, method, this); } void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) { @@ -318,169 +232,128 @@ void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) { std::get(*layer_impl_props).get(); const unsigned int unit = std::get(zoneout_lstmcell_props).get(); + const bool integrate_bias = + std::get(zoneout_lstmcell_props).get(); const float hidden_state_zoneout_rate = std::get(zoneout_lstmcell_props).get(); const float cell_state_zoneout_rate = std::get(zoneout_lstmcell_props).get(); - const bool integrate_bias = - std::get(zoneout_lstmcell_props).get(); const bool test = std::get(zoneout_lstmcell_props).get(); const unsigned int max_timestep = std::get(zoneout_lstmcell_props).get(); const unsigned int timestep = std::get(zoneout_lstmcell_props).get(); - const unsigned int batch_size = - context.getInput(SINGLE_INOUT_IDX).getDim().batch(); - + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - Tensor &hidden_state = - context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]); - hidden_state.reshape({max_timestep, 1, batch_size, unit}); + const unsigned int batch_size = input.getDim().batch(); + + const Tensor &weight_ih = + context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]); + const Tensor &weight_hh = + context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]); + Tensor empty; + Tensor &bias_h = !disable_bias && integrate_bias + ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h]) + : empty; + Tensor &bias_ih = !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih]) + : empty; + Tensor &bias_hh = !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh]) + : empty; + + Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]); + hs.reshape({max_timestep, 1, batch_size, unit}); Tensor prev_hidden_state; if (!timestep) { prev_hidden_state = Tensor(batch_size, 1, 1, unit); prev_hidden_state.setZero(); } else { - prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1); + prev_hidden_state = hs.getBatchSlice(timestep - 1, 1); prev_hidden_state.reshape({batch_size, 1, 1, unit}); } - Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1); - next_hidden_state.reshape({batch_size, 1, 1, unit}); + Tensor hidden_state = hs.getBatchSlice(timestep, 1); + hidden_state.reshape({batch_size, 1, 1, unit}); - Tensor &cell_state = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]); - cell_state.reshape({max_timestep, 1, batch_size, unit}); + Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]); + cs.reshape({max_timestep, 1, batch_size, unit}); Tensor prev_cell_state; if (!timestep) { prev_cell_state = Tensor(batch_size, 1, 1, unit); prev_cell_state.setZero(); } else { - prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1); + prev_cell_state = cs.getBatchSlice(timestep - 1, 1); prev_cell_state.reshape({batch_size, 1, 1, unit}); } - Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1); - next_cell_state.reshape({batch_size, 1, 1, unit}); + Tensor cell_state = cs.getBatchSlice(timestep, 1); + cell_state.reshape({batch_size, 1, 1, unit}); - if (!timestep) { - hidden_state.setZero(); - cell_state.setZero(); - } + Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]); - init_lstm_context::fillWeights( - weights, context, training, - getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep, - timestep, test); - init_lstm_context::fillInputs(inputs, context, training, getInputIdx(wt_idx), - max_timestep, timestep); - init_lstm_context::fillOutputs(outputs, context, training, - getOutputIdx(wt_idx), max_timestep, timestep); - init_lstm_context::fillTensors(tensors, context, training, - getTensorIdx(wt_idx), max_timestep, timestep); - RunLayerContext core_context(context.getName(), context.getTrainable(), - context.getLoss(), context.executeInPlace(), - init_lstm_context::getWeights(weights), - init_lstm_context::getInputs(inputs), - init_lstm_context::getOutputs(outputs), - init_lstm_context::getTensors(tensors)); - lstmcellcorelayer.forwarding(core_context, training); + Tensor &lstm_cell_state = + context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]); + + lstmcell_forwarding(unit, batch_size, disable_bias, integrate_bias, acti_func, + recurrent_acti_func, input, prev_hidden_state, + prev_cell_state, hidden_state, lstm_cell_state, weight_ih, + weight_hh, bias_h, bias_ih, bias_hh, ifgo); if (training) { - Tensor &hidden_state_zoneout_mask = + Tensor &hs_zoneout_mask = test ? context.getWeight( wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]) : context.getTensor( wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]); - hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state_zoneout_mask = - hidden_state_zoneout_mask.getBatchSlice(timestep, 1); - next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor hidden_state_zoneout_mask = + hs_zoneout_mask.getBatchSlice(timestep, 1); + hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); Tensor prev_hidden_state_zoneout_mask; if (!test) { prev_hidden_state_zoneout_mask = - next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate); + hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate); } else { - next_hidden_state_zoneout_mask.multiply(-1.0f, - prev_hidden_state_zoneout_mask); + hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask); prev_hidden_state_zoneout_mask.add_i(1.0f); } - Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx); - hidden_state_origin.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state_origin = - hidden_state_origin.getBatchSlice(timestep, 1); - next_hidden_state_origin.reshape({batch_size, 1, 1, unit}); + hidden_state.multiply_i(hidden_state_zoneout_mask); + prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state, + 1.0f); - next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask, - next_hidden_state); - prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, - next_hidden_state, 1.0f); - } - - if (training) { - Tensor &cell_state_zoneout_mask = + Tensor &cs_zoneout_mask = test ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]) : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]); - cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state_zoneout_mask = - cell_state_zoneout_mask.getBatchSlice(timestep, 1); - next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1); + cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); Tensor prev_cell_state_zoneout_mask; if (!test) { prev_cell_state_zoneout_mask = - next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); + cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); } else { - next_cell_state_zoneout_mask.multiply(-1.0f, - prev_cell_state_zoneout_mask); + cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask); prev_cell_state_zoneout_mask.add_i(1.0f); } - Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx); - cell_state_origin.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state_origin = - cell_state_origin.getBatchSlice(timestep, 1); - next_cell_state_origin.reshape({batch_size, 1, 1, unit}); - - next_cell_state_origin.multiply(next_cell_state_zoneout_mask, - next_cell_state); - prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state, - 1.0f); + lstm_cell_state.multiply(cell_state_zoneout_mask, cell_state); + prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f); } // Todo: zoneout at inference - output.copyData(next_hidden_state); + output.copyData(hidden_state); } void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) { - const bool disable_bias = - std::get(*layer_impl_props).get(); + Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]); + const Tensor &weight_ih = + context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]); + Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); - const bool integrate_bias = - std::get(zoneout_lstmcell_props).get(); - const bool test = std::get(zoneout_lstmcell_props).get(); - const unsigned int max_timestep = - std::get(zoneout_lstmcell_props).get(); - const unsigned int timestep = - std::get(zoneout_lstmcell_props).get(); - - init_lstm_context::fillWeights( - weights, context, true, - getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep, - timestep, test); - init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx), - max_timestep, timestep); - init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx), - max_timestep, timestep); - init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx), - max_timestep, timestep); - RunLayerContext core_context(context.getName(), context.getTrainable(), - context.getLoss(), context.executeInPlace(), - init_lstm_context::getWeights(weights), - init_lstm_context::getInputs(inputs), - init_lstm_context::getOutputs(outputs), - init_lstm_context::getTensors(tensors)); - lstmcellcorelayer.calcDerivative(core_context); + lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative); } void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) { @@ -488,176 +361,185 @@ void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) { std::get(*layer_impl_props).get(); const unsigned int unit = std::get(zoneout_lstmcell_props).get(); - const float hidden_state_zoneout_rate = - std::get(zoneout_lstmcell_props); - const float cell_state_zoneout_rate = - std::get(zoneout_lstmcell_props); const bool integrate_bias = std::get(zoneout_lstmcell_props).get(); - const bool test = std::get(zoneout_lstmcell_props); + const float hidden_state_zoneout_rate = + std::get(zoneout_lstmcell_props).get(); + const float cell_state_zoneout_rate = + std::get(zoneout_lstmcell_props).get(); + const bool test = std::get(zoneout_lstmcell_props).get(); const unsigned int max_timestep = std::get(zoneout_lstmcell_props).get(); const unsigned int timestep = std::get(zoneout_lstmcell_props).get(); - unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch(); - + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); const Tensor &incoming_derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &hidden_state_derivative = - context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]); - hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state_derivative = - hidden_state_derivative.getBatchSlice(timestep, 1); - next_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); + unsigned int batch_size = input.getDim().batch(); + + Tensor &d_weight_ih = + context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]); + const Tensor &weight_hh = + context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]); + Tensor &d_weight_hh = + context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]); + Tensor empty; + Tensor &d_bias_h = + !disable_bias && integrate_bias + ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h]) + : empty; + Tensor &d_bias_ih = + !disable_bias && !integrate_bias + ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]) + : empty; + Tensor &d_bias_hh = + !disable_bias && !integrate_bias + ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh]) + : empty; + + Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]); + hs.reshape({max_timestep, 1, batch_size, unit}); + Tensor prev_hidden_state; + if (!timestep) { + prev_hidden_state = Tensor(batch_size, 1, 1, unit); + prev_hidden_state.setZero(); + } else { + prev_hidden_state = hs.getBatchSlice(timestep - 1, 1); + prev_hidden_state.reshape({batch_size, 1, 1, unit}); + } + + Tensor &d_hs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]); + d_hs.reshape({max_timestep, 1, batch_size, unit}); + Tensor d_prev_hidden_state; + if (!timestep) { + d_prev_hidden_state = Tensor(batch_size, 1, 1, unit); + d_prev_hidden_state.setZero(); + } else { + d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1); + d_prev_hidden_state.reshape({batch_size, 1, 1, unit}); + } + Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1); + d_hidden_state.reshape({batch_size, 1, 1, unit}); + + Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]); + cs.reshape({max_timestep, 1, batch_size, unit}); + Tensor prev_cell_state; + if (!timestep) { + prev_cell_state = Tensor(batch_size, 1, 1, unit); + prev_cell_state.setZero(); + } else { + prev_cell_state = cs.getBatchSlice(timestep - 1, 1); + prev_cell_state.reshape({batch_size, 1, 1, unit}); + } + Tensor cell_state = cs.getBatchSlice(timestep, 1); + cell_state.reshape({batch_size, 1, 1, unit}); - Tensor &cell_state_derivative = - context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]); - cell_state_derivative.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state_derivative = - cell_state_derivative.getBatchSlice(timestep, 1); - next_cell_state_derivative.reshape({batch_size, 1, 1, unit}); + Tensor &d_cs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]); + d_cs.reshape({max_timestep, 1, batch_size, unit}); + Tensor d_prev_cell_state; + if (!timestep) { + d_prev_cell_state = Tensor(batch_size, 1, 1, unit); + d_prev_cell_state.setZero(); + } else { + d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1); + d_prev_cell_state.reshape({batch_size, 1, 1, unit}); + } + Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1); + d_cell_state.reshape({batch_size, 1, 1, unit}); + + Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]); + Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]); + + const Tensor &lstm_cell_state = + context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]); + Tensor &d_lstm_cell_state = + context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]); if (timestep + 1 == max_timestep) { - Tensor &djdweight_ih = - context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]); - Tensor &djdweight_hh = - context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]); - djdweight_ih.setZero(); - djdweight_hh.setZero(); + d_weight_ih.setZero(); + d_weight_hh.setZero(); if (!disable_bias) { if (integrate_bias) { - Tensor &djdbias_h = - context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h]); - djdbias_h.setZero(); + d_bias_h.setZero(); } else { - Tensor &djdbias_ih = - context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]); - djdbias_ih.setZero(); - Tensor &djdbias_hh = - context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh]); - djdbias_hh.setZero(); + d_bias_ih.setZero(); + d_bias_hh.setZero(); } } - - hidden_state_derivative.setZero(); - cell_state_derivative.setZero(); + d_hidden_state.setZero(); + d_cell_state.setZero(); } - next_hidden_state_derivative.add_i(incoming_derivative); + d_hidden_state.add_i(incoming_derivative); - Tensor prev_hidden_state_derivative; - Tensor prev_cell_state_derivative; - Tensor prev_hidden_state_derivative_residual; - Tensor prev_cell_state_derivative_residual; + Tensor d_prev_hidden_state_residual; - Tensor &hidden_state_zoneout_mask = + Tensor &hs_zoneout_mask = test ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]) : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]); - hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state_zoneout_mask = - hidden_state_zoneout_mask.getBatchSlice(timestep, 1); - next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor hidden_state_zoneout_mask = hs_zoneout_mask.getBatchSlice(timestep, 1); + hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); Tensor prev_hidden_state_zoneout_mask; if (!test) { prev_hidden_state_zoneout_mask = - next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate); + hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate); } else { - next_hidden_state_zoneout_mask.multiply(-1.0f, - prev_hidden_state_zoneout_mask); + hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask); prev_hidden_state_zoneout_mask.add_i(1.0f); } - if (timestep) { - prev_hidden_state_derivative = - hidden_state_derivative.getBatchSlice(timestep - 1, 1); - prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); - next_hidden_state_derivative.multiply( - prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual); - } - - Tensor &hidden_state_origin_derivative = - context.getTensorGrad(hidden_state_origin_idx); - hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_hidden_state_origin_derivative = - hidden_state_origin_derivative.getBatchSlice(timestep, 1); - next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit}); + d_hidden_state.multiply(prev_hidden_state_zoneout_mask, + d_prev_hidden_state_residual); + d_hidden_state.multiply_i(hidden_state_zoneout_mask); - next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask, - next_hidden_state_origin_derivative); + Tensor d_prev_cell_state_residual; - Tensor &cell_state_zoneout_mask = + Tensor &cs_zoneout_mask = test ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]) : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]); - cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state_zoneout_mask = - cell_state_zoneout_mask.getBatchSlice(timestep, 1); - next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1); + cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); Tensor prev_cell_state_zoneout_mask; if (!test) { prev_cell_state_zoneout_mask = - next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); + cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); } else { - next_cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask); + cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask); prev_cell_state_zoneout_mask.add_i(1.0f); } - if (timestep) { - prev_cell_state_derivative = - cell_state_derivative.getBatchSlice(timestep - 1, 1); - prev_cell_state_derivative.reshape({batch_size, 1, 1, unit}); - next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask, - prev_cell_state_derivative_residual); - } + d_cell_state.multiply(prev_cell_state_zoneout_mask, + d_prev_cell_state_residual); + d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state); - Tensor &cell_state_origin_derivative = - context.getTensorGrad(cell_state_origin_idx); - cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit}); - Tensor next_cell_state_origin_derivative = - cell_state_origin_derivative.getBatchSlice(timestep, 1); - next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit}); - - next_cell_state_derivative.multiply(next_cell_state_zoneout_mask, - next_cell_state_origin_derivative); - - init_lstm_context::fillWeights( - weights, context, true, - getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep, - timestep, test); - init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx), - max_timestep, timestep); - init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx), - max_timestep, timestep); - init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx), - max_timestep, timestep); - RunLayerContext core_context(context.getName(), context.getTrainable(), - context.getLoss(), context.executeInPlace(), - init_lstm_context::getWeights(weights), - init_lstm_context::getInputs(inputs), - init_lstm_context::getOutputs(outputs), - init_lstm_context::getTensors(tensors)); - lstmcellcorelayer.calcGradient(core_context); - - if (timestep) { - prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual); - prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual); - } + lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias, + acti_func, recurrent_acti_func, input, + prev_hidden_state, d_prev_hidden_state, prev_cell_state, + d_prev_cell_state, d_hidden_state, lstm_cell_state, + d_lstm_cell_state, d_weight_ih, weight_hh, d_weight_hh, + d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo); + + d_prev_hidden_state.add_i(d_prev_hidden_state_residual); + d_prev_cell_state.add_i(d_prev_cell_state_residual); } void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) { const unsigned int max_timestep = std::get(zoneout_lstmcell_props); + context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state], max_timestep * batch); context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state], max_timestep * batch); - context.updateTensor(hidden_state_origin_idx, max_timestep * batch); - context.updateTensor(cell_state_origin_idx, max_timestep * batch); - context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch); + context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch); + context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch); context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask], max_timestep * batch); diff --git a/nntrainer/layers/zoneout_lstmcell.h b/nntrainer/layers/zoneout_lstmcell.h index 3895b83..515f55e 100644 --- a/nntrainer/layers/zoneout_lstmcell.h +++ b/nntrainer/layers/zoneout_lstmcell.h @@ -169,34 +169,38 @@ public: private: static constexpr unsigned int NUM_GATE = 4; - LSTMCellCoreLayer lstmcellcorelayer; - /** * Unit: number of output neurons + * IntegrateBias: integrate bias_ih, bias_hh to bias_h + * HiddenStateActivation: activation type for hidden state. default is tanh + * RecurrentActivation: activation type for recurrent. default is sigmoid * HiddenStateZoneOutRate: zoneout rate for hidden_state * CellStateZoneOutRate: zoneout rate for cell_state - * IntegrateBias: integrate bias_ih, bias_hh to bias_h * Test: property for test mode * MaxTimestep: maximum timestep for zoneout lstmcell * TimeStep: timestep for which lstm should operate * * */ - std::tuple + std::tuple zoneout_lstmcell_props; - std::array wt_idx; /**< indices of the weights */ + std::array wt_idx; /**< indices of the weights */ + + /** + * @brief activation function for h_t : default is tanh + */ + ActiFunc acti_func; + + /** + * @brief activation function for recurrent : default is sigmoid + */ + ActiFunc recurrent_acti_func; /** * @brief Protect overflow */ float epsilon; - - // These weights, inputs, outputs, tensors are all for the lstm_core - // Todo: remove this - std::vector weights; - std::vector inputs; - std::vector outputs; - std::vector tensors; }; } // namespace nntrainer -- 2.7.4