[rnncell] enable 2 bias
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 8 Dec 2021 03:52:02 +0000 (12:52 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 13 Dec 2021 02:50:23 +0000 (11:50 +0900)
 - Make a integrate bias property. It decide whether integrate 2 bias to 1
   or not. It will be used in rnn variant for now.
 - Added bias_hh in rnncell

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/common_properties.h
nntrainer/layers/rnncell.cpp
nntrainer/layers/rnncell.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/layers/unittest_layers_rnncell.cpp
test/unittest/models/unittest_models_recurrent.cpp

index cb21cf7..0f46a46 100644 (file)
@@ -122,6 +122,22 @@ public:
 };
 
 /**
+ * @brief Integrate bias_ih and bias_hh to bias_h to use only 1 bias (Used in
+ * rnn variant)
+ *
+ */
+class IntegrateBias : public nntrainer::Property<bool> {
+public:
+  /**
+   * @brief Construct a IntegrateBias object
+   *
+   */
+  IntegrateBias(bool val = false) : nntrainer::Property<bool>(val) {}
+  using prop_tag = bool_prop_tag;
+  static constexpr const char *key = "integrate_bias";
+};
+
+/**
  * @brief Normalization property, normalize the input to be in range [0, 1] if
  * true
  *
index df3d392..a4c4562 100644 (file)
@@ -25,15 +25,27 @@ namespace nntrainer {
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-// - weight_xh ( weights of input to hidden )
+// - weight_ih ( weights of input to hidden )
 // - weight_hh ( weights of hidden to hidden )
-// - bias_h ( hidden bias )
-enum RNNCellParams { weight_xh, weight_hh, bias_h, hidden_state, dropout_mask };
+// - bias_h ( input bias, hidden_bias )
+// - bias_ih ( input bias )
+// - bias_hh ( hidden bias )
+enum RNNCellParams {
+  weight_ih,
+  weight_hh,
+  bias_h,
+  bias_ih,
+  bias_hh,
+  hidden_state,
+  dropout_mask
+};
 
 RNNCellLayer::RNNCellLayer() :
   LayerImpl(),
-  rnncell_props(props::Unit(), props::HiddenStateActivation(),
-                props::DropOutRate(), props::MaxTimestep(), props::Timestep()),
+  rnncell_props(props::Unit(),
+                props::HiddenStateActivation() = ActivationType::ACT_TANH,
+                props::DropOutRate(), props::IntegrateBias(),
+                props::MaxTimestep(), props::Timestep()),
   acti_func(ActivationType::ACT_NONE, true),
   epsilon(1e-3) {
   wt_idx.fill(std::numeric_limits<unsigned>::max());
@@ -48,69 +60,90 @@ void RNNCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightInitializer>(*layer_impl_props);
   const Tensor::Initializer bias_initializer =
     std::get<props::BiasInitializer>(*layer_impl_props);
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
 
   const unsigned int unit = std::get<props::Unit>(rnncell_props).get();
-  nntrainer::props::HiddenStateActivation hidden_state_activation_type =
-    std::get<props::HiddenStateActivation>(rnncell_props);
-  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(rnncell_props);
+  const nntrainer::ActivationType hidden_state_activation_type =
+    std::get<props::HiddenStateActivation>(rnncell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(rnncell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(rnncell_props).get();
 
   if (context.getNumInputs() != 1) {
     throw std::invalid_argument("RNNCell layer takes only one input");
   }
 
   // input_dim = [ batch, 1, 1, feature_size ]
-  const TensorDim &input_dim = context.getInputDimensions()[0];
-  if (input_dim.channel() != 1 && input_dim.height() != 1) {
+  const TensorDim &input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
+  if (input_dim.channel() != 1 || input_dim.height() != 1) {
     throw std::invalid_argument(
       "Input must be single time dimension for RNNCell");
   }
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
 
-  // outut_dim = [ batch, 1, 1, hidden_size ( unit ) ]
+  // output_dim = [ batch, 1, 1, unit ]
   TensorDim output_dim(batch_size, 1, 1, unit);
 
-  if (dropout_rate > epsilon) {
-    wt_idx[RNNCellParams::dropout_mask] = context.requestTensor(
-      output_dim, "dropout_mask", Tensor::Initializer::NONE, false,
-      TensorLifespan::ITERATION_LIFESPAN);
-  }
-
   context.setOutputDimensions({output_dim});
 
-  // weight_xh_dim : [1, 1, input_size, unit]
-  const TensorDim weight_xh_dim({feature_size, unit});
-  // weight_hh_dim : [1, 1, unit, unit]
-  const TensorDim weight_hh_dim({unit, unit});
-  // bias_h_dim : [1, 1, 1, unit]
-  const TensorDim bias_h_dim({unit});
-
-  // weight_initializer can be set seperately. weight_xh initializer,
+  // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
   // for now, it is set same way.
-  wt_idx[RNNCellParams::weight_xh] =
-    context.requestWeight(weight_xh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_xh", true);
+
+  // weight_ih_dim : [ 1, 1, feature_size, unit ]
+  const TensorDim weight_ih_dim({feature_size, unit});
+  wt_idx[RNNCellParams::weight_ih] =
+    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_ih", true);
+  // weight_hh_dim : [ 1, 1, unit, unit ]
+  const TensorDim weight_hh_dim({unit, unit});
   wt_idx[RNNCellParams::weight_hh] =
     context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_hh", true);
-  wt_idx[RNNCellParams::bias_h] =
-    context.requestWeight(bias_h_dim, bias_initializer, WeightRegularizer::NONE,
-                          1.0f, "bias_h", true);
+  if (!disable_bias) {
+    if (integrate_bias) {
+      // bias_h_dim : [ 1, 1, 1, unit ]
+      const TensorDim bias_h_dim({unit});
+      wt_idx[RNNCellParams::bias_h] =
+        context.requestWeight(bias_h_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+    } else {
+      // bias_ih_dim : [ 1, 1, 1, unit ]
+      const TensorDim bias_ih_dim({unit});
+      wt_idx[RNNCellParams::bias_ih] =
+        context.requestWeight(bias_ih_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      // bias_hh_dim : [ 1, 1, 1, unit ]
+      const TensorDim bias_hh_dim({unit});
+      wt_idx[RNNCellParams::bias_hh] =
+        context.requestWeight(bias_hh_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+    }
+  }
 
   // We do not need this if we reuse net_hidden[0]. But if we do, then the unit
   // test will fail. Becuase it modifies the data during gradient calculation
   // TODO : We could control with something like #define test to save memory
-  const TensorDim dim(batch_size * max_timestep, 1, 1, unit);
-  wt_idx[RNNCellParams::hidden_state] =
-    context.requestTensor(dim, "hidden_state", Tensor::Initializer::NONE, true,
-                          TensorLifespan::ITERATION_LIFESPAN, false);
 
-  if (hidden_state_activation_type.get() == ActivationType::ACT_NONE) {
-    hidden_state_activation_type.set(ActivationType::ACT_TANH);
+  // hidden_state_dim = [ max_timestep * batch, 1, 1, unit ]
+  const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
+  wt_idx[RNNCellParams::hidden_state] = context.requestTensor(
+    hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN, false);
+
+  if (dropout_rate > epsilon) {
+    // dropout_mask_dim = [ max_timestep * batch, 1, 1, unit ]
+    const TensorDim dropout_mask_dim(max_timestep * batch_size, 1, 1, unit);
+    wt_idx[RNNCellParams::dropout_mask] = context.requestTensor(
+      dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+      TensorLifespan::ITERATION_LIFESPAN);
   }
-  acti_func.setActiFunc(hidden_state_activation_type.get());
+
+  acti_func.setActiFunc(hidden_state_activation_type);
 
   if (!acti_func.supportInPlace()) {
     throw exception::not_supported(
@@ -119,7 +152,8 @@ void RNNCellLayer::finalize(InitLayerContext &context) {
 }
 
 void RNNCellLayer::setProperty(const std::vector<std::string> &values) {
-  auto remain_props = loadProperties(values, rnncell_props);
+  const std::vector<std::string> &remain_props =
+    loadProperties(values, rnncell_props);
   LayerImpl::setProperty(remain_props);
 }
 
@@ -130,131 +164,205 @@ void RNNCellLayer::exportTo(Exporter &exporter,
 }
 
 void RNNCellLayer::forwarding(RunLayerContext &context, bool training) {
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
   const unsigned int unit = std::get<props::Unit>(rnncell_props).get();
-  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(rnncell_props);
-  const unsigned int timestep = std::get<props::Timestep>(rnncell_props);
+  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(rnncell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(rnncell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(rnncell_props).get();
 
-  Tensor &weight_xh = context.getWeight(wt_idx[RNNCellParams::weight_xh]);
+  Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  const unsigned int batch_size = input.getDim().batch();
+  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+
+  Tensor &weight_ih = context.getWeight(wt_idx[RNNCellParams::weight_ih]);
   Tensor &weight_hh = context.getWeight(wt_idx[RNNCellParams::weight_hh]);
-  Tensor &bias_h = context.getWeight(wt_idx[RNNCellParams::bias_h]);
+  Tensor empty;
+  Tensor &bias_h = !disable_bias && integrate_bias
+                     ? context.getWeight(wt_idx[RNNCellParams::bias_h])
+                     : empty;
+  Tensor &bias_ih = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[RNNCellParams::bias_ih])
+                      : empty;
+  Tensor &bias_hh = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[RNNCellParams::bias_hh])
+                      : empty;
 
-  Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  const TensorDim &input_dim = input.getDim();
-  const unsigned int batch_size = input_dim[0];
   Tensor &hidden_states =
     context.getTensor(wt_idx[RNNCellParams::hidden_state]);
   hidden_states.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state;
+  if (!timestep) {
+    prev_hidden_state = Tensor(batch_size, unit);
+    prev_hidden_state.setZero();
+  } else {
+    prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
+  }
   Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
 
-  input.dot(weight_xh, hidden_state);
-  if (timestep) {
-    Tensor prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state.dot(weight_hh, hidden_state, false, false, 1.0f);
+  input.dot(weight_ih, hidden_state);
+  prev_hidden_state.dot(weight_hh, hidden_state, false, false, 1.0f);
+  if (!disable_bias) {
+    if (integrate_bias) {
+      hidden_state.add_i(bias_h);
+    } else {
+      hidden_state.add_i(bias_ih);
+      hidden_state.add_i(bias_hh);
+    }
   }
-  hidden_state.add_i(bias_h);
+
   acti_func.run_fn(hidden_state, hidden_state);
+
   if (dropout_rate > epsilon && training) {
-    Tensor &mask = context.getTensor(wt_idx[RNNCellParams::dropout_mask]);
-    mask.dropout_mask(dropout_rate);
-    hidden_state.multiply_i(mask);
+    Tensor &dropout_mask =
+      context.getTensor(wt_idx[RNNCellParams::dropout_mask]);
+    dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+    dropout_mask_t.dropout_mask(dropout_rate);
+    hidden_state.multiply_i(dropout_mask_t);
   }
 
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   output.copy(hidden_state);
 }
 
 void RNNCellLayer::calcDerivative(RunLayerContext &context) {
   const unsigned int unit = std::get<props::Unit>(rnncell_props).get();
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(rnncell_props);
-  const unsigned int timestep = std::get<props::Timestep>(rnncell_props);
-  const TensorDim &input_dim = context.getInput(SINGLE_INOUT_IDX).getDim();
-  const unsigned int batch_size = input_dim.batch();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(rnncell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(rnncell_props).get();
+
+  const unsigned int batch_size =
+    context.getInput(SINGLE_INOUT_IDX).getDim().batch();
+
   Tensor &hidden_states_derivatives =
     context.getTensorGrad(wt_idx[RNNCellParams::hidden_state]);
   hidden_states_derivatives.reshape({max_timestep, 1, batch_size, unit});
   Tensor hidden_state_derivative =
     hidden_states_derivatives.getBatchSlice(timestep, 1);
-
-  Tensor &weight_xh = context.getWeight(wt_idx[RNNCellParams::weight_xh]);
+  Tensor &weight_ih = context.getWeight(wt_idx[RNNCellParams::weight_ih]);
   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
-  hidden_state_derivative.dot(weight_xh, outgoing_derivative, false, true);
+
+  hidden_state_derivative.dot(weight_ih, outgoing_derivative, false, true);
 }
 
 void RNNCellLayer::calcGradient(RunLayerContext &context) {
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
   const unsigned int unit = std::get<props::Unit>(rnncell_props).get();
-  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(rnncell_props);
-  const unsigned int timestep = std::get<props::Timestep>(rnncell_props);
+  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(rnncell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(rnncell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(rnncell_props).get();
 
   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
   Tensor &incoming_derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
-  const TensorDim &input_dim = input.getDim();
-  const unsigned int batch_size = input_dim.batch();
+  const unsigned int batch_size = input.getDim().batch();
 
-  Tensor &djdweight_xh =
-    context.getWeightGrad(wt_idx[RNNCellParams::weight_xh]);
+  Tensor &djdweight_ih =
+    context.getWeightGrad(wt_idx[RNNCellParams::weight_ih]);
+  Tensor &weight_hh = context.getWeight(wt_idx[RNNCellParams::weight_hh]);
   Tensor &djdweight_hh =
     context.getWeightGrad(wt_idx[RNNCellParams::weight_hh]);
-  Tensor &djdbias_h = context.getWeightGrad(wt_idx[RNNCellParams::bias_h]);
-  Tensor &weight_hh = context.getWeight(wt_idx[RNNCellParams::weight_hh]);
+  Tensor empty;
+  Tensor &djdbias_h = !disable_bias && integrate_bias
+                        ? context.getWeightGrad(wt_idx[RNNCellParams::bias_h])
+                        : empty;
+  Tensor &djdbias_ih = !disable_bias && !integrate_bias
+                         ? context.getWeightGrad(wt_idx[RNNCellParams::bias_ih])
+                         : empty;
+  Tensor &djdbias_hh = !disable_bias && !integrate_bias
+                         ? context.getWeightGrad(wt_idx[RNNCellParams::bias_hh])
+                         : empty;
 
   Tensor &hidden_states =
     context.getTensor(wt_idx[RNNCellParams::hidden_state]);
   hidden_states.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state;
+  if (!timestep) {
+    prev_hidden_state = Tensor(batch_size, unit);
+    prev_hidden_state.setZero();
+  } else {
+    prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
+  }
   Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
   Tensor &hidden_states_derivatives =
     context.getTensorGrad(wt_idx[RNNCellParams::hidden_state]);
   hidden_states_derivatives.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state_derivative;
+  if (!timestep) {
+    prev_hidden_state_derivative = Tensor(batch_size, unit);
+  } else {
+    prev_hidden_state_derivative =
+      hidden_states_derivatives.getBatchSlice(timestep - 1, 1);
+  }
   Tensor hidden_state_derivative =
     hidden_states_derivatives.getBatchSlice(timestep, 1);
 
   if (timestep + 1 == max_timestep) {
-    djdweight_xh.setZero();
+    djdweight_ih.setZero();
     djdweight_hh.setZero();
-    djdbias_h.setZero();
+    if (!disable_bias) {
+      if (integrate_bias) {
+        djdbias_h.setZero();
+      } else {
+        djdbias_ih.setZero();
+        djdbias_hh.setZero();
+      }
+    }
+    hidden_state_derivative.setZero();
   }
 
-  hidden_state_derivative.reshape(incoming_derivative.getDim());
-  if (timestep + 1 == max_timestep) {
-    hidden_state_derivative.copyData(incoming_derivative);
-  } else {
-    hidden_state_derivative.add_i(incoming_derivative);
-  }
-  // restore the dimension
-  hidden_state_derivative.reshape({1, 1, batch_size, unit});
+  hidden_state_derivative.reshape(
+    {batch_size, 1, 1, unit}); // reshape to incoming_derivative dim
+  hidden_state_derivative.add_i(incoming_derivative);
+  hidden_state_derivative.reshape({batch_size, unit}); // restore dimension
 
   if (dropout_rate > epsilon) {
-    hidden_state_derivative.multiply_i(
-      context.getTensor(wt_idx[RNNCellParams::dropout_mask]));
+    Tensor &dropout_mask =
+      context.getTensor(wt_idx[RNNCellParams::dropout_mask]);
+    dropout_mask.reshape({max_timestep, 1, batch_size, unit});
+    Tensor dropout_mask_t = dropout_mask.getBatchSlice(timestep, 1);
+    hidden_state_derivative.multiply_i(dropout_mask_t);
   }
 
   acti_func.run_prime_fn(hidden_state, hidden_state_derivative,
                          hidden_state_derivative);
 
-  input.dot(hidden_state_derivative, djdweight_xh, true, false, 1.0);
-  hidden_state_derivative.sum(2, djdbias_h, 1.0, 1.0);
-
-  if (timestep) {
-    Tensor prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
-    prev_hidden_state.dot(hidden_state_derivative, djdweight_hh, true, false,
-                          1.0);
-    Tensor prev_hidden_state_derivative =
-      hidden_states_derivatives.getBatchSlice(timestep - 1, 1);
-    hidden_state_derivative.dot(weight_hh, prev_hidden_state_derivative, false,
-                                true);
+  input.dot(hidden_state_derivative, djdweight_ih, true, false, 1.0);
+
+  hidden_state_derivative.dot(weight_hh, prev_hidden_state_derivative, false,
+                              true);
+  prev_hidden_state.dot(hidden_state_derivative, djdweight_hh, true, false,
+                        1.0);
+  if (!disable_bias) {
+    if (integrate_bias) {
+      hidden_state_derivative.sum(2, djdbias_h, 1.0, 1.0);
+    } else {
+      hidden_state_derivative.sum(2, djdbias_ih, 1.0, 1.0);
+      hidden_state_derivative.sum(2, djdbias_hh, 1.0, 1.0);
+    }
   }
 }
 
 void RNNCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(rnncell_props);
+  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(rnncell_props).get();
+
   context.updateTensor(wt_idx[RNNCellParams::hidden_state],
-                       batch * max_timestep);
+                       max_timestep * batch);
 
-  const float dropout_rate = std::get<props::DropOutRate>(rnncell_props);
   if (dropout_rate > epsilon) {
     /// @note default value of wt_idx[dropout_mask] is 0
-    context.updateTensor(wt_idx[RNNCellParams::dropout_mask], batch);
+    context.updateTensor(wt_idx[RNNCellParams::dropout_mask],
+                         max_timestep * batch);
   }
 }
 
index 2777c29..d2a3b2b 100644 (file)
@@ -102,14 +102,15 @@ private:
    * Unit: number of output neurons
    * HiddenStateActivation: activation type for hidden state. default is tanh
    * DropOutRate: dropout rate
+   * IntegrateBias: Integrate bias_ih, bias_hh to bias_h
    * MaxTimestep: maximum timestep for rnncell
    * TimeStep: timestep for which rnncell should operate
    *
    * */
   std::tuple<props::Unit, props::HiddenStateActivation, props::DropOutRate,
