[grucell] enable 2 bias
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 9 Dec 2021 16:18:57 +0000 (01:18 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 13 Dec 2021 02:50:23 +0000 (11:50 +0900)
 - Enable bias_hh in grucell

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/grucell.cpp
nntrainer/layers/grucell.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/transLayer_v2.py
test/unittest/layers/unittest_layers_grucell.cpp
test/unittest/models/unittest_models_recurrent.cpp

index 069c103..1c1d5e0 100644 (file)
@@ -42,24 +42,24 @@ namespace nntrainer {
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
 enum GRUCellParams {
-  weight_xh,
+  weight_ih,
   weight_hh,
   bias_h,
+  bias_ih,
+  bias_hh,
   hidden_state,
   zrg,
   dropout_mask
 };
 
-#define ENABLE_BIAS_IH 1
-// Todo: enable bias_hh
-#define ENABLE_BIAS_HH 0
-
 // Todo: handle with strided tensor more efficiently and reduce temporary
 // tensors
 GRUCellLayer::GRUCellLayer() :
   LayerImpl(),
-  grucell_props(props::Unit(), props::HiddenStateActivation(),
-                props::RecurrentActivation(), props::DropOutRate(),
+  grucell_props(props::Unit(),
+                props::HiddenStateActivation() = ActivationType::ACT_TANH,
+                props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
+                props::DropOutRate(), props::IntegrateBias(),
                 props::MaxTimestep(), props::Timestep()),
   acti_func(ActivationType::ACT_NONE, true),
   recurrent_acti_func(ActivationType::ACT_NONE, true),
@@ -67,28 +67,28 @@ GRUCellLayer::GRUCellLayer() :
   wt_idx.fill(std::numeric_limits<unsigned>::max());
 }
 
-// - weight_xh ( input to hidden )
-//  : [1, 1, input_size, unit (hidden_size) x NUM_GATE] -> z, r, g
-// - weight_hh ( hidden to hidden )
-//  : [1, 1, unit (hidden_size) , unit (hidden_size) x NUM_GATE] -> z, r, g
-// - bias_h ( hidden bias )
-//  : [1, 1, 1, unit (hidden_size) x NUM_GATE] -> z, r, g
 void GRUCellLayer::finalize(InitLayerContext &context) {
-  auto &weight_regularizer =
-    std::get<props::WeightRegularizer>(*layer_impl_props);
-  auto &weight_regularizer_constant =
-    std::get<props::WeightRegularizerConstant>(*layer_impl_props);
-  auto &weight_initializer =
-    std::get<props::WeightInitializer>(*layer_impl_props);
-  auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
+  const Tensor::Initializer weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props).get();
+  const Tensor::Initializer bias_initializer =
+    std::get<props::BiasInitializer>(*layer_impl_props).get();
+  const WeightRegularizer weight_regularizer =
+    std::get<props::WeightRegularizer>(*layer_impl_props).get();
+  const float weight_regularizer_constant =
+    std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
 
   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
-  auto &hidden_state_activation_type =
-    std::get<props::HiddenStateActivation>(grucell_props);
-  auto &recurrent_activation_type =
-    std::get<props::RecurrentActivation>(grucell_props);
-  const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
+  const ActivationType hidden_state_activation_type =
+    std::get<props::HiddenStateActivation>(grucell_props).get();
+  const ActivationType recurrent_activation_type =
+    std::get<props::RecurrentActivation>(grucell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(grucell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(grucell_props).get();
 
   if (context.getNumInputs() != 1) {
     throw std::invalid_argument("GRUCell layer takes only one input");
@@ -104,51 +104,72 @@ void GRUCellLayer::finalize(InitLayerContext &context) {
   const unsigned int batch_size = input_dim.batch();
   const unsigned int feature_size = input_dim.width();
 
-  // output_dim = [ batch, 1, 1, hidden_size (unit)]
+  // output_dim = [ batch, 1, 1, unit ]
   TensorDim output_dim(batch_size, 1, 1, unit);
   context.setOutputDimensions({output_dim});
 
-  if (dropout_rate > epsilon) {
-    wt_idx[GRUCellParams::dropout_mask] = context.requestTensor(
-      output_dim, "dropout_mask", Tensor::Initializer::NONE, false,
-      TensorLifespan::ITERATION_LIFESPAN);
-  }
-
-  TensorDim weight_xh_dim({feature_size, NUM_GATE * unit});
-  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  TensorDim bias_dim({NUM_GATE * unit});
-
-  // weight_initializer can be set seperately. weight_xh initializer,
+  // weight_initializer can be set seperately. weight_ih initializer,
   // weight_hh initializer kernel initializer & recurrent_initializer in keras
   // for now, it is set same way.
-  wt_idx[GRUCellParams::weight_xh] =
-    context.requestWeight(weight_xh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_xh", true);
+
+  // - weight_ih ( input to hidden )
+  // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
+  TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+  wt_idx[GRUCellParams::weight_ih] =
+    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "weight_ih", true);
+  // - weight_hh ( hidden to hidden )
+  // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
+  TensorDim weight_hh_dim({unit, NUM_GATE * unit});
   wt_idx[GRUCellParams::weight_hh] =
     context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
                           weight_regularizer_constant, "weight_hh", true);
-  wt_idx[GRUCellParams::bias_h] = context.requestWeight(
-    bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias_h", true);
+  if (!disable_bias) {
+    if (integrate_bias) {
+      // - bias_h ( input bias, hidden bias are integrate to 1 bias )
+      // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+      TensorDim bias_h_dim({NUM_GATE * unit});
+      wt_idx[GRUCellParams::bias_h] =
+        context.requestWeight(bias_h_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+    } else {
+      // - bias_ih ( input bias )
+      // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+      TensorDim bias_ih_dim({NUM_GATE * unit});
+      wt_idx[GRUCellParams::bias_ih] =
+        context.requestWeight(bias_ih_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      // - bias_hh ( hidden bias )
+      // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+      TensorDim bias_hh_dim({NUM_GATE * unit});
+      wt_idx[GRUCellParams::bias_hh] =
+        context.requestWeight(bias_hh_dim, bias_initializer,
+                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+    }
+  }
 
+  // hidden_state_dim = [ max_timestep * batch, 1, 1, unit ]
   TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
   wt_idx[GRUCellParams::hidden_state] = context.requestTensor(
     hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
     TensorLifespan::ITERATION_LIFESPAN, false);
 
-  TensorDim zrg_dim(max_timestep * batch_size, 1, 1, unit * NUM_GATE);
+  // zrg_dim = [ max_timestep * batch, 1, 1, NUM_GATE * unit ]
+  TensorDim zrg_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
   wt_idx[GRUCellParams::zrg] =
     context.requestTensor(zrg_dim, "zrg", Tensor::Initializer::NONE, true,
                           TensorLifespan::ITERATION_LIFESPAN, false);
 
-  if (hidden_state_activation_type.get() == ActivationType::ACT_NONE) {
-    hidden_state_activation_type.set(ActivationType::ACT_TANH);
+  if (dropout_rate > epsilon) {
+    // dropout_mask_dim = [ max_timestep * batch, 1, 1, unit ]
+    TensorDim dropout_mask_dim(max_timestep * batch_size, 1, 1, unit);
+    wt_idx[GRUCellParams::dropout_mask] = context.requestTensor(
+      dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+      TensorLifespan::ITERATION_LIFESPAN);
   }
-  acti_func.setActiFunc(hidden_state_activation_type.get());
 
-  if (recurrent_activation_type.get() == ActivationType::ACT_NONE) {
-    recurrent_activation_type.set(ActivationType::ACT_SIGMOID);
-  }
-  recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
+  acti_func.setActiFunc(hidden_state_activation_type);
+  recurrent_acti_func.setActiFunc(recurrent_activation_type);
 }
 
 void GRUCellLayer::setProperty(const std::vector<std::string> &values) {
@@ -163,37 +184,52 @@ void GRUCellLayer::exportTo(Exporter &exporter,
 }
 
 void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
-  const unsigned int unit = std::get<props::Unit>(grucell_props).get();
-  const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
-  const unsigned int timestep = std::get<props::Timestep>(grucell_props);
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
 
-  Tensor &weight_xh = context.getWeight(wt_idx[GRUCellParams::weight_xh]);
-  Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
-  Tensor &bias_ih = context.getWeight(wt_idx[GRUCellParams::bias_h]);
+  const unsigned int unit = std::get<props::Unit>(grucell_props).get();
+  const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(grucell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(grucell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
 
   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  Tensor &hidden_states =
-    context.getTensor(wt_idx[GRUCellParams::hidden_state]);
-  Tensor &zrg_gates = context.getTensor(wt_idx[GRUCellParams::zrg]);
-  Tensor prev_hidden_state;
-
+  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   const TensorDim &input_dim = input.getDim();
   const unsigned int batch_size = input_dim.batch();
 
-  hidden_states.reshape({max_timestep, 1, batch_size, unit});
-  zrg_gates.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+  Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
+  Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
+  Tensor empty;
+  Tensor &bias_h = !disable_bias && integrate_bias
+                     ? context.getWeight(wt_idx[GRUCellParams::bias_h])
+                     : empty;
+  Tensor &bias_ih = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[GRUCellParams::bias_ih])
+                      : empty;
+  Tensor &bias_hh = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
+                      : empty;
 
-  Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
+  Tensor &hidden_states =
+    context.getTensor(wt_idx[GRUCellParams::hidden_state]);
+  hidden_states.reshape({max_timestep, 1, batch_size, unit});
+  Tensor prev_hidden_state;
   if (!timestep) {
     prev_hidden_state = Tensor(batch_size, unit);
     prev_hidden_state.setZero();
   } else {
     prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
   }
+  Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
+
+  Tensor &zrg_gates = context.getTensor(wt_idx[GRUCellParams::zrg]);
+  zrg_gates.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
   Tensor zrg_gate = zrg_gates.getBatchSlice(timestep, 1);
 
-  input.dot(weight_xh, zrg_gate); // x_z, x_r, x_g
+  input.dot(weight_ih, zrg_gate); // x_z, x_r, x_g
 
   Tensor zr_gate =
     zrg_gate.getSharedDataTensor({batch_size, 2 * unit}, 0, false);
@@ -207,11 +243,18 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
   weight_hh_g.copy_with_stride(
     weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
 
-  if (timestep) {
-    zr_gate.add_i_strided(prev_hidden_state.dot(weight_hh_zr));
+  zr_gate.add_i_strided(prev_hidden_state.dot(weight_hh_zr));
+  if (!disable_bias) {
+    if (integrate_bias) {
+      Tensor bias_h_zr = bias_h.getSharedDataTensor({2 * unit}, 0);
+      zr_gate.add_i(bias_h_zr);
+    } else {
+      Tensor bias_ih_zr = bias_ih.getSharedDataTensor({2 * unit}, 0);
+      zr_gate.add_i(bias_ih_zr);
+      Tensor bias_hh_zr = bias_hh.getSharedDataTensor({2 * unit}, 0);
+      zr_gate.add_i(bias_hh_zr);
+    }
   }
-  Tensor bias_ih_zr = bias_ih.getSharedDataTensor({2 * unit}, 0);
-  zr_gate.add_i(bias_ih_zr);
 
   recurrent_acti_func.run_fn(zr_gate, zr_gate);
 
@@ -219,18 +262,22 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor r_gate = zr_gate.getSharedDataTensor({batch_size, unit}, unit, false);
 
   Tensor temp;
-  prev_hidden_state.dot(weight_hh_g, temp, false, false);
-#if ENABLE_BIAS_HH
-  // Todo: fix this to get the bias_hh_g from bias_hh
-  Tensor bias_hh_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
-  temp.add_i(bias_hh_g);
-#endif
+  prev_hidden_state.dot(weight_hh_g, temp);
+  if (!disable_bias && !integrate_bias) {
+    Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
+    temp.add_i(bias_hh_g);
+  }
   temp.multiply_i_strided(r_gate);
   g_gate.add_i_strided(temp);
-#if ENABLE_BIAS_IH
-  Tensor bias_ih_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
-  g_gate.add_i(bias_ih_g);
-#endif
+  if (!disable_bias) {
+    if (integrate_bias) {
+      Tensor bias_h_g = bias_h.getSharedDataTensor({unit}, 2 * unit);
+      g_gate.add_i(bias_h_g);
+    } else {
+      Tensor bias_ih_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
+      g_gate.add_i(bias_ih_g);
+    }
+  }
 
   acti_func.run_fn(g_gate, g_gate);
 
@@ -244,49 +291,66 @@ void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
     hidden_state.multiply_i(mask);
   }
 
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   output.copy(hidden_state);
 }
 
 void GRUCellLayer::calcDerivative(RunLayerContext &context) {
   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
-  const unsigned int timestep = std::get<props::Timestep>(grucell_props);
-  const TensorDim &input_dim = context.getInput(SINGLE_INOUT_IDX).getDim();
-  const unsigned int batch_size = input_dim.batch();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(grucell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
+
+  const unsigned int batch_size =
+    context.getInput(SINGLE_INOUT_IDX).getDim().batch();
 
   Tensor &zrg_gates_derivatives =
     context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
   zrg_gates_derivatives.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
   Tensor zrg_gate_derivative = zrg_gates_derivatives.getBatchSlice(timestep, 1);
-  Tensor &weight_xh = context.getWeight(wt_idx[GRUCellParams::weight_xh]);
+  Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
 
-  zrg_gate_derivative.dot(weight_xh, outgoing_derivative, false, true);
+  zrg_gate_derivative.dot(weight_ih, outgoing_derivative, false, true);
 }
 
 void GRUCellLayer::calcGradient(RunLayerContext &context) {
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
-  const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
-  const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
-  const unsigned int timestep = std::get<props::Timestep>(grucell_props);
+  const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+  const bool integrate_bias =
+    std::get<props::IntegrateBias>(grucell_props).get();
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(grucell_props).get();
+  const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
 
   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  const TensorDim &input_dim = input.getDim();
-  const unsigned int batch_size = input_dim.batch();
+  const unsigned int batch_size = input.getDim().batch();
 
-  Tensor &djdweight_xh =
-    context.getWeightGrad(wt_idx[GRUCellParams::weight_xh]);
+  Tensor &djdweight_ih =
+    context.getWeightGrad(wt_idx[GRUCellParams::weight_ih]);
+  Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
   Tensor &djdweight_hh =
     context.getWeightGrad(wt_idx[GRUCellParams::weight_hh]);
-  Tensor &djdbias_ih = context.getWeightGrad(wt_idx[GRUCellParams::bias_h]);
-  Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
-  Tensor &bias_h = context.getWeight(wt_idx[GRUCellParams::bias_h]);
-  Tensor bias_h_g = bias_h.getSharedDataTensor({unit}, 2 * unit);
 
-  Tensor djdw_zr_h =
+  Tensor empty;
+  Tensor &djdbias_h = !disable_bias && integrate_bias
+                        ? context.getWeightGrad(wt_idx[GRUCellParams::bias_h])
+                        : empty;
+  Tensor &djdbias_ih = !disable_bias && !integrate_bias
+                         ? context.getWeightGrad(wt_idx[GRUCellParams::bias_ih])
+                         : empty;
+  Tensor &bias_hh = !disable_bias && !integrate_bias
+                      ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
+                      : empty;
+  Tensor &djdbias_hh = !disable_bias && !integrate_bias
+                         ? context.getWeightGrad(wt_idx[GRUCellParams::bias_hh])
+                         : empty;
+
+  Tensor djdweight_hh_zr =
     djdweight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
-  Tensor djdw_g_h =
+  Tensor djdweight_hh_g =
     djdweight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
   Tensor &hidden_states =
     context.getTensor(wt_idx[GRUCellParams::hidden_state]);
@@ -306,27 +370,35 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
   hidden_states_derivatives.reshape({max_timestep, 1, batch_size, unit});
   Tensor hidden_state_derivative =
     hidden_states_derivatives.getBatchSlice(timestep, 1);
-  hidden_state_derivative.reshape(incoming_derivative.getDim());
   if (timestep + 1 == max_timestep) {
-    djdweight_xh.setZero();
+    djdweight_ih.setZero();
     djdweight_hh.setZero();
-    djdbias_ih.setZero();
-    hidden_state_derivative.copyData(incoming_derivative);
-  } else {
-    hidden_state_derivative.add_i(incoming_derivative);
+    if (!disable_bias) {
+      if (integrate_bias) {
+        djdbias_h.setZero();
+      } else {
+        djdbias_ih.setZero();
+        djdbias_hh.setZero();
+      }
+    }
+    hidden_state_derivative.setZero();
   }
-  // restore the dimension
-  hidden_state_derivative.reshape({1, 1, batch_size, unit});
 
-  Tensor hs_prev;
+  hidden_state_derivative.reshape(
+    incoming_derivative.getDim()); // reshape to incoming_derivative dim
+  hidden_state_derivative.add_i(incoming_derivative);
+  hidden_state_derivative.reshape(
+    {1, 1, batch_size, unit}); // restore the dimension
+
+  Tensor prev_hidden_state;
   Tensor dh_nx;
   if (timestep) {
-    hs_prev = hidden_states.getBatchSlice(timestep - 1, 1);
+    prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
     dh_nx = hidden_states_derivatives.getBatchSlice(timestep - 1, 1);
   } else {
     dh_nx = Tensor(batch_size, unit);
-    hs_prev = Tensor(batch_size, unit);
-    hs_prev.setZero();
+    prev_hidden_state = Tensor(batch_size, unit);
+    prev_hidden_state.setZero();
   }
 
   if (dropout_rate > epsilon) {
@@ -345,8 +417,8 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
   Tensor rt = zrg_gate.getSharedDataTensor({batch_size, unit}, unit, false);
   Tensor gt = zrg_gate.getSharedDataTensor({batch_size, unit}, unit * 2, false);
 
-  hidden_state_derivative.multiply_strided(zt, dh_nx);    // dh_nx = d1
-  hidden_state_derivative.multiply_strided(hs_prev, dhz); // dhz = d2
+  hidden_state_derivative.multiply_strided(zt, dh_nx); // dh_nx = d1
+  hidden_state_derivative.multiply_strided(prev_hidden_state, dhz); // dhz = d2
   dhz.add_i_strided(hidden_state_derivative.multiply_strided(gt),
                     -1.0f); // dhz = d5
   zt.multiply(-1.0, dhg);
@@ -369,39 +441,40 @@ void GRUCellLayer::calcGradient(RunLayerContext &context) {
   Tensor temp = Tensor(batch_size, unit);
   Tensor dhg_;
   dhg_.copy_with_stride(dhg);
-  hs_prev.dot(wg_hh, temp);
-#if ENABLE_BIAS_HH
-  temp.add_i(bias_h_g);
-#endif
+  prev_hidden_state.dot(wg_hh, temp);
+  if (!disable_bias && !integrate_bias) {
+    Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
+    temp.add_i(bias_hh_g);
+  }
   dhg_.multiply_strided(temp, dhr); // dhr = d15
 
-  // reset temp : hs_prev * rt for djdbias_hh_g and dh_nx
+  // reset temp : prev_hidden_state * rt for djdbias_hh_g and dh_nx
   dhg_.multiply_strided(rt, temp);
-#if ENABLE_BIAS_HH
-  // Todo: fix this to get the djdbias_hh_g from djdbias_hh
-  Tensor djdbias_hh_g = djdbias_ih.getSharedDataTensor({unit}, 2 * unit);
-  temp.sum(2, djdbias_hh_g, 1.0, 1.0);
-#endif
+  if (!disable_bias && !integrate_bias) {
+    Tensor djdbias_hh_g = djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
+    temp.sum(2, djdbias_hh_g, 1.0, 1.0);
+  }
   temp.dot(wg_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14
 
   recurrent_acti_func.run_prime_fn(rt, dhr, dhr); // dhr = d16
 
-#if ENABLE_BIAS_HH
-  // Todo: fix this to get the djdbias_hh_zr from djdbias_hh
-  Tensor djdbias_hh_zr = djdbias_ih.getSharedDataTensor({2 * unit}, 0);
-  djdbias_hh_zr.add_i(
-    zrg_gate_derivative.sum(2).getSharedDataTensor({2 * unit}, 0));
-#endif
-#if ENABLE_BIAS_IH
-  zrg_gate_derivative.sum(2, djdbias_ih, 1.0, 1.0);
-#endif
+  if (!disable_bias) {
+    if (integrate_bias) {
+      zrg_gate_derivative.sum(2, djdbias_h, 1.0, 1.0);
+    } else {
+      zrg_gate_derivative.sum(2, djdbias_ih, 1.0, 1.0);
+      Tensor djdbias_hh_zr = djdbias_hh.getSharedDataTensor({2 * unit}, 0);
+      djdbias_hh_zr.add_i(
+        zrg_gate_derivative.sum(2).getSharedDataTensor({2 * unit}, 0));
+    }
+  }
 
-  djdweight_xh.add_i(input.dot(zrg_gate_derivative, true, false));
+  djdweight_ih.add_i(input.dot(zrg_gate_derivative, true, false));
 
   Tensor dhzr_;
   dhzr_.copy_with_stride(dhzr);
-  djdw_zr_h.add_i_strided(hs_prev.dot(dhzr_, true, false));
-  djdw_g_h.add_i_strided(hs_prev.dot(temp, true, false));
+  djdweight_hh_zr.add_i_strided(prev_hidden_state.dot(dhzr_, true, false));
+  djdweight_hh_g.add_i_strided(prev_hidden_state.dot(temp, true, false));
   dhzr_.dot(wzr_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14 + d12 + d17
 }
 
@@ -413,7 +486,8 @@ void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
 
   const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
   if (dropout_rate > epsilon) {
-    context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
+    context.updateTensor(wt_idx[GRUCellParams::dropout_mask],
+                         max_timestep * batch);
   }
 }
 
index ef3ca3c..b3c30f0 100644 (file)
@@ -105,15 +105,16 @@ private:
    * HiddenStateActivation: activation type for hidden state. default is tanh
    * RecurrentActivation: activation type for recurrent. default is sigmoid
    * DropOutRate: dropout rate
+   * IntegrateBias: integrate bias_ih, bias_hh to bias_h
    * MaxTimeStep: Maximum timestep of gru
    * TimeStep: timestep for which gru should operate
    *
    * */
   std::tuple<props::Unit, props::HiddenStateActivation,
-             props::RecurrentActivation, props::DropOutRate, props::MaxTimestep,
-             props::Timestep>
+             props::RecurrentActivation, props::DropOutRate,
+             props::IntegrateBias, props::MaxTimestep, props::Timestep>
     grucell_props;
-  std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
 
   /**
    * @brief     activation function for h_t : default is sigmoid
index 012259c..5e67378 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index e3b3240..dd2d852 100644 (file)
@@ -122,9 +122,6 @@ class GRUCellStacked(torch.nn.Module):
                 for _ in range(num_gru)
             ]
         )
-        for gru in self.grus:
-            gru.bias_hh.data.fill_(0.0)
-            gru.bias_hh.requires_grad=False
         self.unroll_for = unroll_for
         self.loss = torch.nn.MSELoss()
 
index 9184d43..10f14ad 100644 (file)
@@ -83,24 +83,26 @@ def rnn_lstm_translate(model):
 @register_for_((torch.nn.GRUCell))
 def gru_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
-    bias = ("bias", params[2][1] + params[3][1])
 
     # [hidden, input] -> [input, hidden]
     def transpose_(weight):
         return (weight[0], weight[1].transpose(1, 0))
 
     # resetgate, inputgate, newgate -> inputgate, resetgate, newgate
-    def reorder_weight(param):
-        if (param[1].dim() == 2):
-            hidden_size = int(param[1].shape[1] / 3)
-        else:
-            hidden_size = int(param[1].shape[0] / 3)
-
-        weight = param[1].hsplit(3)
-        return (param[0], torch.hstack((weight[1], weight[0], weight[2])))
-
-    transposed_params = [transpose_(params[0]), transpose_(params[1]), bias]
-    new_params = [reorder_weight(transposed_params[0]), reorder_weight(transposed_params[1]), reorder_weight(transposed_params[2])]
+    def reorder_weights(params):
+        reordered_weights = []
+        for param in params: # param = ("name", weight)
+            if (param[1].dim() == 2): # weight
+                hidden_size = int(param[1].shape[1] / 3)
+            else: # bias
+                hidden_size = int(param[1].shape[0] / 3)
+
+            weight = param[1].hsplit(3)
+            reordered_weights.append((param[0], torch.hstack((weight[1], weight[0], weight[2])))) # reorder
+        return reordered_weights
+
+    transposed_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3]]
+    new_params = reorder_weights(transposed_params)
 
     yield from new_params
 
index 23a2361..92e75b3 100644 (file)
 
 auto semantic_grucell = LayerSemanticsParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
-  nntrainer::GRUCellLayer::type, {"unit=1", "max_timestep=1", "timestep=0"}, 0,
-  false, 1);
+  nntrainer::GRUCellLayer::type,
+  {"unit=1", "max_timestep=1", "timestep=0", "integrate_bias=true"}, 0, false,
+  1);
 
 INSTANTIATE_TEST_CASE_P(GRUCell, LayerSemantics,
                         ::testing::Values(semantic_grucell));
 
 auto grucell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
-  {"unit=5", "max_timestep=1", "timestep=0"}, "3:1:1:7",
+  {"unit=5", "max_timestep=1", "timestep=0", "integrate_bias=true"}, "3:1:1:7",
   "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
 
 INSTANTIATE_TEST_CASE_P(GRUCell, LayerGoldenTest,
index c358614..4dd0365 100644 (file)
@@ -406,7 +406,7 @@ static std::unique_ptr<NeuralNetwork> makeSingleGRUCell() {
   }
 
   auto grucell = makeGraph({
-    {"grucell", {"name=a1", "unit=2"}},
+    {"grucell", {"name=a1", "unit=2", "integrate_bias=false"}},
   });
 
   nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},
@@ -436,8 +436,9 @@ static std::unique_ptr<NeuralNetwork> makeStackedGRUCell() {
   }
 
   auto grucell = makeGraph({
-    {"grucell", {"name=a1", "unit=2"}},
-    {"grucell", {"name=a2", "unit=2", "input_layers=a1"}},
+    {"grucell", {"name=a1", "unit=2", "integrate_bias=false"}},
+    {"grucell",
+     {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}},
   });
 
   nn->addWithReferenceLayers(grucell, "grucell_scope", {"input"}, {"a1"},