From ce54a7870287e857e8b4562f9c3b56f9a8508fbf Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Wed, 1 Dec 2021 03:51:18 +0900 Subject: [PATCH] [zoneout lstmcell] Implement zoneout lstm cell - Zoneout lstmcell is based on the paper and the github repo which is mentioned in paper. - Todo: Zoneout at inference time is not implemented yet. refer: https://arxiv.org/pdf/1606.01305.pdf https://github.com/teganmaharaj/zoneout Self evaluation: Build test: [X]Passed [ ]Failed [ ]Skipped Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: hyeonseok lee --- api/ccapi/include/layer.h | 9 + jni/Android.mk | 1 + nntrainer/app_context.cpp | 4 + nntrainer/compiler/recurrent_realizer.cpp | 2 + nntrainer/layers/lstmcell.cpp | 2 + nntrainer/layers/lstmcell_core.cpp | 3 +- nntrainer/layers/meson.build | 1 + nntrainer/layers/zoneout_lstmcell.cpp | 613 ++++++++++++++++++++++++++++++ nntrainer/layers/zoneout_lstmcell.h | 201 ++++++++++ nntrainer/tensor/tensor.cpp | 28 +- nntrainer/tensor/tensor.h | 26 ++ 11 files changed, 886 insertions(+), 4 deletions(-) create mode 100644 nntrainer/layers/zoneout_lstmcell.cpp create mode 100644 nntrainer/layers/zoneout_lstmcell.h diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index d49ca6f..56b6ac3 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -71,6 +71,7 @@ enum LayerType { LAYER_RESHAPE, /**< Reshape Layer type */ LAYER_RNNCELL, /**< RNN Cell Layer type */ LAYER_LSTMCELL, /**< LSTM Cell Layer type */ + LAYER_ZONEOUT_LSTMCELL, /**< Zoneout LSTM Cell Layer type */ LAYER_GRUCELL, /**< GRU Cell Layer type */ LAYER_REDUCE_MEAN, /**< Reduce mean Layer type */ LAYER_LOSS_MSE = 500, /**< Mean Squared Error Loss Layer type */ @@ -339,6 +340,14 @@ LSTMCell(const std::vector &properties = {}) { } /** + * @brief Helper function to create ZoneoutLSTMCell layer + */ +inline std::unique_ptr +ZoneoutLSTMCell(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_ZONEOUT_LSTMCELL, properties); +} + +/** * @brief Helper function to create GRU layer */ inline std::unique_ptr diff --git a/jni/Android.mk b/jni/Android.mk index 997bee3..bb99159 100644 --- a/jni/Android.mk +++ b/jni/Android.mk @@ -175,6 +175,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/lstm.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/lstmcell_core.cpp \ + $(NNTRAINER_ROOT)/nntrainer/layers/zoneout_lstmcell.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/gru.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/grucell.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/time_dist.cpp \ diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index 66c930f..2e44a9a 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -66,6 +66,7 @@ #include #include #include +#include #ifdef ENABLE_TFLITE_BACKBONE #include @@ -251,6 +252,9 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_LSTM); ac.registerFactory(nntrainer::createLayer, LSTMCellLayer::type, LayerType::LAYER_LSTMCELL); + ac.registerFactory(nntrainer::createLayer, + ZoneoutLSTMCellLayer::type, + LayerType::LAYER_ZONEOUT_LSTMCELL); ac.registerFactory(nntrainer::createLayer, SplitLayer::type, LayerType::LAYER_SPLIT); ac.registerFactory(nntrainer::createLayer, GRULayer::type, diff --git a/nntrainer/compiler/recurrent_realizer.cpp b/nntrainer/compiler/recurrent_realizer.cpp index 21c1b9e..4d866c6 100644 --- a/nntrainer/compiler/recurrent_realizer.cpp +++ b/nntrainer/compiler/recurrent_realizer.cpp @@ -24,6 +24,7 @@ #include #include #include +#include namespace nntrainer { @@ -134,6 +135,7 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step, node->getType() == LSTMLayer::type || node->getType() == LSTMCellLayer::type || node->getType() == LSTMCellCoreLayer::type || + node->getType() == ZoneoutLSTMCellLayer::type || node->getType() == GRUCellLayer::type; }; diff --git a/nntrainer/layers/lstmcell.cpp b/nntrainer/layers/lstmcell.cpp index c9905a1..e2c4244 100644 --- a/nntrainer/layers/lstmcell.cpp +++ b/nntrainer/layers/lstmcell.cpp @@ -6,6 +6,8 @@ * @date 17 March 2021 * @brief This is LSTMCell Layer Class of Neural Network * @see https://github.com/nnstreamer/nntrainer + * https://arxiv.org/pdf/1606.01305.pdf + * https://github.com/teganmaharaj/zoneout * @author Parichay Kapoor * @bug No known bugs except for NYI items * diff --git a/nntrainer/layers/lstmcell_core.cpp b/nntrainer/layers/lstmcell_core.cpp index 3930cc6..5e73c7d 100644 --- a/nntrainer/layers/lstmcell_core.cpp +++ b/nntrainer/layers/lstmcell_core.cpp @@ -125,7 +125,7 @@ void fillInputs(std::vector &inputs, RunLayerContext &context, } } - inputs[0] = Var_Grad(input, outgoing_derivative); + inputs[0] = Var_Grad(input, outgoing_derivative, "lstmcell_core input"); inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative, context.getTensorName(wt_idx[1])); inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative, @@ -220,7 +220,6 @@ void fillTensors(std::vector &tensors, RunLayerContext &context, } tensors[0] = Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0])); - context.getTensorName(wt_idx[0]); #endif } diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index e863dbc..b3cf7ce 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -27,6 +27,7 @@ layer_sources = [ 'lstm.cpp', 'lstmcell.cpp', 'lstmcell_core.cpp', + 'zoneout_lstmcell.cpp', 'time_dist.cpp', 'common_properties.cpp', 'split_layer.cpp', diff --git a/nntrainer/layers/zoneout_lstmcell.cpp b/nntrainer/layers/zoneout_lstmcell.cpp new file mode 100644 index 0000000..921b1e5 --- /dev/null +++ b/nntrainer/layers/zoneout_lstmcell.cpp @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 hyeonseok lee + * + * @file zoneout_lstmcell.cpp + * @date 30 November 2021 + * @brief This is ZoneoutLSTMCell Layer Class of Neural Network + * @see https://github.com/nnstreamer/nntrainer + * @author hyeonseok lee + * @bug No known bugs except for NYI items + * + */ + +#include +#include +#include +#include +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +enum ZoneoutLSTMParams { + weight_ih, + weight_hh, + bias_ih, + hidden_state, + cell_state, + ifgo, + hidden_state_zoneout_mask, + cell_state_zoneout_mask, +}; + +unsigned int hidden_state_origin_idx = 0, cell_state_origin_idx = 0; + +const std::vector +getInputIdx(std::array &wt_idx) { + std::vector ret(3); + ret[0] = SINGLE_INOUT_IDX; + ret[1] = wt_idx[ZoneoutLSTMParams::hidden_state]; + ret[2] = wt_idx[ZoneoutLSTMParams::cell_state]; + return ret; +} + +const std::vector +getOutputIdx(std::array &wt_idx) { + std::vector ret(3); + ret[0] = SINGLE_INOUT_IDX; + ret[1] = hidden_state_origin_idx; + ret[2] = cell_state_origin_idx; + return ret; +} + +const std::vector +getTensorIdx(std::array &wt_idx) { + std::vector ret(1); + ret[0] = wt_idx[ZoneoutLSTMParams::ifgo]; + return ret; +} + +ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() : + LayerImpl(), + zoneout_lstmcell_props(props::Unit(), HiddenStateZoneOutRate(), + CellStateZoneOutRate(), Test(), props::MaxTimestep(), + props::Timestep()), + wt_idx({0}), + epsilon(1e-3) {} + +bool ZoneoutLSTMCellLayer::HiddenStateZoneOutRate::isValid( + const float &value) const { + if (value < 0.0f || value > 1.0f) { + return false; + } else { + return true; + } +} + +bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid( + const float &value) const { + if (value < 0.0f || value > 1.0f) { + return false; + } else { + return true; + } +} + +void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) { + NNTR_THROW_IF(std::get(zoneout_lstmcell_props).empty(), + std::invalid_argument) + << "unit property missing for zoneout_lstmcell layer"; + const unsigned int unit = std::get(zoneout_lstmcell_props).get(); + const float hidden_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const float cell_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const bool test = std::get(zoneout_lstmcell_props); + const unsigned int max_timestep = + std::get(zoneout_lstmcell_props); + +#if !ENABLE_SHARING_WT_IDX + const Tensor::Initializer weight_initializer = + std::get(*layer_impl_props); + const Tensor::Initializer bias_initializer = + std::get(*layer_impl_props); + const nntrainer::WeightRegularizer weight_regularizer = + std::get(*layer_impl_props); + const float weight_regularizer_constant = + std::get(*layer_impl_props); +#endif + + if (context.getNumInputs() != 1) + throw std::invalid_argument("ZoneoutLSTMCellLayer takes only one input"); + if (std::get(zoneout_lstmcell_props).empty()) + throw std::invalid_argument("Number of unroll steps(max timestep) must be " + "provided to zoneout LSTM cells"); + if (std::get(zoneout_lstmcell_props).empty()) + throw std::invalid_argument( + "Current timestep must be provided to zoneout LSTM cell"); + + // input_dim = [ batch_size, 1, 1, feature_size ] + const TensorDim &input_dim = context.getInputDimensions()[0]; + if (input_dim.height() != 1 || input_dim.channel() != 1) + throw std::invalid_argument("Input must be single time dimension for " + "ZoneoutLSTMCell (shape should be " + "[batch_size, 1, 1, feature_size])"); + const unsigned int batch_size = input_dim.batch(); +#if !ENABLE_SHARING_WT_IDX + const unsigned int feature_size = input_dim.width(); +#endif + + // output_dim = [ batch_size, 1, 1, unit ] + const TensorDim output_dim(batch_size, 1, 1, unit); + context.setOutputDimensions({output_dim}); + +#if !ENABLE_SHARING_WT_IDX + // weight_initializer can be set seperately. weight_ih initializer, + // weight_hh initializer kernel initializer & recurrent_initializer in keras + // for now, it is set same way. + + // - weight_ih ( input to hidden ) + // : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o + TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}); + wt_idx[ZoneoutLSTMParams::weight_ih] = + context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, "weight_ih", true); + // - weight_hh ( hidden to hidden ) + // : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o + TensorDim weight_hh_dim({unit, NUM_GATE * unit}); + wt_idx[ZoneoutLSTMParams::weight_hh] = + context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, "weight_hh", true); + // - bias_ih ( input bias ) + // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o + TensorDim bias_ih_dim({NUM_GATE * unit}); + wt_idx[ZoneoutLSTMParams::bias_ih] = + context.requestWeight(bias_ih_dim, bias_initializer, + WeightRegularizer::NONE, 1.0f, "bias_ih", true); +#endif + + // hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ] + const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1, + unit); + if (test) { + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] = + context.requestWeight(hidden_state_zoneout_mask_dim, + Tensor::Initializer::NONE, WeightRegularizer::NONE, + 1.0f, "hidden_state_zoneout_mask", false); + } else if (hidden_state_zoneout_rate > epsilon) { + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] = + context.requestTensor( + hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask", + Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); + } + // cell_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ] + const TensorDim cell_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1, + unit); + if (test) { + wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight( + cell_state_zoneout_mask_dim, Tensor::Initializer::NONE, + WeightRegularizer::NONE, 1.0f, "cell_state_zoneout_mask", false); + } else if (cell_state_zoneout_rate > epsilon) { + wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor( + cell_state_zoneout_mask_dim, "cell_state_zoneout_mask", + Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); + } + + /** + * TODO: hidden_state is only used from the previous timestep. Once it is + * supported as input, no need to cache the hidden_state itself + */ + /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */ + const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit); + wt_idx[ZoneoutLSTMParams::hidden_state] = context.requestTensor( + hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN, false); + /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */ + const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit); + wt_idx[ZoneoutLSTMParams::cell_state] = context.requestTensor( + cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN, false); + + hidden_state_origin_idx = context.requestTensor( + hidden_state_dim, "hidden_state_origin", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN, false); + cell_state_origin_idx = context.requestTensor( + cell_state_dim, "cell_state_origin", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN, false); + +#if !ENABLE_SHARING_WT_IDX + /** + * TODO: make this independent of time dimension once recurrent realizer + * supports requesting tensors which are not always shared + */ + /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */ + const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit); + wt_idx[ZoneoutLSTMParams::ifgo] = + context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN, false); +#endif + + TensorDim hidden_state_t_dim({batch_size, 1, 1, unit}); + TensorDim cell_state_t_dim({batch_size, 1, 1, unit}); + InitLayerContext core_context( + {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3, + context.executeInPlace(), context.getName()); + lstmcellcorelayer.finalize(core_context); + init_lstm_context::fillLayerInitContext(context, core_context); +} + +void ZoneoutLSTMCellLayer::setProperty(const std::vector &values) { + std::vector remain_props = + loadProperties(values, zoneout_lstmcell_props); + + // Note: In current implementation the lstmcellcorelayer also has + // a properties related to weight. But it is not exported or used anywhere. + lstmcellcorelayer.setProperty(remain_props); + if (!std::get(zoneout_lstmcell_props).empty()) { + lstmcellcorelayer.setProperty( + {"unit=" + to_string(std::get(zoneout_lstmcell_props))}); + } + +#if !ENABLE_SHARING_WT_IDX + // To remove lstmcell core layer's properties + std::tuple + lstmcell_core_props; + std::vector impl_props = + loadProperties(remain_props, lstmcell_core_props); + + LayerImpl::setProperty(impl_props); +#endif +} + +void ZoneoutLSTMCellLayer::exportTo(Exporter &exporter, + const ExportMethods &method) const { +#if !ENABLE_SHARING_WT_IDX + LayerImpl::exportTo(exporter, method); +#endif + exporter.saveResult( + std::forward_as_tuple( + std::get(zoneout_lstmcell_props), + std::get(zoneout_lstmcell_props), + std::get(zoneout_lstmcell_props), + std::get(zoneout_lstmcell_props), + std::get(zoneout_lstmcell_props)), + method, this); + lstmcellcorelayer.exportTo(exporter, method); +} + +void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) { + const unsigned int unit = std::get(zoneout_lstmcell_props).get(); + const float hidden_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const float cell_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const bool test = std::get(zoneout_lstmcell_props); + const unsigned int max_timestep = + std::get(zoneout_lstmcell_props); + const unsigned int timestep = + std::get(zoneout_lstmcell_props); + + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + const TensorDim &input_dim = input.getDim(); + const unsigned int batch_size = input_dim.batch(); + + Tensor &hidden_state = + context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]); + hidden_state.reshape({max_timestep, 1, batch_size, unit}); + Tensor prev_hidden_state; + if (!timestep) { + prev_hidden_state = Tensor(batch_size, 1, 1, unit); + prev_hidden_state.setZero(); + } else { + prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1); + prev_hidden_state.reshape({batch_size, 1, 1, unit}); + } + Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1); + next_hidden_state.reshape({batch_size, 1, 1, unit}); + + Tensor &cell_state = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]); + cell_state.reshape({max_timestep, 1, batch_size, unit}); + Tensor prev_cell_state; + if (!timestep) { + prev_cell_state = Tensor(batch_size, 1, 1, unit); + prev_cell_state.setZero(); + } else { + prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1); + prev_cell_state.reshape({batch_size, 1, 1, unit}); + } + Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1); + next_cell_state.reshape({batch_size, 1, 1, unit}); + + if (!timestep) { + hidden_state.setZero(); + cell_state.setZero(); + } + + init_lstm_context::fillWeights(weights, context, training, max_timestep, + timestep, test); + init_lstm_context::fillInputs(inputs, context, training, getInputIdx(wt_idx), + max_timestep, timestep); + init_lstm_context::fillOutputs(outputs, context, training, + getOutputIdx(wt_idx), max_timestep, timestep); + init_lstm_context::fillTensors(tensors, context, training, + getTensorIdx(wt_idx), max_timestep, timestep); + RunLayerContext core_context(context.getName(), context.getTrainable(), + context.getLoss(), context.executeInPlace(), + init_lstm_context::getWeights(weights), + init_lstm_context::getInputs(inputs), + init_lstm_context::getOutputs(outputs), + init_lstm_context::getTensors(tensors)); + lstmcellcorelayer.forwarding(core_context, training); + + if (hidden_state_zoneout_rate > epsilon) { + if (training) { + Tensor &hidden_state_zoneout_mask = + test ? context.getWeight( + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]) + : context.getTensor( + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]); + hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_hidden_state_zoneout_mask = + hidden_state_zoneout_mask.getBatchSlice(timestep, 1); + next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + Tensor prev_hidden_state_zoneout_mask; + if (!test) { + prev_hidden_state_zoneout_mask = + next_hidden_state_zoneout_mask.zoneout_mask( + hidden_state_zoneout_rate); + } else { + next_hidden_state_zoneout_mask.multiply(-1.0f, + prev_hidden_state_zoneout_mask); + prev_hidden_state_zoneout_mask.add_i(1.0f); + } + + Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx); + hidden_state_origin.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_hidden_state_origin = + hidden_state_origin.getBatchSlice(timestep, 1); + next_hidden_state_origin.reshape({batch_size, 1, 1, unit}); + + next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask, + next_hidden_state); + prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, + next_hidden_state, 1.0f); + } + // Todo: zoneout at inference + } + if (cell_state_zoneout_rate > epsilon) { + if (training) { + Tensor &cell_state_zoneout_mask = + test ? context.getWeight( + wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]) + : context.getTensor( + wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]); + cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_cell_state_zoneout_mask = + cell_state_zoneout_mask.getBatchSlice(timestep, 1); + next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + Tensor prev_cell_state_zoneout_mask; + if (!test) { + prev_cell_state_zoneout_mask = + next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); + } else { + next_cell_state_zoneout_mask.multiply(-1.0f, + prev_cell_state_zoneout_mask); + prev_cell_state_zoneout_mask.add_i(1.0f); + } + + Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx); + cell_state_origin.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_cell_state_origin = + cell_state_origin.getBatchSlice(timestep, 1); + next_cell_state_origin.reshape({batch_size, 1, 1, unit}); + + next_cell_state_origin.multiply(next_cell_state_zoneout_mask, + next_cell_state); + prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state, + 1.0f); + } + // Todo: zoneout at inference + } + + Tensor &output = context.getOutput(SINGLE_INOUT_IDX); + output.copyData(next_hidden_state); +} + +void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) { + const bool test = std::get(zoneout_lstmcell_props); + const unsigned int max_timestep = + std::get(zoneout_lstmcell_props); + const unsigned int timestep = + std::get(zoneout_lstmcell_props); + + init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep, + test); + init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx), + max_timestep, timestep); + init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx), + max_timestep, timestep); + init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx), + max_timestep, timestep); + RunLayerContext core_context(context.getName(), context.getTrainable(), + context.getLoss(), context.executeInPlace(), + init_lstm_context::getWeights(weights), + init_lstm_context::getInputs(inputs), + init_lstm_context::getOutputs(outputs), + init_lstm_context::getTensors(tensors)); + lstmcellcorelayer.calcDerivative(core_context); +} + +void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) { + const unsigned int unit = std::get(zoneout_lstmcell_props).get(); + const float hidden_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const float cell_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const bool test = std::get(zoneout_lstmcell_props); + const unsigned int max_timestep = + std::get(zoneout_lstmcell_props); + const unsigned int timestep = + std::get(zoneout_lstmcell_props); + + unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch(); + + const Tensor &incoming_derivative = + context.getIncomingDerivative(SINGLE_INOUT_IDX); + + Tensor &hidden_state_derivative = + context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]); + hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_hidden_state_derivative = + hidden_state_derivative.getBatchSlice(timestep, 1); + next_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); + + Tensor &cell_state_derivative = + context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]); + cell_state_derivative.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_cell_state_derivative = + cell_state_derivative.getBatchSlice(timestep, 1); + next_cell_state_derivative.reshape({batch_size, 1, 1, unit}); + + if (timestep + 1 == max_timestep) { + Tensor &djdweight_ih = + context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]); + Tensor &djdweight_hh = + context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]); + Tensor &djdbias_ih = + context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]); + djdweight_ih.setZero(); + djdweight_hh.setZero(); + djdbias_ih.setZero(); + + hidden_state_derivative.setZero(); + cell_state_derivative.setZero(); + } + + next_hidden_state_derivative.add_i(incoming_derivative); + + Tensor prev_hidden_state_derivative; + Tensor prev_cell_state_derivative; + Tensor prev_hidden_state_derivative_residual; + Tensor prev_cell_state_derivative_residual; + if (hidden_state_zoneout_rate > epsilon) { + Tensor &hidden_state_zoneout_mask = + test ? context.getWeight( + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]) + : context.getTensor( + wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]); + hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_hidden_state_zoneout_mask = + hidden_state_zoneout_mask.getBatchSlice(timestep, 1); + next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + Tensor prev_hidden_state_zoneout_mask; + if (!test) { + prev_hidden_state_zoneout_mask = + next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate); + } else { + next_hidden_state_zoneout_mask.multiply(-1.0f, + prev_hidden_state_zoneout_mask); + prev_hidden_state_zoneout_mask.add_i(1.0f); + } + + if (timestep) { + prev_hidden_state_derivative = + hidden_state_derivative.getBatchSlice(timestep - 1, 1); + prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit}); + next_hidden_state_derivative.multiply( + prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual); + } + + Tensor &hidden_state_origin_derivative = + context.getTensorGrad(hidden_state_origin_idx); + hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_hidden_state_origin_derivative = + hidden_state_origin_derivative.getBatchSlice(timestep, 1); + next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit}); + + next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask, + next_hidden_state_origin_derivative); + } + if (cell_state_zoneout_rate > epsilon) { + Tensor &cell_state_zoneout_mask = + test + ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]) + : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]); + cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_cell_state_zoneout_mask = + cell_state_zoneout_mask.getBatchSlice(timestep, 1); + next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit}); + Tensor prev_cell_state_zoneout_mask; + if (!test) { + prev_cell_state_zoneout_mask = + next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate); + } else { + next_cell_state_zoneout_mask.multiply(-1.0f, + prev_cell_state_zoneout_mask); + prev_cell_state_zoneout_mask.add_i(1.0f); + } + + if (timestep) { + prev_cell_state_derivative = + cell_state_derivative.getBatchSlice(timestep - 1, 1); + prev_cell_state_derivative.reshape({batch_size, 1, 1, unit}); + next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask, + prev_cell_state_derivative_residual); + } + + Tensor &cell_state_origin_derivative = + context.getTensorGrad(cell_state_origin_idx); + cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit}); + Tensor next_cell_state_origin_derivative = + cell_state_origin_derivative.getBatchSlice(timestep, 1); + next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit}); + + next_cell_state_derivative.multiply(next_cell_state_zoneout_mask, + next_cell_state_origin_derivative); + } + + init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep, + test); + init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx), + max_timestep, timestep); + init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx), + max_timestep, timestep); + init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx), + max_timestep, timestep); + RunLayerContext core_context(context.getName(), context.getTrainable(), + context.getLoss(), context.executeInPlace(), + init_lstm_context::getWeights(weights), + init_lstm_context::getInputs(inputs), + init_lstm_context::getOutputs(outputs), + init_lstm_context::getTensors(tensors)); + lstmcellcorelayer.calcGradient(core_context); + + if (timestep) { + if (hidden_state_zoneout_rate > epsilon) { + prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual); + } + if (cell_state_zoneout_rate > epsilon) { + prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual); + } + } +} + +void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context, + unsigned int batch) { + const float hidden_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const float cell_state_zoneout_rate = + std::get(zoneout_lstmcell_props); + const bool test = std::get(zoneout_lstmcell_props); + const unsigned int max_timestep = + std::get(zoneout_lstmcell_props); + context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state], + max_timestep * batch); + context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state], + max_timestep * batch); + context.updateTensor(hidden_state_origin_idx, max_timestep * batch); + context.updateTensor(cell_state_origin_idx, max_timestep * batch); + context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch); + + if (hidden_state_zoneout_rate > epsilon && !test) { + context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask], + max_timestep * batch); + } + if (cell_state_zoneout_rate > epsilon && !test) { + context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask], + max_timestep * batch); + } +} + +} // namespace nntrainer diff --git a/nntrainer/layers/zoneout_lstmcell.h b/nntrainer/layers/zoneout_lstmcell.h new file mode 100644 index 0000000..19f9c3b --- /dev/null +++ b/nntrainer/layers/zoneout_lstmcell.h @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 hyeonseok lee + * + * @file zoneout_lstmcell.h + * @date 30 November 2021 + * @brief This is ZoneoutLSTMCell Layer Class of Neural Network + * @see https://github.com/nnstreamer/nntrainer + * @author hyeonseok lee + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ZONEOUTLSTMCELL_H__ +#define __ZONEOUTLSTMCELL_H__ +#ifdef __cplusplus + +#include +#include +#include +#include + +namespace nntrainer { + +/** + * @class ZoneoutLSTMCellLayer + * @brief ZoneoutLSTMCellLayer + */ +class ZoneoutLSTMCellLayer : public LayerImpl { +public: + /** + * @brief HiddenStateZoneOutRate property, this defines zone out rate for + * hidden state + * + */ + class HiddenStateZoneOutRate : public nntrainer::Property { + + public: + /** + * @brief Construct a new HiddenStateZoneOutRate object with a default value + * 0.0 + * + */ + HiddenStateZoneOutRate(float value = 0.0) : + nntrainer::Property(value) {} + static constexpr const char *key = + "hidden_state_zoneout_rate"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ + + /** + * @brief HiddenStateZoneOutRate validator + * + * @param v float to validate + * @retval true if it is equal or greater than 0.0 and equal or smaller than + * to 1.0 + * @retval false if it is samller than 0.0 or greater than 1.0 + */ + bool isValid(const float &value) const override; + }; + + /** + * @brief CellStateZoneOutRate property, this defines zone out rate for cell + * state + * + */ + class CellStateZoneOutRate : public nntrainer::Property { + + public: + /** + * @brief Construct a new CellStateZoneOutRate object with a default value + * 0.0 + * + */ + CellStateZoneOutRate(float value = 0.0) : + nntrainer::Property(value) {} + static constexpr const char *key = + "cell_state_zoneout_rate"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ + + /** + * @brief CellStateZoneOutRate validator + * + * @param v float to validate + * @retval true if it is equal or greater than 0.0 and equal or smaller than + * to 1.0 + * @retval false if it is samller than 0.0 or greater than 1.0 + */ + bool isValid(const float &value) const override; + }; + + /** + * @brief Test property, this property is set to true when test the zoneout + * lstmcell in unittest + * + */ + class Test : public nntrainer::Property { + + public: + /** + * @brief Construct a new Test object with a default value false + * + */ + Test(bool value = false) : nntrainer::Property(value) {} + static constexpr const char *key = "test"; /**< unique key to access */ + using prop_tag = bool_prop_tag; /**< property type */ + }; + + /** + * @brief Constructor of ZoneoutLSTMCellLayer + */ + ZoneoutLSTMCellLayer(); + + /** + * @brief Destructor of ZoneoutLSTMCellLayer + */ + ~ZoneoutLSTMCellLayer() = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc Layer::calcGradient(RunLayerContext &context) + */ + void calcGradient(RunLayerContext &context) override; + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(Exporter &exporter, const ExportMethods &method) const override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { + return ZoneoutLSTMCellLayer::type; + }; + + /** + * @copydoc Layer::supportBackwarding() + */ + bool supportBackwarding() const override { return true; } + + /** + * @copydoc Layer::setProperty(const PropertyType type, const std::string + * &value) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) + */ + void setBatch(RunLayerContext &context, unsigned int batch) override; + + inline static const std::string type = "zoneout_lstmcell"; + +private: + static constexpr unsigned int NUM_GATE = 4; + + LSTMCellCoreLayer lstmcellcorelayer; + + /** + * Unit: number of output neurons + * HiddenStateZoneOutRate: zoneout rate for hidden_state + * CellStateZoneOutRate: zoneout rate for cell_state + * Test: property for test mode + * MaxTimestep: maximum timestep for zoneout lstmcell + * TimeStep: timestep for which lstm should operate + * + * */ + std::tuple + zoneout_lstmcell_props; + std::array wt_idx; /**< indices of the weights */ + + /** + * @brief Protect overflow + */ + float epsilon; + + // These weights, inputs, outputs, tensors are all for the lstm_core + // Todo: remove this + std::vector weights; + std::vector inputs; + std::vector outputs; + std::vector tensors; +}; +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __ZONEOUTLSTMCELL_H__ */ diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index d0cf702..e13b4e8 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -120,8 +120,7 @@ public: SrcSharedTensor() : src(nullptr), off(0) {} SrcSharedTensor(const Tensor *tensor, unsigned int offset) : - src(tensor), - off(offset) {} + src(tensor), off(offset) {} /** * @brief Get the allocated src tensor @@ -252,6 +251,11 @@ void Tensor::setRandUniform(float min, float max) { std::uniform_real_distribution(min, max)); } +void Tensor::setRandBernoulli(float probability) { + setDist( + std::bernoulli_distribution(probability)); +} + void Tensor::initialize() { if (empty() || !isAllocated()) return; @@ -1248,6 +1252,26 @@ void Tensor::filter_mask(const Tensor &mask_len, bool reverse) { } } +Tensor Tensor::zoneout_mask(float zoneout) { + Tensor ret(getDim()); + zoneout_mask(ret, zoneout); + return ret; +} + +void Tensor::zoneout_mask(Tensor &opposite, float zoneout) { + opposite.setRandBernoulli(zoneout); + float *data = getData(); + float *opposite_data = opposite.getData(); + + for (unsigned int i = 0; i < size(); ++i) { + if (opposite_data[i] > epsilon) { + data[i] = 0.0f; + } else { + data[i] = 1.0f; + } + } +} + int Tensor::apply_i(std::function f) { Tensor result = *this; apply(f, result); diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 499ef6b..0cf1c4e 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -705,6 +705,26 @@ public: void filter_mask(const Tensor &mask_len, bool reverse = false); /** + * @brief Calculate 2 Zone Out Mask + * @details Calculate zone out mask according to the bernoulli distribution. + * Zone out mask with rate @a zoneout for inplace and the other zone out mask + * with rate @a (1-zoneout). + * @param zoneout zone out rate + * @retval Tensor zone out mask for opposite tensor + */ + Tensor zoneout_mask(float zoneout); + + /** + * @brief Calculate 2 Zone Out Mask + * @details Calculate zone out mask according to the bernoulli distribution. + * Zone out mask with rate @a zoneout for inplace and the other zone out mask + * with rate @a (1-zoneout). + * @param opposite opposite zone out mask + * @param zoneout zone out rate + */ + void zoneout_mask(Tensor &opposite, float zoneout); + + /** * @brief sum all the Tensor elements according to the batch * @retval Calculated Tensor(batch, 1, 1, 1) */ @@ -974,6 +994,12 @@ public: void setRandUniform(float min = -0.05f, float max = 0.05f); /** + * @brief Set the tensor with random bernoulli distribution + * @param[in] probability probability value for the distribution + */ + void setRandBernoulli(float probability = 0.5f); + + /** * @brief Initialize the memory of the given tensor */ void initialize(); -- 2.7.4