[lstm] implement bidirectional lstm forward
authorhyeonseok lee <hs89.lee@samsung.com>
Fri, 14 Jan 2022 04:29:57 +0000 (13:29 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 8 Feb 2022 04:41:23 +0000 (13:41 +0900)
 - Make batch_first_forward function
 - For now only support forward for bidirectional lstm

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h
nntrainer/layers/lstm.cpp
nntrainer/layers/lstm.h
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/models/unittest_models_recurrent.cpp

index a17021cbc846748f0da62d80364408e39a950a59..c0d6c2c8630b8d897f6a434b78825974ed8f173b 100644 (file)
@@ -75,6 +75,8 @@ std::ifstream::pos_type FilePath::file_size() { return cached_pos_size; }
 
 ReturnSequences::ReturnSequences(bool value) { set(value); }
 
+Bidirectional::Bidirectional(bool value) { set(value); }
+
 bool NumClass::isValid(const unsigned int &v) const { return v > 0; }
 
 InputConnection::InputConnection() : nntrainer::Property<Connection>() {}
index a12eda7bccbe4cd5a9d08f28548f8fedd2d02843..13cf6a50dc3eaf8dcdbd39f59e589522e630fb0d 100644 (file)
@@ -573,6 +573,21 @@ public:
   using prop_tag = bool_prop_tag;
 };
 