-             props::MaxTimestep, props::Timestep>
+             props::IntegrateBias, props::MaxTimestep, props::Timestep>
     rnncell_props;
-  std::array<unsigned int, 5> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is tanh
index 21f49e1..358d59d 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index adb98ec..55b1701 100644 (file)
@@ -38,9 +38,6 @@ class RNNCellStacked(torch.nn.Module):
                 for _ in range(num_rnn)
             ]
         )
-        for rnn in self.rnns:
-            rnn.bias_hh.data.fill_(0.0)
-            rnn.bias_hh.requires_grad=False
         self.unroll_for = unroll_for
         self.loss = torch.nn.MSELoss()
 
index 6f79935..fd02108 100644 (file)
@@ -71,7 +71,17 @@ def zoneout_translate(model):
     new_params = [transpose_(params[0]), transpose_(params[1]), bias, hidden_state, cell_state]
     yield from new_params
 
-@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell, torch.nn.LSTM))
+@register_for_((torch.nn.RNNCell))
+def rnn_lstm_translate(model):
+    params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+    # [hidden, input] -> [input, hidden]
+    def transpose_(weight):
+        return (weight[0], weight[1].transpose(1, 0))
+
+    new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3]]
+    yield from new_params
+
+@register_for_((torch.nn.LSTMCell, torch.nn.LSTM))
 def rnn_lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
     bias = ("bias", params[2][1] + params[3][1])
