*
* @file lstmcell_core.cpp
* @date 25 November 2021
- * @brief This is LSTMCellCore Layer Class of Neural Network
+ * @brief These are lstm core functions.
* @see https://github.com/nnstreamer/nntrainer
* @author hyeonseok lee <hs89.lee@samsung.com>
* @bug No known bugs except for NYI items
*
*/
-#include <layer_context.h>
#include <lstmcell_core.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
-#include <node_exporter.h>
-
-// ENABLE_SHARING_WT_IDX implies does the wt_idx of lstm_core can be shared
-// with lstm_cell variant layer.
-// Todo: remove this if sharing wt_idx with other lstm variant is enabled
-#define ENABLE_SHARING_WT_IDX 0
namespace nntrainer {
-namespace init_lstm_context {
-void fillLayerInitContext(InitLayerContext &context,
- const InitLayerContext &core_context) {
- /** real set the input flags */
- auto const &input_dims = context.getInputDimensions();
- for (unsigned int idx = 0; idx < core_context.getNumInputs(); idx++) {
- context.setDynDimFlagInputDimension(idx, input_dims[idx].getDynDimFlag());
- context.setEffDimFlagInputDimension(idx, input_dims[idx].getEffDimFlag());
- }
-
- /** real request of tensors */
- for (auto const &ts : core_context.getTensorsSpec())
- context.requestTensor(ts);
-
- /** real request of weights */
- for (auto const &ws : core_context.getWeightsSpec())
- context.requestWeight(ws);
-}
-
-void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
- bool training, const std::vector<unsigned int> &wt_idx,
- const unsigned int max_timestep, const unsigned int timestep,
- bool test) {
- weights.resize(context.getNumWeights());
- for (unsigned int i = 0; i < context.getNumWeights(); ++i) {
- if (training && (!test || i < context.getNumWeights() - 2)) {
- weights[i] =
- Weight(context.getWeight(wt_idx[i]), context.getWeightGrad(wt_idx[i]),
- context.getWeightName(wt_idx[i]));
- } else {
- weights[i] = Weight(context.getWeight(wt_idx[i]), Tensor(),
- context.getWeightName(wt_idx[i]));
- }
- }
-}
-
-const std::vector<Weight *> getWeights(std::vector<Weight> &weights) {
- std::vector<Weight *> ret(weights.size());
- for (unsigned int i = 0; i < weights.size(); ++i) {
- ret[i] = &weights[i];
- }
- return ret;
-}
-
-void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
- bool training, const std::vector<unsigned int> &wt_idx,
- const unsigned int max_timestep, const unsigned int timestep) {
- inputs.resize(3);
- Tensor empty;
- const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
- const unsigned int batch_size = output_dim.batch();
- const unsigned int unit = output_dim.width();
-
- const Tensor &input = context.getInput(wt_idx[0]);
- const Tensor &outgoing_derivative =
- training ? context.getOutgoingDerivative(0) : empty;
-
- Tensor &hidden_state = context.getTensor(wt_idx[1]);
- hidden_state.reshape({max_timestep, 1, batch_size, unit});
- Tensor &hidden_state_derivative =
- training ? context.getTensorGrad(wt_idx[1]) : empty;
- if (training) {
- hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- }
-
- Tensor &cell_state = context.getTensor(wt_idx[2]);
- cell_state.reshape({max_timestep, 1, batch_size, unit});
- Tensor &cell_state_derivative =
- training ? context.getTensorGrad(wt_idx[2]) : empty;
- if (training) {
- cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- }
-
- Tensor prev_hidden_state;
- Tensor prev_hidden_state_derivative;
- Tensor prev_cell_state;
- Tensor prev_cell_state_derivative;
- if (!timestep) {
- prev_hidden_state = Tensor(batch_size, 1, 1, unit);
- prev_hidden_state.setZero();
- prev_hidden_state_derivative = Tensor(batch_size, 1, 1, unit);
- prev_hidden_state_derivative.setZero();
- prev_cell_state = Tensor(batch_size, 1, 1, unit);
- prev_cell_state.setZero();
- prev_cell_state_derivative = Tensor(batch_size, 1, 1, unit);
- prev_cell_state_derivative.setZero();
- } else {
- prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
- prev_hidden_state.reshape({batch_size, 1, 1, unit});
- if (training) {
- prev_hidden_state_derivative =
- hidden_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
- }
- prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
- prev_cell_state.reshape({batch_size, 1, 1, unit});
- if (training) {
- prev_cell_state_derivative =
- cell_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
- }
- }
-
- inputs[0] = Var_Grad(input, outgoing_derivative, "lstmcell_core input");
- inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative,
- context.getTensorName(wt_idx[1]));
- inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative,
- context.getTensorName(wt_idx[2]));
-}
-
-const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs) {
- std::vector<Var_Grad *> ret(inputs.size());
- for (unsigned int i = 0; i < inputs.size(); ++i) {
- ret[i] = &inputs[i];
- }
- return ret;
-}
-
-void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
- bool training, const std::vector<unsigned int> &wt_idx,
- const unsigned int max_timestep, const unsigned int timestep) {
- outputs.resize(2);
- Tensor empty;
- const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
- const unsigned int batch_size = output_dim.batch();
- const unsigned int unit = output_dim.width();
-
- Tensor &hidden_state = context.getTensor(wt_idx[1]);
- hidden_state.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
- next_hidden_state.reshape({batch_size, 1, 1, unit});
- Tensor next_hidden_state_derivative;
- if (training) {
- Tensor &hidden_state_derivative = context.getTensorGrad(wt_idx[1]);
- hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- next_hidden_state_derivative =
- hidden_state_derivative.getBatchSlice(timestep, 1);
- next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
- }
-
- Tensor &cell_state = context.getTensor(wt_idx[2]);
- cell_state.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
- next_cell_state.reshape({batch_size, 1, 1, unit});
- Tensor next_cell_state_derivative;
- if (training) {
- Tensor &cell_state_derivative = context.getTensorGrad(wt_idx[2]);
- cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- next_cell_state_derivative =
- cell_state_derivative.getBatchSlice(timestep, 1);
- next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
- }
-
- outputs[0] = Var_Grad(next_hidden_state, next_hidden_state_derivative,
- context.getTensorName(wt_idx[1]));
- outputs[1] = Var_Grad(next_cell_state, next_cell_state_derivative,
- context.getTensorName(wt_idx[2]));
-}
-
-const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs) {
- std::vector<Var_Grad *> ret(outputs.size());
- for (unsigned int i = 0; i < outputs.size(); ++i) {
- ret[i] = &outputs[i];
- }
- return ret;
-}
-
-void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
- bool training, const std::vector<unsigned int> &wt_idx,
- const unsigned int max_timestep, const unsigned int timestep) {
- tensors.resize(1);
-
-#if ENABLE_SHARING_WT_IDX
- const Tensor &ifgo = context.getTensor(wt_idx[0]);
- Tensor empty;
- const Tensor &ifgo_derivative =
- training ? context.getTensorGrad(wt_idx[0]) : empty;
- tensors[0] =
- Var_Grad(ifgo, ifgo_derivative, context.getTensorName(wt_idx[0]));
-#else
- const TensorDim &output_dim = context.getOutput(0).getDim();
- const unsigned int batch_size = output_dim.batch();
- const unsigned int unit = output_dim.width();
-
- Tensor &ifgo = context.getTensor(wt_idx[0]);
- const unsigned int NUM_GATE = ifgo.width() / unit;
- ifgo.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
- Tensor ifgo_t = ifgo.getBatchSlice(timestep, 1);
- ifgo_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
- Tensor ifgo_derivative_t;
- if (training) {
- Tensor &ifgo_derivative = context.getTensorGrad(wt_idx[0]);
- ifgo_derivative.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
- ifgo_derivative_t = ifgo_derivative.getBatchSlice(timestep, 1);
- ifgo_derivative_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
- }
- tensors[0] =
- Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0]));
-#endif
-}
-
-const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors) {
- std::vector<Var_Grad *> ret(tensors.size());
- for (unsigned int i = 0; i < tensors.size(); ++i) {
- ret[i] = &tensors[i];
- }
- return ret;
-}
-
-} // namespace init_lstm_context
-
-enum INDEX {
- INPUT = 0,
- HIDDEN_STATE_IN = 1,
- CELL_STATE_IN = 2,
- HIDDEN_STATE_OUT = 0,
- CELL_STATE_OUT = 1
-};
-
-enum LSTMCellCoreParams {
- weight_ih,
- weight_hh,
- bias_h,
- bias_ih,
- bias_hh,
- ifgo,
-};
-
-LSTMCellCoreLayer::LSTMCellCoreLayer() :
- LayerImpl(),
- lstmcell_core_props(
- props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH,
- props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
- props::IntegrateBias()),
- acti_func(ActivationType::ACT_NONE, true),
- recurrent_acti_func(ActivationType::ACT_NONE, true) {
- wt_idx.fill(std::numeric_limits<unsigned>::max());
-}
-
-void LSTMCellCoreLayer::finalize(InitLayerContext &context) {
-#if ENBABLE_SHARING_WEIGHT
- const Tensor::Initializer weight_initializer =
- std::get<props::WeightInitializer>(*layer_impl_props).get();
- const Tensor::Initializer bias_initializer =
- std::get<props::BiasInitializer>(*layer_impl_props).get();
- const WeightRegularizer weight_regularizer =
- std::get<props::WeightRegularizer>(*layer_impl_props).get();
- const float weight_regularizer_constant =
- std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
- const bool disable_bias =
- std::get<props::DisableBias>(*layer_impl_props).get();
-#endif
-
- NNTR_THROW_IF(std::get<props::Unit>(lstmcell_core_props).empty(),
- std::invalid_argument)
- << "unit property missing for lstmcell_core layer";
- const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
- const ActivationType hidden_state_activation_type =
- std::get<props::HiddenStateActivation>(lstmcell_core_props).get();
- const ActivationType recurrent_activation_type =
- std::get<props::RecurrentActivation>(lstmcell_core_props).get();
-#if ENBABLE_SHARING_WEIGHT
- const bool integrate_bias =
- std::get<props::IntegrateBias>(lstmcell_core_props).get();
-#endif
-
- if (context.getNumInputs() != 3)
- throw std::invalid_argument("LSTMCellCore layer should takes 3 input");
-
- // input_dim = [ batch, 1, 1, feature_size ]
- const TensorDim &input_dim = context.getInputDimensions()[0];
- if (input_dim.height() != 1 || input_dim.channel() != 1)
- throw std::invalid_argument(
- "Input must be single time dimension for LSTMCellCore");
- // input_hidden_state_dim = [ batch, 1, 1, unit ]
- const TensorDim &input_hidden_state_dim = context.getInputDimensions()[1];
- if (input_hidden_state_dim.channel() != 1 ||
- input_hidden_state_dim.height() != 1)
- throw std::invalid_argument("Input hidden state's dimension should be "
- "[batch, 1, 1, unit] for LSTMCellCore");
- // input_cell_state_dim = [ batch, 1, 1, unit ]
- const TensorDim &input_cell_state_dim = context.getInputDimensions()[2];
- if (input_cell_state_dim.channel() != 1 || input_cell_state_dim.height() != 1)
- throw std::invalid_argument("Input cell state's dimension should be "
- "[batch, 1, 1, unit] for LSTMCellCore");
-
- const unsigned int batch_size = input_dim.batch();
-#if ENABLE_SHARING_WT_IDX
- const unsigned int feature_size = input_dim.width();
-#endif
-
- const TensorDim output_dim(batch_size, 1, 1, unit);
- const TensorDim output_hidden_state_dim = input_hidden_state_dim;
- const TensorDim output_cell_state_dim = input_cell_state_dim;
-
- context.setOutputDimensions(
- {output_dim, output_hidden_state_dim, output_cell_state_dim});
-
-#if ENABLE_SHARING_WT_IDX
- // weight_initializer can be set seperately. weight_ih initializer,
- // weight_hh initializer kernel initializer & recurrent_initializer in keras
- // for now, it is set same way.
-
- // - weight_ih ( input to hidden )
- // : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
- TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
- wt_idx[LSTMCellCoreParams::weight_ih] =
- context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
- weight_regularizer_constant, "weight_ih", true);
- // - weight_hh ( hidden to hidden )
- // : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
- TensorDim weight_hh_dim({unit, NUM_GATE * unit});
- wt_idx[LSTMCellCoreParams::weight_hh] =
- context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
- weight_regularizer_constant, "weight_hh", true);
- if (!disable_bias) {
- if (integrate_bias) {
- // - bias_h ( input bias, hidden bias are integrate to 1 bias )
- // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
- TensorDim bias_h_dim({NUM_GATE * unit});
- wt_idx[LSTMCellCoreParams::bias_h] =
- context.requestWeight(bias_h_dim, bias_initializer,
- WeightRegularizer::NONE, 1.0f, "bias_h", true);
- } else {
- // - bias_ih ( input bias )
- // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
- TensorDim bias_ih_dim({NUM_GATE * unit});
- wt_idx[LSTMCellCoreParams::bias_ih] =
- context.requestWeight(bias_ih_dim, bias_initializer,
- WeightRegularizer::NONE, 1.0f, "bias_ih", true);
- // - bias_hh ( hidden bias )
- // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
- TensorDim bias_hh_dim({NUM_GATE * unit});
- wt_idx[LSTMCellCoreParams::bias_hh] =
- context.requestWeight(bias_hh_dim, bias_initializer,
- WeightRegularizer::NONE, 1.0f, "bias_hh", true);
- }
- }
-#endif
-
-#if ENABLE_SHARING_WT_IDX
- TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
- wt_idx[LSTMCellCoreParams::ifgo] =
- context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN);
-#endif
- acti_func.setActiFunc(hidden_state_activation_type);
- recurrent_acti_func.setActiFunc(recurrent_activation_type);
-}
-
-void LSTMCellCoreLayer::setProperty(const std::vector<std::string> &values) {
- std::vector<std::string> remain_props =
- loadProperties(values, lstmcell_core_props);
- LayerImpl::setProperty(remain_props);
-}
-
-void LSTMCellCoreLayer::exportTo(Exporter &exporter,
- const ExportMethods &method) const {
-#if ENABLE_SHARING_WT_IDX
- LayerImpl::exportTo(exporter, method);
-#endif
- exporter.saveResult(lstmcell_core_props, method, this);
-}
-
-void LSTMCellCoreLayer::forwarding(RunLayerContext &context, bool training) {
- const bool disable_bias =
- std::get<props::DisableBias>(*layer_impl_props).get();
-
- const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
- const bool integrate_bias =
- std::get<props::IntegrateBias>(lstmcell_core_props).get();
-
- const Tensor &input = context.getInput(INDEX::INPUT);
- const unsigned int batch_size = input.getDim().batch();
-
- const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
- const Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
- Tensor &next_hidden_state = context.getOutput(INDEX::HIDDEN_STATE_OUT);
- Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
-
-#if ENABLE_SHARING_WT_IDX
- const Tensor &weight_ih =
- context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
- const Tensor &weight_hh =
- context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
- Tensor empty;
- Tensor &bias_h = !disable_bias && integrate_bias
- ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_h])
- : empty;
- Tensor &bias_ih = !disable_bias && !integrate_bias
- ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_ih])
- : empty;
- Tensor &bias_hh = !disable_bias && !integrate_bias
- ? context.getWeight(wt_idx[LSTMCellCoreParams::bias_hh])
- : empty;
-#else
- const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
- const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
- Tensor empty;
- Tensor &bias_h = !disable_bias && integrate_bias
- ? context.getWeight(LSTMCellCoreParams::bias_h)
- : empty;
- // subtract index by 1 cause there is no bias_h
- Tensor &bias_ih = !disable_bias && !integrate_bias
- ? context.getWeight(LSTMCellCoreParams::bias_ih - 1)
- : empty;
- Tensor &bias_hh = !disable_bias && !integrate_bias
- ? context.getWeight(LSTMCellCoreParams::bias_hh - 1)
- : empty;
-#endif
-
-#if ENABLE_SHARING_WT_IDX
- Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
-#else
- Tensor &ifgo = context.getTensor(0);
-#endif
-
- input.dot(weight_ih, ifgo);
- prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
- if (!disable_bias) {
- if (integrate_bias) {
- ifgo.add_i(bias_h);
- } else {
- ifgo.add_i(bias_ih);
- ifgo.add_i(bias_hh);
- }
- }
-
- Tensor input_forget_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
- Tensor input_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
- Tensor forget_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
- Tensor memory_cell =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
- Tensor output_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
-
- recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
- recurrent_acti_func.run_fn(output_gate, output_gate);
- acti_func.run_fn(memory_cell, memory_cell);
-
- forget_gate.multiply_strided(prev_cell_state, next_cell_state);
- memory_cell.multiply_strided(input_gate, next_cell_state, 1.0f);
-
- acti_func.run_fn(next_cell_state, next_hidden_state);
- next_hidden_state.multiply_i_strided(output_gate);
-}
-
-void LSTMCellCoreLayer::calcDerivative(RunLayerContext &context) {
-#if ENABLE_SHARING_WT_IDX
- Tensor &ifgo_derivative =
- context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
- const Tensor &weight_ih =
- context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
-#else
- const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
- Tensor &ifgo_derivative = context.getTensorGrad(0);
-#endif
- Tensor &outgoing_derivative = context.getOutgoingDerivative(INDEX::INPUT);
-
- ifgo_derivative.dot(weight_ih, outgoing_derivative, false, true);
-}
-
-void LSTMCellCoreLayer::calcGradient(RunLayerContext &context) {
- const bool disable_bias =
- std::get<props::DisableBias>(*layer_impl_props).get();
-
- const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
- const bool integrate_bias =
- std::get<props::IntegrateBias>(lstmcell_core_props).get();
-
- const Tensor &input = context.getInput(INDEX::INPUT);
- const unsigned int batch_size = input.getDim().batch();
-
-#if ENABLE_SHARING_WT_IDX
- Tensor &djdweight_ih =
- context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_ih]);
- const Tensor &weight_hh =
- context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
- Tensor &djdweight_hh =
- context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_hh]);
- Tensor empty;
- Tensor &djdbias_h =
- !disable_bias && integrate_bias
- ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_h])
- : empty;
- Tensor &djdbias_ih =
- !disable_bias && !integrate_bias
- ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_ih])
- : empty;
- Tensor &djdbias_hh =
- !disable_bias && !integrate_bias
- ? context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_hh])
- : empty;
-#else
- Tensor &djdweight_ih = context.getWeightGrad(LSTMCellCoreParams::weight_ih);
- const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
- Tensor &djdweight_hh = context.getWeightGrad(LSTMCellCoreParams::weight_hh);
- Tensor empty;
- Tensor &djdbias_h = !disable_bias && integrate_bias
- ? context.getWeightGrad(LSTMCellCoreParams::bias_h)
- : empty;
- // subtract index by 1 cause there is no bias_h(and also djdbias_h)
- Tensor &djdbias_ih =
- !disable_bias && !integrate_bias
- ? context.getWeightGrad(LSTMCellCoreParams::bias_ih - 1)
- : empty;
- Tensor &djdbias_hh =
- !disable_bias && !integrate_bias
- ? context.getWeightGrad(LSTMCellCoreParams::bias_hh - 1)
- : empty;
-#endif
-
- const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
- Tensor &prev_hidden_state_derivative =
- context.getOutgoingDerivative(INDEX::HIDDEN_STATE_IN);
- Tensor &next_hidden_state_derivative =
- context.getIncomingDerivative(INDEX::HIDDEN_STATE_OUT);
-
- Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
- Tensor &prev_cell_state_derivative =
- context.getOutgoingDerivative(INDEX::CELL_STATE_IN);
- Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
- Tensor &next_cell_state_derivative =
- context.getIncomingDerivative(INDEX::CELL_STATE_OUT);
-
-#if ENABLE_SHARING_WT_IDX
- Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
- Tensor &ifgo_derivative =
- context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
-#else
- Tensor &ifgo = context.getTensor(0);
- Tensor &ifgo_derivative = context.getTensorGrad(0);
-#endif
-
- Tensor input_forget_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
- Tensor input_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
- Tensor forget_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
- Tensor memory_cell =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
- Tensor output_gate =
- ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
-
- Tensor input_forget_gate_derivative =
- ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
- Tensor input_gate_derivative =
- ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
- Tensor forget_gate_derivative =
- ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
- Tensor memory_cell_derivative = ifgo_derivative.getSharedDataTensor(
- {batch_size, 1, 1, unit}, unit * 2, false);
- Tensor output_gate_derivative = ifgo_derivative.getSharedDataTensor(
- {batch_size, 1, 1, unit}, unit * 3, false);
-
- acti_func.run_fn(next_cell_state, next_cell_state);
- next_hidden_state_derivative.multiply_strided(next_cell_state,
- output_gate_derivative);
-
- acti_func.run_prime_fn(next_cell_state, prev_cell_state_derivative,
- next_hidden_state_derivative);
- prev_cell_state_derivative.multiply_i_strided(output_gate);
- prev_cell_state_derivative.add_i(next_cell_state_derivative);
-
- prev_cell_state_derivative.multiply_strided(input_gate,
- memory_cell_derivative);
- prev_cell_state_derivative.multiply_strided(memory_cell,
- input_gate_derivative);
-
- prev_cell_state_derivative.multiply_strided(prev_cell_state,
- forget_gate_derivative);
- prev_cell_state_derivative.multiply_i_strided(forget_gate);
-
- recurrent_acti_func.run_prime_fn(output_gate, output_gate_derivative,
- output_gate_derivative);
- recurrent_acti_func.run_prime_fn(input_forget_gate,
- input_forget_gate_derivative,
- input_forget_gate_derivative);
- acti_func.run_prime_fn(memory_cell, memory_cell_derivative,
- memory_cell_derivative);
-
- if (!disable_bias) {
- if (integrate_bias) {
- ifgo_derivative.sum(0, djdbias_h, 1.0f, 1.0f);
- } else {
- ifgo_derivative.sum(0, djdbias_ih, 1.0f, 1.0f);
- ifgo_derivative.sum(0, djdbias_hh, 1.0f, 1.0f);
- }
- }
- input.dot(ifgo_derivative, djdweight_ih, true, false, 1.0f);
- prev_hidden_state.dot(ifgo_derivative, djdweight_hh, true, false, 1.0f);
- ifgo_derivative.dot(weight_hh, prev_hidden_state_derivative, false, true);
-}
-
-void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) {
- context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch);
-}
-
void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
const bool disable_bias, const bool integrate_bias,
ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
hidden_state,
cell_state,
ifgo,
+ lstm_cell_state,
hidden_state_zoneout_mask,
cell_state_zoneout_mask,
};
-unsigned int hidden_state_origin_idx = 0, cell_state_origin_idx = 0;
-
-const std::vector<unsigned int>
-getWeightIdx(std::array<unsigned int, 10> &wt_idx, const bool disable_bias,
- const bool integrate_bias, const bool test) {
- std::vector<unsigned int> ret;
- ret.push_back(wt_idx[ZoneoutLSTMParams::weight_ih]);
- ret.push_back(wt_idx[ZoneoutLSTMParams::weight_hh]);
- if (!disable_bias) {
- if (integrate_bias) {
- ret.push_back(wt_idx[ZoneoutLSTMParams::bias_h]);
- } else {
- ret.push_back(wt_idx[ZoneoutLSTMParams::bias_ih]);
- ret.push_back(wt_idx[ZoneoutLSTMParams::bias_hh]);
- }
- }
- if (test) {
- ret.push_back(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
- ret.push_back(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
- }
- return ret;
-}
-
-const std::vector<unsigned int>
-getInputIdx(std::array<unsigned int, 10> &wt_idx) {
- std::vector<unsigned int> ret(3);
- ret[0] = SINGLE_INOUT_IDX;
- ret[1] = wt_idx[ZoneoutLSTMParams::hidden_state];
- ret[2] = wt_idx[ZoneoutLSTMParams::cell_state];
- return ret;
-}
-
-const std::vector<unsigned int>
-getOutputIdx(std::array<unsigned int, 10> &wt_idx) {
- std::vector<unsigned int> ret(3);
- ret[0] = SINGLE_INOUT_IDX;
- ret[1] = hidden_state_origin_idx;
- ret[2] = cell_state_origin_idx;
- return ret;
-}
-
-const std::vector<unsigned int>
-getTensorIdx(std::array<unsigned int, 10> &wt_idx) {
- std::vector<unsigned int> ret(1);
- ret[0] = wt_idx[ZoneoutLSTMParams::ifgo];
- return ret;
-}
-
ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() :
LayerImpl(),
- zoneout_lstmcell_props(props::Unit(), HiddenStateZoneOutRate(),
- CellStateZoneOutRate(), props::IntegrateBias(), Test(),
- props::MaxTimestep(), props::Timestep()),
+ zoneout_lstmcell_props(
+ props::Unit(), props::IntegrateBias(),
+ props::HiddenStateActivation() = ActivationType::ACT_TANH,
+ props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
+ HiddenStateZoneOutRate(), CellStateZoneOutRate(), Test(),
+ props::MaxTimestep(), props::Timestep()),
+ acti_func(ActivationType::ACT_NONE, true),
+ recurrent_acti_func(ActivationType::ACT_NONE, true),
epsilon(1e-3) {
wt_idx.fill(std::numeric_limits<unsigned>::max());
}
}
void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
-#if !ENABLE_SHARING_WT_IDX
const Tensor::Initializer weight_initializer =
std::get<props::WeightInitializer>(*layer_impl_props).get();
const Tensor::Initializer bias_initializer =
std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
const bool disable_bias =
std::get<props::DisableBias>(*layer_impl_props).get();
-#endif
NNTR_THROW_IF(std::get<props::Unit>(zoneout_lstmcell_props).empty(),
std::invalid_argument)
const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
const bool integrate_bias =
std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
+ const ActivationType hidden_state_activation_type =
+ std::get<props::HiddenStateActivation>(zoneout_lstmcell_props).get();
+ const ActivationType recurrent_activation_type =
+ std::get<props::RecurrentActivation>(zoneout_lstmcell_props).get();
const bool test = std::get<Test>(zoneout_lstmcell_props).get();
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
// input_dim = [ batch_size, 1, 1, feature_size ]
const TensorDim &input_dim = context.getInputDimensions()[0];
- if (input_dim.height() != 1 || input_dim.channel() != 1)
+ if (input_dim.channel() != 1 || input_dim.height() != 1)
throw std::invalid_argument("Input must be single time dimension for "
"ZoneoutLSTMCell (shape should be "
"[batch_size, 1, 1, feature_size])");
const unsigned int batch_size = input_dim.batch();
-#if !ENABLE_SHARING_WT_IDX
const unsigned int feature_size = input_dim.width();
-#endif
// output_dim = [ batch_size, 1, 1, unit ]
const TensorDim output_dim(batch_size, 1, 1, unit);
context.setOutputDimensions({output_dim});
-#if !ENABLE_SHARING_WT_IDX
// weight_initializer can be set seperately. weight_ih initializer,
// weight_hh initializer kernel initializer & recurrent_initializer in keras
// for now, it is set same way.
WeightRegularizer::NONE, 1.0f, "bias_hh", true);
}
}
-#endif
/**
* TODO: hidden_state is only used from the previous timestep. Once it is
cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN, false);
- hidden_state_origin_idx = context.requestTensor(
- hidden_state_dim, "hidden_state_origin", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
- cell_state_origin_idx = context.requestTensor(
- cell_state_dim, "cell_state_origin", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
-
-#if !ENABLE_SHARING_WT_IDX
- /**
- * TODO: make this independent of time dimension once recurrent realizer
- * supports requesting tensors which are not always shared
- */
- /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */
- const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
+ /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
+ const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
wt_idx[ZoneoutLSTMParams::ifgo] =
context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
-#endif
+ TensorLifespan::ITERATION_LIFESPAN);
+
+ /** lstm_cell_state_dim = [ batch_size, 1, 1, unit ] */
+ const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit);
+ wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor(
+ lstm_cell_state_dim, "lstm_cell_state", Tensor::Initializer::NONE, true,
+ TensorLifespan::ITERATION_LIFESPAN);
// hidden_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
}
- TensorDim hidden_state_t_dim({batch_size, 1, 1, unit});
- TensorDim cell_state_t_dim({batch_size, 1, 1, unit});
- InitLayerContext core_context(
- {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3,
- context.executeInPlace(), context.getName());
- lstmcellcorelayer.finalize(core_context);
- init_lstm_context::fillLayerInitContext(context, core_context);
+ acti_func.setActiFunc(hidden_state_activation_type);
+ recurrent_acti_func.setActiFunc(recurrent_activation_type);
}
void ZoneoutLSTMCellLayer::setProperty(const std::vector<std::string> &values) {
- std::vector<std::string> remain_props =
+ const std::vector<std::string> &remain_props =
loadProperties(values, zoneout_lstmcell_props);
-
- // Note: In current implementation the lstmcellcorelayer also has
- // a properties related to weight. But it is not exported or used anywhere.
- lstmcellcorelayer.setProperty(remain_props);
- if (!std::get<props::Unit>(zoneout_lstmcell_props).empty()) {
- lstmcellcorelayer.setProperty(
- {"unit=" + to_string(std::get<props::Unit>(zoneout_lstmcell_props))});
- }
- lstmcellcorelayer.setProperty(
- {"integrate_bias=" +
- to_string(std::get<props::IntegrateBias>(zoneout_lstmcell_props))});
-
-#if !ENABLE_SHARING_WT_IDX
- // To remove lstmcell core layer's properties
- std::tuple<props::HiddenStateActivation, props::RecurrentActivation>
- lstmcell_core_props;
- std::vector<std::string> impl_props =
- loadProperties(remain_props, lstmcell_core_props);
-
- LayerImpl::setProperty(impl_props);
-#endif
+ LayerImpl::setProperty(remain_props);
}
void ZoneoutLSTMCellLayer::exportTo(Exporter &exporter,
const ExportMethods &method) const {
-#if !ENABLE_SHARING_WT_IDX
LayerImpl::exportTo(exporter, method);
-#endif
- exporter.saveResult(
- std::forward_as_tuple(
- std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props),
- std::get<CellStateZoneOutRate>(zoneout_lstmcell_props),
- std::get<Test>(zoneout_lstmcell_props),
- std::get<props::MaxTimestep>(zoneout_lstmcell_props),
- std::get<props::Timestep>(zoneout_lstmcell_props)),
- method, this);
- lstmcellcorelayer.exportTo(exporter, method);
+ exporter.saveResult(zoneout_lstmcell_props, method, this);
}
void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
std::get<props::DisableBias>(*layer_impl_props).get();
const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
+ const bool integrate_bias =
+ std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
const float hidden_state_zoneout_rate =
std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
const float cell_state_zoneout_rate =
std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
- const bool integrate_bias =
- std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
const bool test = std::get<Test>(zoneout_lstmcell_props).get();
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
const unsigned int timestep =
std::get<props::Timestep>(zoneout_lstmcell_props).get();
- const unsigned int batch_size =
- context.getInput(SINGLE_INOUT_IDX).getDim().batch();
-
+ const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
- Tensor &hidden_state =
- context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
- hidden_state.reshape({max_timestep, 1, batch_size, unit});
+ const unsigned int batch_size = input.getDim().batch();
+
+ const Tensor &weight_ih =
+ context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+ const Tensor &weight_hh =
+ context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
+ Tensor empty;
+ Tensor &bias_h = !disable_bias && integrate_bias
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
+ : empty;
+ Tensor &bias_ih = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
+ : empty;
+ Tensor &bias_hh = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
+ : empty;
+
+ Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
+ hs.reshape({max_timestep, 1, batch_size, unit});
Tensor prev_hidden_state;
if (!timestep) {
prev_hidden_state = Tensor(batch_size, 1, 1, unit);
prev_hidden_state.setZero();
} else {
- prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
prev_hidden_state.reshape({batch_size, 1, 1, unit});
}
- Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
- next_hidden_state.reshape({batch_size, 1, 1, unit});
+ Tensor hidden_state = hs.getBatchSlice(timestep, 1);
+ hidden_state.reshape({batch_size, 1, 1, unit});
- Tensor &cell_state = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
- cell_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
+ cs.reshape({max_timestep, 1, batch_size, unit});
Tensor prev_cell_state;
if (!timestep) {
prev_cell_state = Tensor(batch_size, 1, 1, unit);
prev_cell_state.setZero();
} else {
- prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
+ prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
prev_cell_state.reshape({batch_size, 1, 1, unit});
}
- Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
- next_cell_state.reshape({batch_size, 1, 1, unit});
+ Tensor cell_state = cs.getBatchSlice(timestep, 1);
+ cell_state.reshape({batch_size, 1, 1, unit});
- if (!timestep) {
- hidden_state.setZero();
- cell_state.setZero();
- }
+ Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
- init_lstm_context::fillWeights(
- weights, context, training,
- getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
- timestep, test);
- init_lstm_context::fillInputs(inputs, context, training, getInputIdx(wt_idx),
- max_timestep, timestep);
- init_lstm_context::fillOutputs(outputs, context, training,
- getOutputIdx(wt_idx), max_timestep, timestep);
- init_lstm_context::fillTensors(tensors, context, training,
- getTensorIdx(wt_idx), max_timestep, timestep);
- RunLayerContext core_context(context.getName(), context.getTrainable(),
- context.getLoss(), context.executeInPlace(),
- init_lstm_context::getWeights(weights),
- init_lstm_context::getInputs(inputs),
- init_lstm_context::getOutputs(outputs),
- init_lstm_context::getTensors(tensors));
- lstmcellcorelayer.forwarding(core_context, training);
+ Tensor &lstm_cell_state =
+ context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
+
+ lstmcell_forwarding(unit, batch_size, disable_bias, integrate_bias, acti_func,
+ recurrent_acti_func, input, prev_hidden_state,
+ prev_cell_state, hidden_state, lstm_cell_state, weight_ih,
+ weight_hh, bias_h, bias_ih, bias_hh, ifgo);
if (training) {
- Tensor &hidden_state_zoneout_mask =
+ Tensor &hs_zoneout_mask =
test ? context.getWeight(
wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
: context.getTensor(
wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
- hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_zoneout_mask =
- hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor hidden_state_zoneout_mask =
+ hs_zoneout_mask.getBatchSlice(timestep, 1);
+ hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
Tensor prev_hidden_state_zoneout_mask;
if (!test) {
prev_hidden_state_zoneout_mask =
- next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+ hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
} else {
- next_hidden_state_zoneout_mask.multiply(-1.0f,
- prev_hidden_state_zoneout_mask);
+ hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
prev_hidden_state_zoneout_mask.add_i(1.0f);
}
- Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
- hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_origin =
- hidden_state_origin.getBatchSlice(timestep, 1);
- next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
+ hidden_state.multiply_i(hidden_state_zoneout_mask);
+ prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
+ 1.0f);
- next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
- next_hidden_state);
- prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
- next_hidden_state, 1.0f);
- }
-
- if (training) {
- Tensor &cell_state_zoneout_mask =
+ Tensor &cs_zoneout_mask =
test
? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
: context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
- cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_zoneout_mask =
- cell_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
+ cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
Tensor prev_cell_state_zoneout_mask;
if (!test) {
prev_cell_state_zoneout_mask =
- next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+ cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
} else {
- next_cell_state_zoneout_mask.multiply(-1.0f,
- prev_cell_state_zoneout_mask);
+ cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
prev_cell_state_zoneout_mask.add_i(1.0f);
}
- Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
- cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_origin =
- cell_state_origin.getBatchSlice(timestep, 1);
- next_cell_state_origin.reshape({batch_size, 1, 1, unit});
-
- next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
- next_cell_state);
- prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
- 1.0f);
+ lstm_cell_state.multiply(cell_state_zoneout_mask, cell_state);
+ prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f);
}
// Todo: zoneout at inference
- output.copyData(next_hidden_state);
+ output.copyData(hidden_state);
}
void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
- const bool disable_bias =
- std::get<props::DisableBias>(*layer_impl_props).get();
+ Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
+ const Tensor &weight_ih =
+ context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+ Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
- const bool integrate_bias =
- std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
- const bool test = std::get<Test>(zoneout_lstmcell_props).get();
- const unsigned int max_timestep =
- std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
- const unsigned int timestep =
- std::get<props::Timestep>(zoneout_lstmcell_props).get();
-
- init_lstm_context::fillWeights(
- weights, context, true,
- getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
- timestep, test);
- init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
- max_timestep, timestep);
- init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
- max_timestep, timestep);
- init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
- max_timestep, timestep);
- RunLayerContext core_context(context.getName(), context.getTrainable(),
- context.getLoss(), context.executeInPlace(),
- init_lstm_context::getWeights(weights),
- init_lstm_context::getInputs(inputs),
- init_lstm_context::getOutputs(outputs),
- init_lstm_context::getTensors(tensors));
- lstmcellcorelayer.calcDerivative(core_context);
+ lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
}
void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
std::get<props::DisableBias>(*layer_impl_props).get();
const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
- const float hidden_state_zoneout_rate =
- std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
- const float cell_state_zoneout_rate =
- std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
const bool integrate_bias =
std::get<props::IntegrateBias>(zoneout_lstmcell_props).get();
- const bool test = std::get<Test>(zoneout_lstmcell_props);
+ const float hidden_state_zoneout_rate =
+ std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
+ const float cell_state_zoneout_rate =
+ std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
+ const bool test = std::get<Test>(zoneout_lstmcell_props).get();
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
const unsigned int timestep =
std::get<props::Timestep>(zoneout_lstmcell_props).get();
- unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch();
-
+ const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
const Tensor &incoming_derivative =
context.getIncomingDerivative(SINGLE_INOUT_IDX);
- Tensor &hidden_state_derivative =
- context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
- hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_derivative =
- hidden_state_derivative.getBatchSlice(timestep, 1);
- next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+ unsigned int batch_size = input.getDim().batch();
+
+ Tensor &d_weight_ih =
+ context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
+ const Tensor &weight_hh =
+ context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
+ Tensor &d_weight_hh =
+ context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
+ Tensor empty;
+ Tensor &d_bias_h =
+ !disable_bias && integrate_bias
+ ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h])
+ : empty;
+ Tensor &d_bias_ih =
+ !disable_bias && !integrate_bias
+ ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih])
+ : empty;
+ Tensor &d_bias_hh =
+ !disable_bias && !integrate_bias
+ ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh])
+ : empty;
+
+ Tensor &hs = context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state]);
+ hs.reshape({max_timestep, 1, batch_size, unit});
+ Tensor prev_hidden_state;
+ if (!timestep) {
+ prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+ prev_hidden_state.setZero();
+ } else {
+ prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state.reshape({batch_size, 1, 1, unit});
+ }
+
+ Tensor &d_hs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::hidden_state]);
+ d_hs.reshape({max_timestep, 1, batch_size, unit});
+ Tensor d_prev_hidden_state;
+ if (!timestep) {
+ d_prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+ d_prev_hidden_state.setZero();
+ } else {
+ d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1);
+ d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
+ }
+ Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1);
+ d_hidden_state.reshape({batch_size, 1, 1, unit});
+
+ Tensor &cs = context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state]);
+ cs.reshape({max_timestep, 1, batch_size, unit});
+ Tensor prev_cell_state;
+ if (!timestep) {
+ prev_cell_state = Tensor(batch_size, 1, 1, unit);
+ prev_cell_state.setZero();
+ } else {
+ prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
+ prev_cell_state.reshape({batch_size, 1, 1, unit});
+ }
+ Tensor cell_state = cs.getBatchSlice(timestep, 1);
+ cell_state.reshape({batch_size, 1, 1, unit});
- Tensor &cell_state_derivative =
- context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
- cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_derivative =
- cell_state_derivative.getBatchSlice(timestep, 1);
- next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+ Tensor &d_cs = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::cell_state]);
+ d_cs.reshape({max_timestep, 1, batch_size, unit});
+ Tensor d_prev_cell_state;
+ if (!timestep) {
+ d_prev_cell_state = Tensor(batch_size, 1, 1, unit);
+ d_prev_cell_state.setZero();
+ } else {
+ d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1);
+ d_prev_cell_state.reshape({batch_size, 1, 1, unit});
+ }
+ Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1);
+ d_cell_state.reshape({batch_size, 1, 1, unit});
+
+ Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
+ Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
+
+ const Tensor &lstm_cell_state =
+ context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
+ Tensor &d_lstm_cell_state =
+ context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
if (timestep + 1 == max_timestep) {
- Tensor &djdweight_ih =
- context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
- Tensor &djdweight_hh =
- context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
- djdweight_ih.setZero();
- djdweight_hh.setZero();
+ d_weight_ih.setZero();
+ d_weight_hh.setZero();
if (!disable_bias) {
if (integrate_bias) {
- Tensor &djdbias_h =
- context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h]);
- djdbias_h.setZero();
+ d_bias_h.setZero();
} else {
- Tensor &djdbias_ih =
- context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih]);
- djdbias_ih.setZero();
- Tensor &djdbias_hh =
- context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh]);
- djdbias_hh.setZero();
+ d_bias_ih.setZero();
+ d_bias_hh.setZero();
}
}
-
- hidden_state_derivative.setZero();
- cell_state_derivative.setZero();
+ d_hidden_state.setZero();
+ d_cell_state.setZero();
}
- next_hidden_state_derivative.add_i(incoming_derivative);
+ d_hidden_state.add_i(incoming_derivative);
- Tensor prev_hidden_state_derivative;
- Tensor prev_cell_state_derivative;
- Tensor prev_hidden_state_derivative_residual;
- Tensor prev_cell_state_derivative_residual;
+ Tensor d_prev_hidden_state_residual;
- Tensor &hidden_state_zoneout_mask =
+ Tensor &hs_zoneout_mask =
test
? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
: context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
- hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_zoneout_mask =
- hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor hidden_state_zoneout_mask = hs_zoneout_mask.getBatchSlice(timestep, 1);
+ hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
Tensor prev_hidden_state_zoneout_mask;
if (!test) {
prev_hidden_state_zoneout_mask =
- next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+ hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
} else {
- next_hidden_state_zoneout_mask.multiply(-1.0f,
- prev_hidden_state_zoneout_mask);
+ hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
prev_hidden_state_zoneout_mask.add_i(1.0f);
}
- if (timestep) {
- prev_hidden_state_derivative =
- hidden_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
- next_hidden_state_derivative.multiply(
- prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
- }
-
- Tensor &hidden_state_origin_derivative =
- context.getTensorGrad(hidden_state_origin_idx);
- hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_origin_derivative =
- hidden_state_origin_derivative.getBatchSlice(timestep, 1);
- next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+ d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
+ d_prev_hidden_state_residual);
+ d_hidden_state.multiply_i(hidden_state_zoneout_mask);
- next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
- next_hidden_state_origin_derivative);
+ Tensor d_prev_cell_state_residual;
- Tensor &cell_state_zoneout_mask =
+ Tensor &cs_zoneout_mask =
test
? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
: context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
- cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_zoneout_mask =
- cell_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
+ cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
Tensor prev_cell_state_zoneout_mask;
if (!test) {
prev_cell_state_zoneout_mask =
- next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+ cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
} else {
- next_cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
+ cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
prev_cell_state_zoneout_mask.add_i(1.0f);
}
- if (timestep) {
- prev_cell_state_derivative =
- cell_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
- next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
- prev_cell_state_derivative_residual);
- }
+ d_cell_state.multiply(prev_cell_state_zoneout_mask,
+ d_prev_cell_state_residual);
+ d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state);
- Tensor &cell_state_origin_derivative =
- context.getTensorGrad(cell_state_origin_idx);
- cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_origin_derivative =
- cell_state_origin_derivative.getBatchSlice(timestep, 1);
- next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
-
- next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
- next_cell_state_origin_derivative);
-
- init_lstm_context::fillWeights(
- weights, context, true,
- getWeightIdx(wt_idx, disable_bias, integrate_bias, test), max_timestep,
- timestep, test);
- init_lstm_context::fillInputs(inputs, context, true, getInputIdx(wt_idx),
- max_timestep, timestep);
- init_lstm_context::fillOutputs(outputs, context, true, getOutputIdx(wt_idx),
- max_timestep, timestep);
- init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
- max_timestep, timestep);
- RunLayerContext core_context(context.getName(), context.getTrainable(),
- context.getLoss(), context.executeInPlace(),
- init_lstm_context::getWeights(weights),
- init_lstm_context::getInputs(inputs),
- init_lstm_context::getOutputs(outputs),
- init_lstm_context::getTensors(tensors));
- lstmcellcorelayer.calcGradient(core_context);
-
- if (timestep) {
- prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
- prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
- }
+ lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
+ acti_func, recurrent_acti_func, input,
+ prev_hidden_state, d_prev_hidden_state, prev_cell_state,
+ d_prev_cell_state, d_hidden_state, lstm_cell_state,
+ d_lstm_cell_state, d_weight_ih, weight_hh, d_weight_hh,
+ d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+
+ d_prev_hidden_state.add_i(d_prev_hidden_state_residual);
+ d_prev_cell_state.add_i(d_prev_cell_state_residual);
}
void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
unsigned int batch) {
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+
context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
max_timestep * batch);
context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state],
max_timestep * batch);
- context.updateTensor(hidden_state_origin_idx, max_timestep * batch);
- context.updateTensor(cell_state_origin_idx, max_timestep * batch);
- context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch);
+ context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
+ context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);
context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
max_timestep * batch);