+/**
+ * @brief bidirectional property, used to make bidirectional layers
+ *
+ */
+class Bidirectional : public nntrainer::Property<bool> {
+public:
+  /**
+   * @brief Construct a new Bidirectional object
+   *
+   */
+  Bidirectional(bool value = false);
+  static constexpr const char *key = "bidirectional";
+  using prop_tag = bool_prop_tag;
+};
+
 /**
  * @brief Identifiers to locate a connection which should be returned as whole
  * used in recurrent realizer
index 6afca495aba96e30c8fff0e372b29fb9b6e7c0cf..b0a35abd7212e2bfbe18e970edc6fd1c18654db2 100644 (file)
@@ -31,16 +31,109 @@ enum LSTMParams {
   hidden_state,
   cell_state,
   ifgo,
+  reverse_weight_ih,
+  reverse_weight_hh,
+  reverse_bias_h,
+  reverse_bias_ih,
+  reverse_bias_hh,
+  reverse_hidden_state,
+  reverse_cell_state,
+  reverse_ifgo,
   dropout_mask
 };
 
+/**
+ * @brief run lstm fowarding for batch_first input
+ *
+ * @param NUM_GATE Number of gate which is 4 for lstm
+ * @param unit number of output neurons
+ * @param batch_size batch size
+ * @param max_timestep maximum timestep for lstm
+ * @param integrate_bias integrate bias_ih, bias_hh to bias_h
+ * @param acti_func activation function for memory cell, cell state
+ * @param recurrent_acti_func activation function for input/output/forget
+ * gate
+ * @param reverse indicate forward for reverse input in bidirectional lstm
+ * @param enable_dropout whether to apply dropout
+ * @param dropout_rate dropout rate
+ * @param input_ input
+ * @param weight_ih weight_ih. weight for input to hidden
+ * @param weight_hh weight_hh. weight for hidden to hidden
+ * @param bias_h bias_h. bias for input and hidden.
+ * @param bias_ih bias_ih. bias for input
+ * @param bias_hh bias_hh. bias for hidden
+ * @param hidden_state_ hidden state
+ * @param cell_state_ cell state
+ * @param ifgo_ input gate, forget gate, memory cell, output gate
+ * @param mask_ dropout mask
+ */
+static void batch_first_forwarding(
+  unsigned int NUM_GATE, const unsigned int unit, const unsigned int batch_size,
+  const unsigned int max_timestep, const unsigned int feature_size,
+  const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
+  ActiFunc &recurrent_acti_func, const bool reverse, const bool enable_dropout,
+  const float dropout_rate, const Tensor &input_, const Tensor &weight_ih,
+  const Tensor &weight_hh, const Tensor &bias_h, const Tensor &bias_ih,
+  const Tensor &bias_hh, Tensor &hidden_state_, Tensor &cell_state_,
+  Tensor &ifgo_, const Tensor &mask_) {
+  hidden_state_.setZero();
+  cell_state_.setZero();
+
+  for (unsigned int batch = 0; batch < batch_size; ++batch) {
+    const Tensor input_sample = input_.getBatchSlice(batch, 1);
+    Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
+    Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+    Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+
+    for (unsigned int t = 0; t < max_timestep; ++t) {
+      Tensor input = input_sample.getSharedDataTensor(
+        {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+      Tensor prev_hidden_state;
+
+      if (!t) {
+        prev_hidden_state = Tensor(unit);
+        prev_hidden_state.setZero();
+      } else {
+        prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+      }
+      Tensor hidden_state = hidden_state_sample.getSharedDataTensor(
+        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+      Tensor prev_cell_state;
+      if (!t) {
+        prev_cell_state = Tensor(unit);
+        prev_cell_state.setZero();
+      } else {
+        prev_cell_state =
+          cell_state_sample.getSharedDataTensor({unit}, (t - 1) * unit);
+      }
+      Tensor cell_state =
+        cell_state_sample.getSharedDataTensor({unit}, t * unit);
+      Tensor ifgo =
+        ifgo_sample.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
+
+      lstmcell_forwarding(unit, 1, disable_bias, integrate_bias, acti_func,
+                          recurrent_acti_func, input, prev_hidden_state,
+                          prev_cell_state, hidden_state, cell_state, weight_ih,
+                          weight_hh, bias_h, bias_ih, bias_hh, ifgo);
+
+      if (enable_dropout) {
+        Tensor mask_sample = mask_.getBatchSlice(batch, 1);
+        Tensor mask = mask_sample.getSharedDataTensor({unit}, t * unit);
+        mask.dropout_mask(dropout_rate);
+        hidden_state.multiply_i(mask);
+      }
+    }
+  }
+}
+
 LSTMLayer::LSTMLayer() :
   LayerImpl(),
   lstm_props(props::Unit(), props::IntegrateBias(),
              props::HiddenStateActivation() = ActivationType::ACT_TANH,
              props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
-             props::ReturnSequences(), props::DropOutRate(),
-             props::MaxTimestep()),
+             props::ReturnSequences(), props::Bidirectional(),
+             props::DropOutRate(), props::MaxTimestep()),
   acti_func(ActivationType::ACT_NONE, true),
   recurrent_acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
@@ -70,6 +163,7 @@ void LSTMLayer::finalize(InitLayerContext &context) {
     std::get<props::RecurrentActivation>(lstm_props).get();
   const bool return_sequences =
     std::get<props::ReturnSequences>(lstm_props).get();
+  const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
 
   if (context.getNumInputs() != 1) {
@@ -91,34 +185,32 @@ void LSTMLayer::finalize(InitLayerContext &context) {
   std::get<props::MaxTimestep>(lstm_props).set(max_timestep);
   const unsigned int feature_size = input_dim.width();
 
-  // if return_sequences == false :
-  //      output_dim = [ batch_size, 1, 1, unit ]
-  // else:
-  //      output_dim = [ batch_size, 1, time_iteration, unit ]
+  // output_dim = [ batch_size, 1, return_sequences ? time_iteration : 1,
+  // bidirectional ? 2 * unit : unit ]
   const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1,
-                             unit);
+                             bidirectional ? 2 * unit : unit);
   context.setOutputDimensions({output_dim});
 
   // weight_initializer can be set seperately. weight_ih initializer,
-  // weight_hh initializer kernel initializer & recurrent_initializer in keras
-  // for now, it is set same way.
+  // weight_hh initializer kernel initializer & recurrent_initializer in
+  // keras for now, it is set same way.
 
-  // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ] ->
-  // i, f, g, o
+  // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ]
+  // -> i, f, g, o
   const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
   wt_idx[LSTMParams::weight_ih] =
     context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_ih", true);
-  // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i, f,
-  // g, o
+  // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i,
+  // f, g, o
   const TensorDim weight_hh_dim({unit, NUM_GATE * unit});
   wt_idx[LSTMParams::weight_hh] =
     context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
-      // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1, 1, 1,
-      // NUM_GATE * unit ] -> i, f, g, o
+      // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1,
+      // 1, 1, NUM_GATE * unit ] -> i, f, g, o
       const TensorDim bias_h_dim({NUM_GATE * unit});
       wt_idx[LSTMParams::bias_h] =
         context.requestWeight(bias_h_dim, bias_initializer,
@@ -129,7 +221,8 @@ void LSTMLayer::finalize(InitLayerContext &context) {
       wt_idx[LSTMParams::bias_ih] =
         context.requestWeight(bias_ih_dim, bias_initializer,
                               WeightRegularizer::NONE, 1.0f, "bias_ih", true);
-      // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
+      // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g,
+      // o
       const TensorDim bias_hh_dim({NUM_GATE * unit});
       wt_idx[LSTMParams::bias_hh] =
         context.requestWeight(bias_hh_dim, bias_initializer,
@@ -154,6 +247,67 @@ void LSTMLayer::finalize(InitLayerContext &context) {
     context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true,
                           TensorLifespan::ITERATION_LIFESPAN);
 
+  if (bidirectional) {
+    // weight_initializer can be set seperately. weight_ih initializer,
+    // weight_hh initializer kernel initializer & recurrent_initializer in
+    // keras for now, it is set same way.
+
+    // reverse_weight_ih ( input to hidden ) : [ 1, 1, feature_size,
+    // NUM_GATE * unit ] -> i, f, g, o
+    const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit});
+    wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight(
+      reverse_weight_ih_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, "reverse_weight_ih", true);
+    // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE *
+    // unit ]
+    // -> i, f, g, o
+    const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit});
+    wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight(
+      reverse_weight_hh_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, "reverse_weight_hh", true);
+    if (!disable_bias) {
+      if (integrate_bias) {
+        // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias
+        // ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
+        const TensorDim reverse_bias_h_dim({NUM_GATE * unit});
+        wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight(
+          reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+          "reverse_bias_h", true);
+      } else {
+        // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
+        // i, f, g, o
+        const TensorDim reverse_bias_ih_dim({NUM_GATE * unit});
+        wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight(
+          reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+          "reverse_bias_ih", true);
+        // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
+        // i, f, g, o
+        const TensorDim reverse_bias_hh_dim({NUM_GATE * unit});
+        wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight(
+          reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+          "reverse_bias_hh", true);
+      }
+    }
+
+    // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
+    const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit);
+    wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor(
+      reverse_hidden_state_dim, "reverse_hidden_state",
+      Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN);
+    // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ]
+    const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit);
+    wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor(
+      reverse_cell_state_dim, "reverse_cell_state", Tensor::Initializer::NONE,
+      true, TensorLifespan::ITERATION_LIFESPAN);
+
+    // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
+    const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
+                                     NUM_GATE * unit);
+    wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor(
+      reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true,
+      TensorLifespan::ITERATION_LIFESPAN);
+  }
+
   if (dropout_rate > epsilon) {
     // dropout_mask_dim = [ batch, 1, time_iteration, unit ]
     const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit);