index cf8594e..13bfebe 100644 (file)
@@ -26,7 +26,7 @@ INSTANTIATE_TEST_CASE_P(RNNCell, LayerSemantics,
 
 auto rnncell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::RNNCellLayer>,
-  {"unit=5", "timestep=0", "max_timestep=1"}, "3:1:1:7",
+  {"unit=5", "integrate_bias=true", "timestep=0", "max_timestep=1"}, "3:1:1:7",
   "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
 
 INSTANTIATE_TEST_CASE_P(RNNCell, LayerGoldenTest,
index ba32e95..ba8bfe9 100644 (file)
@@ -342,7 +342,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleRNNCell() {
   }
 
   auto rnncell = makeGraph({
-    {"rnncell", {"name=a1", "unit=2"}},
+    {"rnncell", {"name=a1", "unit=2", "integrate_bias=false"}},
   });
 
   nn->addWithReferenceLayers(rnncell, "rnncell_scope", {"input"}, {"a1"},
@@ -372,8 +372,9 @@ static std::unique_ptr<NeuralNetwork> makeStackedRNNCell() {
   }
 
   auto rnncell = makeGraph({
-    {"rnncell", {"name=a1", "unit=2"}},
-    {"rnncell", {"name=a2", "unit=2", "input_layers=a1"}},
+    {"rnncell", {"name=a1", "unit=2", "integrate_bias=false"}},
+    {"rnncell",
+     {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}},
   });
 
   nn->addWithReferenceLayers(rnncell, "rnncell_scope", {"input"}, {"a1"},