*
*/
-#include <cmath>
#include <layer_context.h>
#include <lstmcell.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
-#include <util_func.h>
namespace nntrainer {
static constexpr size_t SINGLE_INOUT_IDX = 0;
-enum LSTMParams {
- weight_xh,
+enum LSTMCellParams {
+ weight_ih,
weight_hh,
- bias_h,
+ bias_ih,
hidden_state,
- mem_cell,
- fgio,
+ cell_state,
+ ifgo,
dropout_mask
};
+const std::vector<unsigned int>
+getInOutIdx(std::array<unsigned int, 7> &wt_idx) {
+ std::vector<unsigned int> ret(3);
+ ret[0] = SINGLE_INOUT_IDX;
+ ret[1] = wt_idx[LSTMCellParams::hidden_state];
+ ret[2] = wt_idx[LSTMCellParams::cell_state];
+ return ret;
+}
+
+const std::vector<unsigned int>
+getTensorIdx(std::array<unsigned int, 7> &wt_idx) {
+ std::vector<unsigned int> ret(1);
+ ret[0] = wt_idx[LSTMCellParams::ifgo];
+ return ret;
+}
+
LSTMCellLayer::LSTMCellLayer() :
LayerImpl(),
- lstm_props(props::Unit(), props::HiddenStateActivation(),
- props::RecurrentActivation(), props::DropOutRate(),
- props::MaxTimestep(), props::Timestep()),
+ lstmcell_props(props::Unit(), props::DropOutRate(), props::MaxTimestep(),
+ props::Timestep()),
wt_idx({0}),
- acti_func(ActivationType::ACT_NONE, true),
- recurrent_acti_func(ActivationType::ACT_NONE, true),
epsilon(1e-3) {}
-// - weight_xh ( input to hidden )
-// : [1, 1, input_size, unit (hidden_size) x NUM_GATE] -> f, g, i, o
-// - weight_hh ( hidden to hidden )
-// : [1, 1, unit (hidden_size) , unit (hidden_size) x NUM_GATE] -> f, g, i, o
-// - bias_h ( hidden bias )
-// : [1, 1, 1, unit (hidden_size) x NUM_GATE] -> f, g, i, o
void LSTMCellLayer::finalize(InitLayerContext &context) {
- auto &weight_regularizer =
+ NNTR_THROW_IF(std::get<props::Unit>(lstmcell_props).empty(),
+ std::invalid_argument)
+ << "unit property missing for lstmcell layer";
+ const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+ const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(lstmcell_props);
+
+#if !ENABLE_SHARING_WT_IDX
+ const Tensor::Initializer weight_initializer =
+ std::get<props::WeightInitializer>(*layer_impl_props);
+ const Tensor::Initializer bias_initializer =
+ std::get<props::BiasInitializer>(*layer_impl_props);
+ const nntrainer::WeightRegularizer weight_regularizer =
std::get<props::WeightRegularizer>(*layer_impl_props);
- auto &weight_regularizer_constant =
+ const float weight_regularizer_constant =
std::get<props::WeightRegularizerConstant>(*layer_impl_props);
- auto &weight_initializer =
- std::get<props::WeightInitializer>(*layer_impl_props);
- auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
-
- NNTR_THROW_IF(std::get<props::Unit>(lstm_props).empty(),
- std::invalid_argument)
- << "unit property missing for lstm layer";
- auto unit = std::get<props::Unit>(lstm_props).get();
- auto &hidden_state_activation_type =
- std::get<props::HiddenStateActivation>(lstm_props);
- auto &recurrent_activation_type =
- std::get<props::RecurrentActivation>(lstm_props);
- float dropout_rate = std::get<props::DropOutRate>(lstm_props);
+#endif
if (context.getNumInputs() != 1)
- throw std::invalid_argument("LSTM layer takes only one input");
- if (std::get<props::MaxTimestep>(lstm_props).empty())
+ throw std::invalid_argument("LSTMCell layer takes only one input");
+ if (std::get<props::MaxTimestep>(lstmcell_props).empty())
throw std::invalid_argument(
- "Number of unroll steps must be provided to LSTM cells");
- if (std::get<props::Timestep>(lstm_props).empty())
+ "Number of unroll steps(max timestep) must be provided to LSTM cell");
+ if (std::get<props::Timestep>(lstmcell_props).empty())
throw std::invalid_argument(
"Current Timestep must be provided to LSTM cell");
- // input_dim = [ batch, 1, 1, feature_size ]
- TensorDim output_dim;
+ // input_dim = [ batch_size, 1, 1, feature_size ]
const TensorDim &input_dim = context.getInputDimensions()[0];
- if (input_dim.height() != 1 && input_dim.channel() != 1)
+ if (input_dim.height() != 1 || input_dim.channel() != 1)
throw std::invalid_argument(
- "Input must be single time dimension for LSTMCell");
- // output_dim = [ batch, 1, 1, hidden_size (unit)]
- output_dim = input_dim;
- output_dim.width(unit);
-
- if (dropout_rate > epsilon) {
- wt_idx[LSTMParams::dropout_mask] = context.requestTensor(
- output_dim, "dropout_mask", Tensor::Initializer::NONE, false,
- TensorLifespan::ITERATION_LIFESPAN);
- }
-
+ "Input must be single time dimension for LSTMCell (shape should be "
+ "[batch_size, 1, 1, feature_size])");
+ const unsigned int batch_size = input_dim.batch();
+#if !ENABLE_SHARING_WT_IDX
+ const unsigned int feature_size = input_dim.width();
+#endif
+
+ // output_dim = [ batch_size, 1, 1, unit ]
+ const TensorDim output_dim(batch_size, 1, 1, unit);
context.setOutputDimensions({output_dim});
- TensorDim bias_dim = TensorDim();
- bias_dim.setTensorDim(3, unit * NUM_GATE);
-
- TensorDim dim_xh = output_dim;
- dim_xh.height(input_dim.width());
- dim_xh.width(unit * NUM_GATE);
- dim_xh.batch(1);
-
- TensorDim dim_hh = output_dim;
- dim_hh.height(unit);
- dim_hh.width(unit * NUM_GATE);
- dim_hh.batch(1);
-
- // weight_initializer can be set seperately. weight_xh initializer,
+#if !ENABLE_SHARING_WT_IDX
+ // weight_initializer can be set seperately. weight_ih initializer,
// weight_hh initializer kernel initializer & recurrent_initializer in keras
// for now, it is set same way.
- wt_idx[LSTMParams::weight_xh] =
- context.requestWeight(dim_xh, weight_initializer, weight_regularizer,
- weight_regularizer_constant, "weight_xh", true);
- wt_idx[LSTMParams::weight_hh] =
- context.requestWeight(dim_hh, weight_initializer, weight_regularizer,
- weight_regularizer_constant, "weight_hh", true);
- wt_idx[LSTMParams::bias_h] = context.requestWeight(
- bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias_h", true);
-
- unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
-
- TensorDim d = input_dim;
- // d.height(d.batch());
- d.height(1);
- d.batch(max_timestep * d.batch());
- d.width(unit);
- /** hidden dim = [ UnrollLength, 1, Batch, Units ] */
- wt_idx[LSTMParams::hidden_state] =
- context.requestTensor(d, "hidden_state", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
- wt_idx[LSTMParams::mem_cell] =
- context.requestTensor(d, "mem_cell", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
+ // - weight_ih ( input to hidden )
+ // : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
+ TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+ wt_idx[LSTMCellParams::weight_ih] =
+ context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+ weight_regularizer_constant, "weight_ih", true);
+ // - weight_hh ( hidden to hidden )
+ // : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
+ TensorDim weight_hh_dim({unit, NUM_GATE * unit});
+ wt_idx[LSTMCellParams::weight_hh] =
+ context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
+ weight_regularizer_constant, "weight_hh", true);
+ // - bias_ih ( input bias )
+ // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
+ TensorDim bias_ih_dim({NUM_GATE * unit});
+ wt_idx[LSTMCellParams::bias_ih] =
+ context.requestWeight(bias_ih_dim, bias_initializer,
+ WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+#endif
+
+ // dropout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
+ const TensorDim dropout_mask_dim(max_timestep * batch_size, 1, 1, unit);
+ if (dropout_rate > epsilon) {
+ wt_idx[LSTMCellParams::dropout_mask] = context.requestTensor(
+ dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+ TensorLifespan::ITERATION_LIFESPAN);
+ }
+ /**
+ * TODO: hidden_state is only used from the previous timestep. Once it is
+ * supported as input, no need to cache the hidden_state itself
+ */
+ /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+ const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
+ wt_idx[LSTMCellParams::hidden_state] = context.requestTensor(
+ hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
+ TensorLifespan::ITERATION_LIFESPAN, false);
+ /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
+ const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
+ wt_idx[LSTMCellParams::cell_state] = context.requestTensor(
+ cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
+ TensorLifespan::ITERATION_LIFESPAN, false);
+#if !ENABLE_SHARING_WT_IDX
/**
* TODO: make this independent of time dimension once recurrent realizer
* supports requesting tensors which are not always shared
- *
- * TODO: reorder to ifgo for better performance. This will require change in
- * stored weights in the test
*/
- d.width(unit * NUM_GATE);
- wt_idx[LSTMParams::fgio] =
- context.requestTensor(d, "fgio", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
+ /** ifgo_dim = [ max_timestep * batch_size, 1, 1, NUM_GATE * unit ] */
+ const TensorDim ifgo_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
+ wt_idx[LSTMCellParams::ifgo] =
+ context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
+ TensorLifespan::ITERATION_LIFESPAN);
+#endif
+
+ TensorDim hidden_state_t_dim({batch_size, 1, 1, unit});
+ TensorDim cell_state_t_dim({batch_size, 1, 1, unit});
+ InitLayerContext core_context(
+ {input_dim, hidden_state_t_dim, cell_state_t_dim}, 3,
+ context.executeInPlace(), context.getName());
+ lstmcellcorelayer.finalize(core_context);
+ init_lstm_context::fillLayerInitContext(context, core_context);
+}
- if (hidden_state_activation_type.get() == ActivationType::ACT_NONE) {
- hidden_state_activation_type.set(ActivationType::ACT_TANH);
+void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
+ std::vector<std::string> remain_props =
+ loadProperties(values, lstmcell_props);
+
+ // Note: In current implementation the lstmcellcorelayer also has
+ // a properties related to weight. But it is not exported or used anywhere.
+ lstmcellcorelayer.setProperty(remain_props);
+ if (!std::get<props::Unit>(lstmcell_props).empty()) {
+ lstmcellcorelayer.setProperty(
+ {"unit=" + to_string(std::get<props::Unit>(lstmcell_props))});
}
- acti_func.setActiFunc(hidden_state_activation_type.get());
- if (recurrent_activation_type.get() == ActivationType::ACT_NONE) {
- recurrent_activation_type.set(ActivationType::ACT_SIGMOID);
- }
- recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
-}
+#if !ENABLE_SHARING_WT_IDX
+ // To remove lstmcell core layer's properties
+ std::tuple<props::HiddenStateActivation, props::RecurrentActivation>
+ lstmcell_core_props;
+ std::vector<std::string> impl_props =
+ loadProperties(remain_props, lstmcell_core_props);
-void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
- auto remain_props = loadProperties(values, lstm_props);
- LayerImpl::setProperty(remain_props);
+ LayerImpl::setProperty(impl_props);
+#endif
}
void LSTMCellLayer::exportTo(Exporter &exporter,
const ExportMethods &method) const {
+#if !ENABLE_SHARING_WT_IDX
LayerImpl::exportTo(exporter, method);
- exporter.saveResult(lstm_props, method, this);
+#endif
+ exporter.saveResult(
+ std::forward_as_tuple(std::get<props::DropOutRate>(lstmcell_props),
+ std::get<props::MaxTimestep>(lstmcell_props),
+ std::get<props::Timestep>(lstmcell_props)),
+ method, this);
+ lstmcellcorelayer.exportTo(exporter, method);
}
void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
- auto unit = std::get<props::Unit>(lstm_props).get();
- float dropout_rate = std::get<props::DropOutRate>(lstm_props);
-
- Tensor &weight_xh = context.getWeight(wt_idx[LSTMParams::weight_xh]);
- Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
- Tensor &bias_h = context.getWeight(wt_idx[LSTMParams::bias_h]);
-
- Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
- Tensor &hidden_ = context.getTensor(wt_idx[LSTMParams::hidden_state]);
- Tensor &cell_ = context.getTensor(wt_idx[LSTMParams::mem_cell]);
- Tensor &fgio = context.getTensor(wt_idx[LSTMParams::fgio]);
- const TensorDim &input_dim = input_.getDim();
- unsigned int batch = input_dim.batch();
-
- unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
- if (start_timestep == 0) {
- hidden_.setZero();
- cell_.setZero();
- }
-
- unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
- hidden_.reshape({max_timestep, 1, batch, hidden_.width()});
- cell_.reshape({max_timestep, 1, batch, cell_.width()});
- fgio.reshape({max_timestep, 1, batch, fgio.width()});
-
- /**
- * @note when the recurrent realization happens, different instances of lstm
- * will share the weights, hidden state, cell and fgio memory. However, they
- * do not share the input, output and derivatives memory. The input/output
- * will be contain a single timestep data only.
- */
- Tensor hs = hidden_.getBatchSlice(start_timestep, 1);
- Tensor cs = cell_.getBatchSlice(start_timestep, 1);
- Tensor fgio_t = fgio.getBatchSlice(start_timestep, 1);
-
- input_.dot(weight_xh, fgio_t);
-
- if (start_timestep > 0) {
- Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
- hs_prev.dot(weight_hh, fgio_t, false, false, 1.0);
- }
-
- fgio_t.add_i(bias_h);
- Tensor hif = fgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
- Tensor hi = fgio_t.getSharedDataTensor({batch, unit}, 0, false);
- Tensor hf = fgio_t.getSharedDataTensor({batch, unit}, unit, false);
- Tensor hg = fgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
- Tensor ho = fgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
- recurrent_acti_func.run_fn(hif, hif);
- recurrent_acti_func.run_fn(ho, ho);
- acti_func.run_fn(hg, hg);
-
- if (start_timestep > 0) {
- Tensor cs_prev = cell_.getBatchSlice(start_timestep - 1, 1);
- hf.multiply_strided(cs_prev, cs);
+ const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+ const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(lstmcell_props);
+ const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+ const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+ const TensorDim &input_dim = input.getDim();
+ const unsigned int batch_size = input_dim.batch();
+
+ Tensor &hidden_state =
+ context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
+ hidden_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
+ next_hidden_state.reshape({batch_size, 1, 1, unit});
+
+ Tensor &cell_state = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
+
+ if (!timestep) {
+ hidden_state.setZero();
+ cell_state.setZero();
}
- hg.multiply_strided(hi, cs, 1.0);
- acti_func.run_fn(cs, hs);
- hs.multiply_i_strided(ho);
+ init_lstm_context::fillWeights(weights, context, training, max_timestep,
+ timestep);
+ init_lstm_context::fillInputs(inputs, context, training, getInOutIdx(wt_idx),
+ max_timestep, timestep);
+ init_lstm_context::fillOutputs(outputs, context, training,
+ getInOutIdx(wt_idx), max_timestep, timestep);
+ init_lstm_context::fillTensors(tensors, context, training,
+ getTensorIdx(wt_idx), max_timestep, timestep);
+ RunLayerContext core_context(context.getName(), context.getTrainable(),
+ context.getLoss(), context.executeInPlace(),
+ init_lstm_context::getWeights(weights),
+ init_lstm_context::getInputs(inputs),
+ init_lstm_context::getOutputs(outputs),
+ init_lstm_context::getTensors(tensors));
+ lstmcellcorelayer.forwarding(core_context, training);
if (dropout_rate > epsilon && training) {
- Tensor &mask_ = context.getTensor(wt_idx[LSTMParams::dropout_mask]);
- hs.dropout_mask(dropout_rate);
- hs.multiply_i(mask_);
+ Tensor &dropout_mask =
+ context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
+ dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+ dropout_mask_t.dropout_mask(dropout_rate);
+ next_hidden_state.multiply_i(dropout_mask_t);
}
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
- std::copy(hs.getData(), hs.getData() + hs.size(), output.getData());
+ output.copyData(next_hidden_state);
}
void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
- Tensor &derivative_ = context.getTensorGrad(wt_idx[LSTMParams::fgio]);
- Tensor &weight = context.getWeight(wt_idx[LSTMParams::weight_xh]);
- Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
-
- unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
- derivative_.reshape({max_timestep, 1, ret_.batch(), derivative_.width()});
-
- /** get the timestep values */
- unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
- Tensor deriv_t = derivative_.getBatchSlice(start_timestep, 1);
- deriv_t.dot(weight, ret_, false, true);
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(lstmcell_props);
+ const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+ init_lstm_context::fillWeights(weights, context, true, max_timestep,
+ timestep);
+ init_lstm_context::fillInputs(inputs, context, true, getInOutIdx(wt_idx),
+ max_timestep, timestep);
+ init_lstm_context::fillOutputs(outputs, context, true, getInOutIdx(wt_idx),
+ max_timestep, timestep);
+ init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+ max_timestep, timestep);
+ RunLayerContext core_context(context.getName(), context.getTrainable(),
+ context.getLoss(), context.executeInPlace(),
+ init_lstm_context::getWeights(weights),
+ init_lstm_context::getInputs(inputs),
+ init_lstm_context::getOutputs(outputs),
+ init_lstm_context::getTensors(tensors));
+ lstmcellcorelayer.calcDerivative(core_context);
}
void LSTMCellLayer::calcGradient(RunLayerContext &context) {
- auto unit = std::get<props::Unit>(lstm_props).get();
- float dropout_rate = std::get<props::DropOutRate>(lstm_props);
-
- Tensor &djdw_x = context.getWeightGrad(wt_idx[LSTMParams::weight_xh]);
- Tensor &djdw_h = context.getWeightGrad(wt_idx[LSTMParams::weight_hh]);
- Tensor &djdb_h = context.getWeightGrad(wt_idx[LSTMParams::bias_h]);
- Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
-
- Tensor &derivative_ = context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
- /**
- * TODO: hidden_ is only used from the previous timestep. Once it is supported
- * as input, no need to cache the hidden_ itself
- */
- Tensor &hidden_ = context.getTensor(wt_idx[LSTMParams::hidden_state]);
- Tensor &incoming_deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
- Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
- Tensor &m_cell_ = context.getTensor(wt_idx[LSTMParams::mem_cell]);
- Tensor &dm_cell_ = context.getTensorGrad(wt_idx[LSTMParams::mem_cell]);
- Tensor &fgio = context.getTensor(wt_idx[LSTMParams::fgio]);
- Tensor &d_fgio = context.getTensorGrad(wt_idx[LSTMParams::fgio]);
- const TensorDim &input_dim = input_.getDim();
- unsigned int batch = input_dim.batch();
-
- /** get the timestep values */
- unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
- unsigned int start_timestep = std::get<props::Timestep>(lstm_props);
-
- derivative_.reshape({max_timestep, 1, batch, derivative_.width()});
- hidden_.reshape({max_timestep, 1, batch, hidden_.width()});
- m_cell_.reshape({max_timestep, 1, batch, m_cell_.width()});
- dm_cell_.reshape({max_timestep, 1, batch, dm_cell_.width()});
- fgio.reshape({max_timestep, 1, batch, fgio.width()});
- d_fgio.reshape({max_timestep, 1, batch, d_fgio.width()});
-
- if (start_timestep + 1 == max_timestep) {
- djdw_x.setZero();
- djdw_h.setZero();
- djdb_h.setZero();
- }
-
- Tensor dh = derivative_.getBatchSlice(start_timestep, 1);
- dh.reshape(incoming_deriv.getDim());
- if (start_timestep + 1 == max_timestep) {
- dh.copyData(incoming_deriv);
- } else {
- dh.add_i(incoming_deriv);
+ const unsigned int unit = std::get<props::Unit>(lstmcell_props).get();
+ const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(lstmcell_props);
+ const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
+
+ unsigned int batch_size = context.getInput(SINGLE_INOUT_IDX).getDim().batch();
+
+ const Tensor &incoming_derivative =
+ context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+ Tensor &hidden_state_derivative =
+ context.getTensorGrad(wt_idx[LSTMCellParams::hidden_state]);
+ hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state_derivative =
+ hidden_state_derivative.getBatchSlice(timestep, 1);
+ next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+
+ Tensor &cell_state_derivative =
+ context.getTensorGrad(wt_idx[LSTMCellParams::cell_state]);
+ cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state_derivative =
+ cell_state_derivative.getBatchSlice(timestep, 1);
+ next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+
+ if (timestep + 1 == max_timestep) {
+ Tensor &djdweight_ih =
+ context.getWeightGrad(wt_idx[LSTMCellParams::weight_ih]);
+ Tensor &djdweight_hh =
+ context.getWeightGrad(wt_idx[LSTMCellParams::weight_hh]);
+ Tensor &djdbias_ih = context.getWeightGrad(wt_idx[LSTMCellParams::bias_ih]);
+ djdweight_ih.setZero();
+ djdweight_hh.setZero();
+ djdbias_ih.setZero();
+
+ next_hidden_state_derivative.setZero();
+ next_cell_state_derivative.setZero();
}
- dh = derivative_.getBatchSlice(start_timestep, 1);
if (dropout_rate > epsilon) {
- derivative_.multiply_i(context.getTensor(wt_idx[LSTMParams::dropout_mask]));
+ Tensor &dropout_mask =
+ context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
+ dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+ next_hidden_state_derivative.multiply_i(dropout_mask_t);
}
- Tensor dc = dm_cell_.getBatchSlice(start_timestep, 1);
- Tensor xs = input_;
- Tensor hs_t = hidden_.getBatchSlice(start_timestep, 1);
- Tensor cs = m_cell_.getBatchSlice(start_timestep, 1);
-
- Tensor dfgio_t = d_fgio.getBatchSlice(start_timestep, 1);
- Tensor fgio_t = fgio.getBatchSlice(start_timestep, 1);
-
- Tensor dhif = dfgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
- Tensor dhi = dfgio_t.getSharedDataTensor({batch, unit}, 0, false);
- Tensor dhf = dfgio_t.getSharedDataTensor({batch, unit}, unit, false);
- Tensor dhg = dfgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
- Tensor dho = dfgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
- Tensor hif = fgio_t.getSharedDataTensor({batch, unit * 2}, 0, false);
- Tensor hi = fgio_t.getSharedDataTensor({batch, unit}, 0, false);
- Tensor hf = fgio_t.getSharedDataTensor({batch, unit}, unit, false);
- Tensor hg = fgio_t.getSharedDataTensor({batch, unit}, unit * 2, false);
- Tensor ho = fgio_t.getSharedDataTensor({batch, unit}, unit * 3, false);
-
- acti_func.run_fn(cs, cs);
- cs.multiply_strided(dh, dho);
-
- if (start_timestep + 1 == max_timestep) {
- acti_func.run_prime_fn(cs, dc, dh);
- dc.multiply_i_strided(ho);
- } else {
- /// @todo optimize this by updating run_prime_fn to accumulate or make
- /// it inplace somehow
- Tensor dc_temp(dc.getDim());
- acti_func.run_prime_fn(cs, dc_temp, dh);
- dc_temp.multiply_strided(ho, dc, 1.0);
- }
-
- if (start_timestep > 0) {
- Tensor dc_nx = dm_cell_.getBatchSlice(start_timestep - 1, 1);
- dc.multiply_strided(hf, dc_nx);
- Tensor cs_prev = m_cell_.getBatchSlice(start_timestep - 1, 1);
- dc.multiply_strided(cs_prev, dhf);
- } else {
- dhf.setZero();
- }
-
- dc.multiply_strided(hg, dhi);
- dc.multiply_strided(hi, dhg);
-
- recurrent_acti_func.run_prime_fn(ho, dho, dho);
- recurrent_acti_func.run_prime_fn(hif, dhif, dhif);
- acti_func.run_prime_fn(hg, dhg, dhg);
- dfgio_t.sum(2, djdb_h, 1.0, 1.0);
-
- xs.dot(dfgio_t, djdw_x, true, false, 1.0f);
- if (start_timestep != 0) {
- Tensor hs_prev = hidden_.getBatchSlice(start_timestep - 1, 1);
- hs_prev.dot(dfgio_t, djdw_h, true, false, 1.0f);
- Tensor dh_nx = derivative_.getBatchSlice(start_timestep - 1, 1);
- dfgio_t.dot(weight_hh, dh_nx, false, true, 1.0f);
- }
+ next_hidden_state_derivative.add_i(incoming_derivative);
+
+ init_lstm_context::fillWeights(weights, context, true, max_timestep,
+ timestep);
+ init_lstm_context::fillInputs(inputs, context, true, getInOutIdx(wt_idx),
+ max_timestep, timestep);
+ init_lstm_context::fillOutputs(outputs, context, true, getInOutIdx(wt_idx),
+ max_timestep, timestep);
+ init_lstm_context::fillTensors(tensors, context, true, getTensorIdx(wt_idx),
+ max_timestep, timestep);
+ RunLayerContext core_context(context.getName(), context.getTrainable(),
+ context.getLoss(), context.executeInPlace(),
+ init_lstm_context::getWeights(weights),
+ init_lstm_context::getInputs(inputs),
+ init_lstm_context::getOutputs(outputs),
+ init_lstm_context::getTensors(tensors));
+ lstmcellcorelayer.calcGradient(core_context);
}
void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
- unsigned int max_timestep = std::get<props::MaxTimestep>(lstm_props);
- context.updateTensor(wt_idx[LSTMParams::hidden_state], batch * max_timestep);
- context.updateTensor(wt_idx[LSTMParams::mem_cell], batch * max_timestep);
- context.updateTensor(wt_idx[LSTMParams::fgio], batch * max_timestep);
-
- const float dropout_rate = std::get<props::DropOutRate>(lstm_props);
+ const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(lstmcell_props);
+ context.updateTensor(wt_idx[LSTMCellParams::hidden_state],
+ max_timestep * batch);
+ context.updateTensor(wt_idx[LSTMCellParams::cell_state],
+ max_timestep * batch);
+ context.updateTensor(wt_idx[LSTMCellParams::ifgo], max_timestep * batch);
if (dropout_rate > epsilon) {
- context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
+ context.updateTensor(wt_idx[LSTMCellParams::dropout_mask],
+ max_timestep * batch);
}
}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+ *
+ * @file lstmcell_core.cpp
+ * @date 25 November 2021
+ * @brief This is LSTMCellCore Layer Class of Neural Network
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author hyeonseok lee <hs89.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#include <layer_context.h>
+#include <lstmcell_core.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+
+// ENABLE_SHARING_WT_IDX implies does the wt_idx of lstm_core can be shared
+// with lstm_cell variant layer.
+// Todo: remove this if sharing wt_idx with other lstm variant is enabled
+#define ENABLE_SHARING_WT_IDX 0
+
+namespace nntrainer {
+
+namespace init_lstm_context {
+void fillLayerInitContext(InitLayerContext &context,
+ const InitLayerContext &core_context) {
+ /** real set the input flags */
+ auto const &input_dims = context.getInputDimensions();
+ for (unsigned int idx = 0; idx < core_context.getNumInputs(); idx++) {
+ context.setDynDimFlagInputDimension(idx, input_dims[idx].getDynDimFlag());
+ context.setEffDimFlagInputDimension(idx, input_dims[idx].getEffDimFlag());
+ }
+
+ /** real request of tensors */
+ for (auto const &ts : core_context.getTensorsSpec())
+ context.requestTensor(ts);
+
+ /** real request of weights */
+ for (auto const &ws : core_context.getWeightsSpec())
+ context.requestWeight(ws);
+}
+
+void fillWeights(std::vector<Weight> &weights, const RunLayerContext &context,
+ bool training, const unsigned int max_timestep,
+ const unsigned int timestep, bool test) {
+ weights.resize(context.getNumWeights());
+ for (unsigned int i = 0; i < context.getNumWeights(); ++i) {
+ if (training && (!test || i < 3u)) {
+ weights[i] = Weight(context.getWeight(i), context.getWeightGrad(i),
+ context.getWeightName(i));
+ } else {
+ weights[i] =
+ Weight(context.getWeight(i), Tensor(), context.getWeightName(i));
+ }
+ }
+}
+
+const std::vector<Weight *> getWeights(std::vector<Weight> &weights) {
+ std::vector<Weight *> ret(weights.size());
+ for (unsigned int i = 0; i < weights.size(); ++i) {
+ ret[i] = &weights[i];
+ }
+ return ret;
+}
+
+void fillInputs(std::vector<Var_Grad> &inputs, RunLayerContext &context,
+ bool training, const std::vector<unsigned int> &wt_idx,
+ const unsigned int max_timestep, const unsigned int timestep) {
+ inputs.resize(3);
+ Tensor empty;
+ const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
+ const unsigned int batch_size = output_dim.batch();
+ const unsigned int unit = output_dim.width();
+
+ const Tensor &input = context.getInput(wt_idx[0]);
+ const Tensor &outgoing_derivative =
+ training ? context.getOutgoingDerivative(0) : empty;
+
+ Tensor &hidden_state = context.getTensor(wt_idx[1]);
+ hidden_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor &hidden_state_derivative =
+ training ? context.getTensorGrad(wt_idx[1]) : empty;
+ if (training) {
+ hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ }
+
+ Tensor &cell_state = context.getTensor(wt_idx[2]);
+ cell_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor &cell_state_derivative =
+ training ? context.getTensorGrad(wt_idx[2]) : empty;
+ if (training) {
+ cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ }
+
+ Tensor prev_hidden_state;
+ Tensor prev_hidden_state_derivative;
+ Tensor prev_cell_state;
+ Tensor prev_cell_state_derivative;
+ if (!timestep) {
+ prev_hidden_state = Tensor(batch_size, 1, 1, unit);
+ prev_hidden_state.setZero();
+ prev_hidden_state_derivative = Tensor(batch_size, 1, 1, unit);
+ prev_hidden_state_derivative.setZero();
+ prev_cell_state = Tensor(batch_size, 1, 1, unit);
+ prev_cell_state.setZero();
+ prev_cell_state_derivative = Tensor(batch_size, 1, 1, unit);
+ prev_cell_state_derivative.setZero();
+ } else {
+ prev_hidden_state = hidden_state.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state.reshape({batch_size, 1, 1, unit});
+ if (training) {
+ prev_hidden_state_derivative =
+ hidden_state_derivative.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+ }
+ prev_cell_state = cell_state.getBatchSlice(timestep - 1, 1);
+ prev_cell_state.reshape({batch_size, 1, 1, unit});
+ if (training) {
+ prev_cell_state_derivative =
+ cell_state_derivative.getBatchSlice(timestep - 1, 1);
+ prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+ }
+ }
+
+ inputs[0] = Var_Grad(input, outgoing_derivative);
+ inputs[1] = Var_Grad(prev_hidden_state, prev_hidden_state_derivative,
+ context.getTensorName(wt_idx[1]));
+ inputs[2] = Var_Grad(prev_cell_state, prev_cell_state_derivative,
+ context.getTensorName(wt_idx[2]));
+}
+
+const std::vector<Var_Grad *> getInputs(std::vector<Var_Grad> &inputs) {
+ std::vector<Var_Grad *> ret(inputs.size());
+ for (unsigned int i = 0; i < inputs.size(); ++i) {
+ ret[i] = &inputs[i];
+ }
+ return ret;
+}
+
+void fillOutputs(std::vector<Var_Grad> &outputs, RunLayerContext &context,
+ bool training, const std::vector<unsigned int> &wt_idx,
+ const unsigned int max_timestep, const unsigned int timestep) {
+ outputs.resize(2);
+ Tensor empty;
+ const TensorDim &output_dim = context.getOutput(wt_idx[0]).getDim();
+ const unsigned int batch_size = output_dim.batch();
+ const unsigned int unit = output_dim.width();
+
+ Tensor &hidden_state = context.getTensor(wt_idx[1]);
+ hidden_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state = hidden_state.getBatchSlice(timestep, 1);
+ next_hidden_state.reshape({batch_size, 1, 1, unit});
+ Tensor next_hidden_state_derivative;
+ if (training) {
+ Tensor &hidden_state_derivative = context.getTensorGrad(wt_idx[1]);
+ hidden_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ next_hidden_state_derivative =
+ hidden_state_derivative.getBatchSlice(timestep, 1);
+ next_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+ }
+
+ Tensor &cell_state = context.getTensor(wt_idx[2]);
+ cell_state.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state = cell_state.getBatchSlice(timestep, 1);
+ next_cell_state.reshape({batch_size, 1, 1, unit});
+ Tensor next_cell_state_derivative;
+ if (training) {
+ Tensor &cell_state_derivative = context.getTensorGrad(wt_idx[2]);
+ cell_state_derivative.reshape({max_timestep, 1, batch_size, unit});
+ next_cell_state_derivative =
+ cell_state_derivative.getBatchSlice(timestep, 1);
+ next_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+ }
+
+ outputs[0] = Var_Grad(next_hidden_state, next_hidden_state_derivative,
+ context.getTensorName(wt_idx[1]));
+ outputs[1] = Var_Grad(next_cell_state, next_cell_state_derivative,
+ context.getTensorName(wt_idx[2]));
+}
+
+const std::vector<Var_Grad *> getOutputs(std::vector<Var_Grad> &outputs) {
+ std::vector<Var_Grad *> ret(outputs.size());
+ for (unsigned int i = 0; i < outputs.size(); ++i) {
+ ret[i] = &outputs[i];
+ }
+ return ret;
+}
+
+void fillTensors(std::vector<Var_Grad> &tensors, RunLayerContext &context,
+ bool training, const std::vector<unsigned int> &wt_idx,
+ const unsigned int max_timestep, const unsigned int timestep) {
+ tensors.resize(1);
+
+#if ENABLE_SHARING_WT_IDX
+ const Tensor &ifgo = context.getTensor(wt_idx[0]);
+ Tensor empty;
+ const Tensor &ifgo_derivative =
+ training ? context.getTensorGrad(wt_idx[0]) : empty;
+ tensors[0] =
+ Var_Grad(ifgo, ifgo_derivative, context.getTensorName(wt_idx[0]));
+#else
+ const TensorDim &output_dim = context.getOutput(0).getDim();
+ const unsigned int batch_size = output_dim.batch();
+ const unsigned int unit = output_dim.width();
+
+ Tensor &ifgo = context.getTensor(wt_idx[0]);
+ const unsigned int NUM_GATE = ifgo.width() / unit;
+ ifgo.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+ Tensor ifgo_t = ifgo.getBatchSlice(timestep, 1);
+ ifgo_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
+ Tensor ifgo_derivative_t;
+ if (training) {
+ Tensor &ifgo_derivative = context.getTensorGrad(wt_idx[0]);
+ ifgo_derivative.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+ ifgo_derivative_t = ifgo_derivative.getBatchSlice(timestep, 1);
+ ifgo_derivative_t.reshape({batch_size, 1, 1, NUM_GATE * unit});
+ }
+ tensors[0] =
+ Var_Grad(ifgo_t, ifgo_derivative_t, context.getTensorName(wt_idx[0]));
+ context.getTensorName(wt_idx[0]);
+#endif
+}
+
+const std::vector<Var_Grad *> getTensors(std::vector<Var_Grad> &tensors) {
+ std::vector<Var_Grad *> ret(tensors.size());
+ for (unsigned int i = 0; i < tensors.size(); ++i) {
+ ret[i] = &tensors[i];
+ }
+ return ret;
+}
+
+} // namespace init_lstm_context
+
+enum INDEX {
+ INPUT = 0,
+ HIDDEN_STATE_IN = 1,
+ CELL_STATE_IN = 2,
+ HIDDEN_STATE_OUT = 0,
+ CELL_STATE_OUT = 1
+};
+
+enum LSTMCellCoreParams {
+ weight_ih,
+ weight_hh,
+ bias_ih,
+ ifgo,
+};
+
+LSTMCellCoreLayer::LSTMCellCoreLayer() :
+ LayerImpl(),
+ lstmcell_core_props(
+ props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH,
+ props::RecurrentActivation() = ActivationType::ACT_SIGMOID),
+ wt_idx({0}),
+ acti_func(ActivationType::ACT_NONE, true),
+ recurrent_acti_func(ActivationType::ACT_NONE, true) {}
+
+void LSTMCellCoreLayer::finalize(InitLayerContext &context) {
+ NNTR_THROW_IF(std::get<props::Unit>(lstmcell_core_props).empty(),
+ std::invalid_argument)
+ << "unit property missing for lstmcell_core layer";
+ const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+ const nntrainer::props::HiddenStateActivation hidden_state_activation_type =
+ std::get<props::HiddenStateActivation>(lstmcell_core_props);
+ const nntrainer::props::RecurrentActivation recurrent_activation_type =
+ std::get<props::RecurrentActivation>(lstmcell_core_props);
+
+#if ENBABLE_SHARING_WEIGHT
+ const Tensor::Initializer weight_initializer =
+ std::get<props::WeightInitializer>(*layer_impl_props);
+ const Tensor::Initializer bias_initializer =
+ std::get<props::BiasInitializer>(*layer_impl_props);
+ const nntrainer::WeightRegularizer weight_regularizer =
+ std::get<props::WeightRegularizer>(*layer_impl_props);
+ const float weight_regularizer_constant =
+ std::get<props::WeightRegularizerConstant>(*layer_impl_props);
+#endif
+
+ if (context.getNumInputs() != 3)
+ throw std::invalid_argument("LSTMCellCore layer should takes 3 input");
+
+ // input_dim = [ batch, 1, 1, feature_size ]
+ const TensorDim &input_dim = context.getInputDimensions()[0];
+ if (input_dim.height() != 1 || input_dim.channel() != 1)
+ throw std::invalid_argument(
+ "Input must be single time dimension for LSTMCellCore");
+ // input_hidden_state_dim = [ batch, 1, 1, unit ]
+ const TensorDim &input_hidden_state_dim = context.getInputDimensions()[1];
+ if (input_hidden_state_dim.channel() != 1 ||
+ input_hidden_state_dim.height() != 1)
+ throw std::invalid_argument("Input hidden state's dimension should be "
+ "[batch, 1, 1, unit] for LSTMCellCore");
+ // input_cell_state_dim = [ batch, 1, 1, unit ]
+ const TensorDim &input_cell_state_dim = context.getInputDimensions()[2];
+ if (input_cell_state_dim.channel() != 1 || input_cell_state_dim.height() != 1)
+ throw std::invalid_argument("Input cell state's dimension should be "
+ "[batch, 1, 1, unit] for LSTMCellCore");
+
+ const unsigned int batch_size = input_dim.batch();
+#if ENABLE_SHARING_WT_IDX
+ const unsigned int feature_size = input_dim.width();
+#endif
+
+ const TensorDim output_dim(batch_size, 1, 1, unit);
+ const TensorDim output_hidden_state_dim = input_hidden_state_dim;
+ const TensorDim output_cell_state_dim = input_cell_state_dim;
+
+ context.setOutputDimensions(
+ {output_dim, output_hidden_state_dim, output_cell_state_dim});
+
+#if ENABLE_SHARING_WT_IDX
+ // weight_initializer can be set seperately. weight_ih initializer,
+ // weight_hh initializer kernel initializer & recurrent_initializer in keras
+ // for now, it is set same way.
+
+ // - weight_ih ( input to hidden )
+ // : [1, 1, feature_size, NUM_GATE x unit] -> i, f, g, o
+ TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+ wt_idx[LSTMCellCoreParams::weight_ih] =
+ context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+ weight_regularizer_constant, "weight_ih", true);
+ // - weight_hh ( hidden to hidden )
+ // : [1, 1, unit, NUM_GATE x unit] -> i, f, g, o
+ TensorDim weight_hh_dim({unit, NUM_GATE * unit});
+ wt_idx[LSTMCellCoreParams::weight_hh] =
+ context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
+ weight_regularizer_constant, "weight_hh", true);
+ // - bias_ih ( input bias )
+ // : [1, 1, 1, NUM_GATE x unit] -> i, f, g, o
+ TensorDim bias_ih_dim({NUM_GATE * unit});
+ wt_idx[LSTMCellCoreParams::bias_ih] =
+ context.requestWeight(bias_ih_dim, bias_initializer,
+ WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+#endif
+
+#if ENABLE_SHARING_WT_IDX
+ TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
+ wt_idx[LSTMCellCoreParams::ifgo] =
+ context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
+ TensorLifespan::ITERATION_LIFESPAN);
+#endif
+ acti_func.setActiFunc(hidden_state_activation_type.get());
+ recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
+}
+
+void LSTMCellCoreLayer::setProperty(const std::vector<std::string> &values) {
+ std::vector<std::string> remain_props =
+ loadProperties(values, lstmcell_core_props);
+ LayerImpl::setProperty(remain_props);
+}
+
+void LSTMCellCoreLayer::exportTo(Exporter &exporter,
+ const ExportMethods &method) const {
+#if ENABLE_SHARING_WT_IDX
+ LayerImpl::exportTo(exporter, method);
+#endif
+ exporter.saveResult(lstmcell_core_props, method, this);
+}
+
+void LSTMCellCoreLayer::forwarding(RunLayerContext &context, bool training) {
+ const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+
+ const Tensor &input = context.getInput(INDEX::INPUT);
+ const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
+ const Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
+ const TensorDim &input_dim = input.getDim();
+ const unsigned int batch_size = input_dim.batch();
+
+ Tensor &next_hidden_state = context.getOutput(INDEX::HIDDEN_STATE_OUT);
+ Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
+
+#if ENABLE_SHARING_WT_IDX
+ const Tensor &weight_ih =
+ context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
+ const Tensor &weight_hh =
+ context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
+ const Tensor &bias_ih =
+ context.getWeight(wt_idx[LSTMCellCoreParams::bias_ih]);
+#else
+ const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
+ const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
+ const Tensor &bias_ih = context.getWeight(LSTMCellCoreParams::bias_ih);
+#endif
+
+#if ENABLE_SHARING_WT_IDX
+ Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
+#else
+ Tensor &ifgo = context.getTensor(0);
+#endif
+
+ input.dot(weight_ih, ifgo);
+ prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
+ ifgo.add_i(bias_ih);
+
+ Tensor input_forget_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+ Tensor input_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+ Tensor forget_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+ Tensor memory_cell =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+ Tensor output_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+ recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
+ recurrent_acti_func.run_fn(output_gate, output_gate);
+ acti_func.run_fn(memory_cell, memory_cell);
+
+ forget_gate.multiply_strided(prev_cell_state, next_cell_state);
+ memory_cell.multiply_strided(input_gate, next_cell_state, 1.0f);
+
+ acti_func.run_fn(next_cell_state, next_hidden_state);
+ next_hidden_state.multiply_i_strided(output_gate);
+}
+
+void LSTMCellCoreLayer::calcDerivative(RunLayerContext &context) {
+#if ENABLE_SHARING_WT_IDX
+ Tensor &ifgo_derivative =
+ context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
+ const Tensor &weight_ih =
+ context.getWeight(wt_idx[LSTMCellCoreParams::weight_ih]);
+#else
+ const Tensor &weight_ih = context.getWeight(LSTMCellCoreParams::weight_ih);
+ Tensor &ifgo_derivative = context.getTensorGrad(0);
+#endif
+ Tensor &outgoing_derivative = context.getOutgoingDerivative(INDEX::INPUT);
+
+ ifgo_derivative.dot(weight_ih, outgoing_derivative, false, true);
+}
+
+void LSTMCellCoreLayer::calcGradient(RunLayerContext &context) {
+ const unsigned int unit = std::get<props::Unit>(lstmcell_core_props).get();
+
+ const Tensor &input = context.getInput(INDEX::INPUT);
+ const TensorDim &input_dim = input.getDim();
+ const unsigned int batch_size = input_dim.batch();
+
+#if ENABLE_SHARING_WT_IDX
+ Tensor &djdweight_ih =
+ context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_ih]);
+ const Tensor &weight_hh =
+ context.getWeight(wt_idx[LSTMCellCoreParams::weight_hh]);
+ Tensor &djdweight_hh =
+ context.getWeightGrad(wt_idx[LSTMCellCoreParams::weight_hh]);
+ Tensor &djdbias_ih =
+ context.getWeightGrad(wt_idx[LSTMCellCoreParams::bias_ih]);
+#else
+ Tensor &djdweight_ih = context.getWeightGrad(LSTMCellCoreParams::weight_ih);
+ const Tensor &weight_hh = context.getWeight(LSTMCellCoreParams::weight_hh);
+ Tensor &djdweight_hh = context.getWeightGrad(LSTMCellCoreParams::weight_hh);
+ Tensor &djdbias_ih = context.getWeightGrad(LSTMCellCoreParams::bias_ih);
+#endif
+
+ const Tensor &prev_hidden_state = context.getInput(INDEX::HIDDEN_STATE_IN);
+ Tensor &prev_hidden_state_derivative =
+ context.getOutgoingDerivative(INDEX::HIDDEN_STATE_IN);
+ Tensor &next_hidden_state_derivative =
+ context.getIncomingDerivative(INDEX::HIDDEN_STATE_OUT);
+
+ Tensor &prev_cell_state = context.getInput(INDEX::CELL_STATE_IN);
+ Tensor &prev_cell_state_derivative =
+ context.getOutgoingDerivative(INDEX::CELL_STATE_IN);
+ Tensor &next_cell_state = context.getOutput(INDEX::CELL_STATE_OUT);
+ Tensor &next_cell_state_derivative =
+ context.getIncomingDerivative(INDEX::CELL_STATE_OUT);
+
+#if ENABLE_SHARING_WT_IDX
+ Tensor &ifgo = context.getTensor(wt_idx[LSTMCellCoreParams::ifgo]);
+ Tensor &ifgo_derivative =
+ context.getTensorGrad(wt_idx[LSTMCellCoreParams::ifgo]);
+#else
+ Tensor &ifgo = context.getTensor(0);
+ Tensor &ifgo_derivative = context.getTensorGrad(0);
+#endif
+
+ Tensor input_forget_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+ Tensor input_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+ Tensor forget_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+ Tensor memory_cell =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+ Tensor output_gate =
+ ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+ Tensor input_forget_gate_derivative =
+ ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+ Tensor input_gate_derivative =
+ ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+ Tensor forget_gate_derivative =
+ ifgo_derivative.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+ Tensor memory_cell_derivative = ifgo_derivative.getSharedDataTensor(
+ {batch_size, 1, 1, unit}, unit * 2, false);
+ Tensor output_gate_derivative = ifgo_derivative.getSharedDataTensor(
+ {batch_size, 1, 1, unit}, unit * 3, false);
+
+ acti_func.run_fn(next_cell_state, next_cell_state);
+ next_hidden_state_derivative.multiply_strided(next_cell_state,
+ output_gate_derivative);
+
+ acti_func.run_prime_fn(next_cell_state, prev_cell_state_derivative,
+ next_hidden_state_derivative);
+ prev_cell_state_derivative.multiply_i_strided(output_gate);
+ prev_cell_state_derivative.add_i(next_cell_state_derivative);
+
+ prev_cell_state_derivative.multiply_strided(input_gate,
+ memory_cell_derivative);
+ prev_cell_state_derivative.multiply_strided(memory_cell,
+ input_gate_derivative);
+
+ prev_cell_state_derivative.multiply_strided(prev_cell_state,
+ forget_gate_derivative);
+ prev_cell_state_derivative.multiply_i_strided(forget_gate);
+
+ recurrent_acti_func.run_prime_fn(output_gate, output_gate_derivative,
+ output_gate_derivative);
+ recurrent_acti_func.run_prime_fn(input_forget_gate,
+ input_forget_gate_derivative,
+ input_forget_gate_derivative);
+ acti_func.run_prime_fn(memory_cell, memory_cell_derivative,
+ memory_cell_derivative);
+
+ ifgo_derivative.sum(0, djdbias_ih, 1.0f, 1.0f);
+ input.dot(ifgo_derivative, djdweight_ih, true, false, 1.0f);
+ prev_hidden_state.dot(ifgo_derivative, djdweight_hh, true, false, 1.0f);
+ ifgo_derivative.dot(weight_hh, prev_hidden_state_derivative, false, true);
+}
+
+void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) {
+ context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch);
+}
+
+} // namespace nntrainer