@@ -186,12 +340,16 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
   const bool integrate_bias = std::get<props::IntegrateBias>(lstm_props).get();
   const bool return_sequences =
     std::get<props::ReturnSequences>(lstm_props).get();
+  const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
   const unsigned int max_timestep =
     std::get<props::MaxTimestep>(lstm_props).get();
 
-  const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX);
-  const TensorDim input_dim = inputs.getDim();
+  unsigned int bidirectional_constant = bidirectional ? 2 : 1;
+  bool enable_dropout = dropout_rate > epsilon && training;
+
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  const TensorDim input_dim = input.getDim();
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
@@ -209,70 +367,81 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
                             ? context.getWeight(wt_idx[LSTMParams::bias_hh])
                             : empty;
 
-  Tensor &hs = context.getTensor(wt_idx[LSTMParams::hidden_state]);
-  Tensor &cs = context.getTensor(wt_idx[LSTMParams::cell_state]);
-  Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]);
-
-  hs.setZero();
-  cs.setZero();
-
-  for (unsigned int batch = 0; batch < batch_size; ++batch) {
-    const Tensor input_batch = inputs.getBatchSlice(batch, 1);
-    Tensor hs_batch = hs.getBatchSlice(batch, 1);
-    Tensor cs_batch = cs.getBatchSlice(batch, 1);
-    Tensor ifgo_batch = ifgos.getBatchSlice(batch, 1);
-
-    for (unsigned int t = 0; t < max_timestep; ++t) {
-      Tensor input;
-      if (input_batch.height() != 1)
-        input =
-          input_batch.getSharedDataTensor({feature_size}, t * feature_size);
-      else
-        input = input_batch;
-
-      Tensor prev_hidden_state;
-      if (!t) {
-        prev_hidden_state = Tensor(unit);
-        prev_hidden_state.setZero();
-      } else {
-        prev_hidden_state =
-          hs_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-      }
-      Tensor hidden_state = hs_batch.getSharedDataTensor({unit}, t * unit);
-      Tensor prev_cell_state;
-      if (!t) {
-        prev_cell_state = Tensor(unit);
-        prev_cell_state.setZero();
-      } else {
-        prev_cell_state = cs_batch.getSharedDataTensor({unit}, (t - 1) * unit);
-      }
-      Tensor cell_state = cs_batch.getSharedDataTensor({unit}, t * unit);
-      Tensor ifgo =
-        ifgo_batch.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
-
-      lstmcell_forwarding(unit, 1, disable_bias, integrate_bias, acti_func,
-                          recurrent_acti_func, input, prev_hidden_state,
-                          prev_cell_state, hidden_state, cell_state, weight_ih,
-                          weight_hh, bias_h, bias_ih, bias_hh, ifgo);
-
-      if (dropout_rate > epsilon && training) {
-        Tensor masks = context.getTensor(wt_idx[LSTMParams::dropout_mask])
-                         .getBatchSlice(batch, 1);
-        Tensor mask = masks.getSharedDataTensor({unit}, t * unit);
-        mask.dropout_mask(dropout_rate);
-        hidden_state.multiply_i(mask);
-      }
-    }
+  Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]);
+  Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
+  Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
+
+  Tensor &mask = enable_dropout
+                   ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
+                   : empty;
+
+  batch_first_forwarding(
+    NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias,
+    integrate_bias, acti_func, recurrent_acti_func, false, enable_dropout,
+    dropout_rate, input, weight_ih, weight_hh, bias_h, bias_ih, bias_hh,
+    hidden_state, cell_state, ifgo, mask);
+
+  if (bidirectional) {
+    const Tensor &reverse_weight_ih =
+      context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
+    const Tensor &reverse_weight_hh =
+      context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
+    const Tensor &reverse_bias_h =
+      !disable_bias && integrate_bias
+        ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h])
+        : empty;
+    const Tensor &reverse_bias_ih =
+      !disable_bias && !integrate_bias
+        ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih])
+        : empty;
+    const Tensor &reverse_bias_hh =
+      !disable_bias && !integrate_bias
+        ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh])
+        : empty;
+
+    Tensor &reverse_hidden_state =
+      context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
+    Tensor &reverse_cell_state =
+      context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
+    Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
+
+    batch_first_forwarding(
+      NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias,
+      integrate_bias, acti_func, recurrent_acti_func, true, enable_dropout,
+      dropout_rate, input, reverse_weight_ih, reverse_weight_hh, reverse_bias_h,
+      reverse_bias_ih, reverse_bias_hh, reverse_hidden_state,
+      reverse_cell_state, reverse_ifgo, mask);
   }
 
-  if (return_sequences) {
-    std::copy(hs.getData(), hs.getData() + hs.size(), output.getData());
+  if (return_sequences && !bidirectional) {
+    std::copy(hidden_state.getData(),
+              hidden_state.getData() + hidden_state.size(), output.getData());
   } else {
+    unsigned int start_timestep = 0;
+    unsigned int end_timestep = return_sequences ? max_timestep : 1;
     for (unsigned int batch = 0; batch < batch_size; ++batch) {
-      float *hidden_state_data =
-        hs.getAddress(batch * max_timestep * unit + (max_timestep - 1) * unit);
-      float *output_data = output.getAddress(batch * unit);
-      std::copy(hidden_state_data, hidden_state_data + unit, output_data);
+      for (unsigned int timestep = start_timestep; timestep < end_timestep;
+           ++timestep) {
+        float *hidden_state_data = hidden_state.getAddress(
+          batch * max_timestep * unit +
+          (return_sequences ? 0 : (max_timestep - 1) * unit) + timestep * unit);
+        float *output_data =
+          output.getAddress(batch * (return_sequences ? max_timestep : 1) *
+                              bidirectional_constant * unit +
+                            timestep * bidirectional_constant * unit);
+        std::copy(hidden_state_data, hidden_state_data + unit, output_data);
+
+        if (bidirectional) {
+          Tensor &reverse_hidden_state =
+            context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
+          float *reverse_hidden_state_data = reverse_hidden_state.getAddress(
+            batch * max_timestep * unit +
+            (return_sequences ? 0 : (max_timestep - 1) * unit) +
+            timestep * unit);
+          std::copy(reverse_hidden_state_data, reverse_hidden_state_data + unit,
+                    output_data + unit);
+        }
+      }
     }
   }
 }
@@ -354,8 +523,10 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
 
       Tensor rdata =
         incoming_derivative.getSharedDataTensor({unit}, batch * unit);
-      /// @note this is not copying from start ~ end but only start time step
-      // This is copying for self rolling as well as last recurrent unrolled.
+      /// @note this is not copying from start ~ end but only start time
+      /// step
+      // This is copying for self rolling as well as last recurrent
+      // unrolled.
       if ((unsigned)start_timestep + 1 == max_timestep) {
         data.fill(rdata);
       } else {
@@ -427,8 +598,8 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
       Tensor d_ifgo = d_ifgo_batch.getSharedDataTensor({unit * NUM_GATE},
                                                        unit * t * NUM_GATE);
 
-      // Temporary variable for d_prev_hidden_state. d_prev_hidden_state already
-      // have precalculated values from incomming derivatives
+      // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
+      // already have precalculated values from incomming derivatives
       Tensor d_prev_hidden_state_temp;
 
       lstmcell_calcGradient(unit, 1, disable_bias, integrate_bias, acti_func,
@@ -443,10 +614,18 @@ void LSTMLayer::calcGradient(RunLayerContext &context) {
 }
 
 void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) {
+  const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
+
   context.updateTensor(wt_idx[LSTMParams::hidden_state], batch);
   context.updateTensor(wt_idx[LSTMParams::cell_state], batch);
   context.updateTensor(wt_idx[LSTMParams::ifgo], batch);
 
+  if (bidirectional) {
+    context.updateTensor(wt_idx[LSTMParams::reverse_hidden_state], batch);
+    context.updateTensor(wt_idx[LSTMParams::reverse_cell_state], batch);
+    context.updateTensor(wt_idx[LSTMParams::reverse_ifgo], batch);
+  }
+
   if (std::get<props::DropOutRate>(lstm_props).get() > epsilon) {
     context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
   }
index 804ba0e02a386421f5b0257c5bcffa76cb892d3b..e68369d5b10af6c63c96458287c2b5cee6897005 100644 (file)
@@ -106,15 +106,16 @@ private:
    * HiddenStateActivation: activation type for hidden state. default is tanh
    * RecurrentActivation: activation type for recurrent. default is sigmoid
    * ReturnSequence: option for return sequence
+   * Bidirectional: option for bidirectional
    * DropOutRate: dropout rate
    * MaxTimestep: maximum timestep for lstm
    *
    * */
   std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
              props::RecurrentActivation, props::ReturnSequences,
-             props::DropOutRate, props::MaxTimestep>
+             props::Bidirectional, props::DropOutRate, props::MaxTimestep>
     lstm_props;
-  std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 15> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is tanh
index 6197a19db7792557e398626517cc4d56efa7a1f1..f704d8c0caad083f0db74121655533699125a844 100644 (file)
@@ -56,15 +56,17 @@ class RNNCellStacked(torch.nn.Module):
         return ret, loss
 
 class LSTMStacked(torch.nn.Module):
-    def __init__(self, num_lstm=1):
+    def __init__(self, num_lstm=1, bidirectional=False):
         super().__init__()
         self.input_size = self.hidden_size = 2
         self.num_lstm = num_lstm
+        self.bidirectional=bidirectional
         self.lstms = torch.nn.ModuleList(
             [
-                torch.nn.LSTM(self.input_size, self.hidden_size, batch_first=True)
-                # torch.nn.LSTM(self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True)
-                for _ in range(num_lstm)
+                torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, batch_first=True, bidirectional=bidirectional)
+                # Intended comment
+                # torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True, bidirectional=bidirectional)
+                for i in range(num_lstm)
             ]
         )
         self.loss = torch.nn.MSELoss()
@@ -73,12 +75,12 @@ class LSTMStacked(torch.nn.Module):
         out = inputs[0]
         states = inputs[1:]
         # hs = [states[2 * i] for i in range(self.num_lstm)]
-        hs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
+        hs = [torch.zeros((2, 3, 2)) if self.bidirectional else torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
         # cs = [states[2 * i + 1] for i in range(self.num_lstm)]
-        cs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
+        cs = [torch.zeros((2, 3, 2)) if self.bidirectional else torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)]
         for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)):
             out, (hs[i], cs[i]) = lstm(out, (h, c))
-
+        
         loss = self.loss(out, labels[0])
         return out, loss
 
@@ -212,9 +214,9 @@ if __name__ == "__main__":
         name="rnncell_stacked",
     )
 
-    unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 1, 3, 2, 2, 2]
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, False]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm),
+        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
         iteration=iteration,
         input_dims=[(batch_size, unroll_for, feature_size)],
         # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
@@ -222,9 +224,9 @@ if __name__ == "__main__":
         name="lstm_single",
     )
 
-    unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 2, 3, 2, 2, 2]
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, False]
     record_v2(
-        LSTMStacked(num_lstm=num_lstm),
+        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
         iteration=iteration,
         input_dims=[(batch_size, unroll_for, feature_size)],
         # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)],
@@ -232,6 +234,26 @@ if __name__ == "__main__":
         name="lstm_stacked",
     )
 
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, True]
+    record_v2(
+        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
+        iteration=iteration,
+        input_dims=[(batch_size, unroll_for, feature_size)],
+        # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims=[(batch_size, unroll_for, 2 * unit)],
+        name="bidirectional_lstm_single",
+    )
+
+    unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, True]
+    record_v2(
+        LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional),
+        iteration=iteration,
+        input_dims=[(batch_size, unroll_for, feature_size)],
+        # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)],
+        label_dims=[(batch_size, unroll_for, 2 * unit)],
+        name="bidirectional_lstm_stacked",
+    )
+
     unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 1, 2, 3, 2, 2, 2]
     record_v2(
         LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
index 9373d6739d8bd848f941b85e71229f34e1a9bb8c..ca0b621500bc6f262f0a27fa81bb6ed4c30d84a4 100644 (file)
@@ -70,7 +70,21 @@ def zoneout_translate(model):
     new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3], hidden_state, cell_state]
     yield from new_params
 
-@register_for_((torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell))
+@register_for_((torch.nn.LSTM))
+def lstm_translate(model):
+    params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+    # [hidden, input] -> [input, hidden]
+    def transpose_(weight):
+        return (weight[0], weight[1].transpose(1, 0))
+
+    new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3]]
+    if model.bidirectional:
+        reverse_params = [transpose_(params[4]), transpose_(params[5]), params[6], params[7]]
+        new_params += reverse_params
+
+    yield from new_params
+
+@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell))
 def rnn_lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
     # [hidden, input] -> [input, hidden]
index 9aa336445cbcdf0c19839b575a2f1774ca516eb5..ae812f952c7e2715f072c1676f117f45a222f19f 100644 (file)
@@ -175,6 +175,47 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTM() {
   return nn;
 }
 
+// static std::unique_ptr<NeuralNetwork> makeSingleBidirectionalLSTM() {
+//   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+//   nn->setProperty({"batch_size=3"});
+
+//   auto outer_graph = makeGraph({
+//     {"input", {"name=input", "input_shape=1:2:2"}},
+//     {"lstm",
+//      {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
+//       "bidirectional=true"}},
+//     {"mse", {"name=loss", "input_layers=a1"}},
+//   });
+//   for (auto &node : outer_graph) {
+//     nn->addLayer(node);
+//   }
+
+//   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate =
+//   0.1"})); return nn;
+// }
+
+// static std::unique_ptr<NeuralNetwork> makeStackedBidirectionalLSTM() {
+//   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+//   nn->setProperty({"batch_size=3"});
+
+//   auto outer_graph = makeGraph({
+//     {"input", {"name=input", "input_shape=1:2:2"}},
+//     {"lstm",
+//      {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true",
+//       "bidirectional=true"}},
+//     {"lstm",
+//      {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true",
+//       "bidirectional=true"}},
+//     {"mse", {"name=loss"}},
+//   });
+//   for (auto &node : outer_graph) {
+//     nn->addLayer(node);
+//   }
+
+//   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate =
+//   0.1"})); return nn;
+// }
+
 static std::unique_ptr<NeuralNetwork> makeSingleLSTMCell() {
   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
   nn->setProperty({"batch_size=3"});
@@ -526,6 +567,10 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::ALL_V2),
+    // mkModelTc_V2(makeSingleBidirectionalLSTM, "bidirectional_lstm_single",
+    //              ModelTestOption::ALL_V2),
+    // mkModelTc_V2(makeStackedBidirectionalLSTM, "bidirectional_lstm_stacked",
+    //              ModelTestOption::ALL_V2),
     mkModelTc_V2(makeSingleLSTMCell, "lstmcell_single",
                  ModelTestOption::ALL_V2),
     mkModelTc_V2(makeStackedLSTMCell, "lstmcell_stacked",