[ GRU ] implement fowarding / backwarding of GRU Layer
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 10 Jun 2021 04:57:47 +0000 (13:57 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 17 Jun 2021 06:23:09 +0000 (15:23 +0900)
This commit includes,
  . forwarding implementation
  . calGradient implementation
  . calDerivative implementation
  . gru_basic unittest

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/layers/gru.cpp
nntrainer/layers/gru.h
test/input_gen/genModelTests.py
test/unittest/unittest_nntrainer_models.cpp

index 1608c31..cc8ccf3 100644 (file)
@@ -91,8 +91,10 @@ int GRULayer::initialize(Manager &manager) {
   TensorDim d = input_dim[0];
   d.width(unit);
 
-  hidden = std::make_shared<Var_Grad>(d, true, true, "GRU:temp_hidden");
+  hidden = std::make_shared<Var_Grad>(d, true, true, "GRU:output");
+
   d.width(unit * NUM_GATE);
+  zrg = std::make_shared<Var_Grad>(d, true, true, "GRU:zrg");
 
   TensorDim h_dim = TensorDim();
   h_dim.setTensorDim(3, unit);
@@ -124,10 +126,10 @@ void GRULayer::setProperty(const PropertyType type, const std::string &value) {
       output_dim[0].width(unit);
     }
     break;
-  case PropertyType::activation:
+  case PropertyType::hidden_state_activation:
     if (!value.empty()) {
       ActivationType acti_type = (ActivationType)parseType(value, TOKEN_ACTI);
-      LayerV1::activation_type = acti_type;
+      hidden_state_activation_type = acti_type;
       acti_func.setActiFunc(acti_type);
     }
     break;
@@ -159,7 +161,86 @@ void GRULayer::setRecurrentActivation(ActivationType activation) {
 }
 
 void GRULayer::forwarding(bool training) {
-  // NYI
+  Tensor &weight_xh =
+    weightAt(static_cast<int>(GRUParams::weight_xh)).getVariableRef();
+  Tensor &weight_hh =
+    weightAt(static_cast<int>(GRUParams::weight_hh)).getVariableRef();
+  Tensor &bias_h =
+    weightAt(static_cast<int>(GRUParams::bias_h)).getVariableRef();
+
+  hidden->getVariableRef().setZero();
+  zrg->getVariableRef().setZero();
+
+  h_prev.setZero();
+
+  Tensor &hidden_ = hidden->getVariableRef();
+  Tensor &input_ = net_input[0]->getVariableRef();
+
+  Tensor hs_prev;
+  Tensor hs;
+
+  for (unsigned int b = 0; b < input_dim[0].batch(); ++b) {
+    Tensor islice = input_.getBatchSlice(b, 1);
+    Tensor oslice = hidden_.getBatchSlice(b, 1);
+    Tensor zrg_ = zrg->getVariableRef().getBatchSlice(b, 1);
+
+    for (unsigned int t = 0; t < islice.height(); ++t) {
+      Tensor xs =
+        islice.getSharedDataTensor({islice.width()}, t * islice.width());
+      hs = oslice.getSharedDataTensor({oslice.width()}, t * oslice.width());
+      Tensor zrg_t =
+        zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
+
+      if (t > 0) {
+        hs_prev = oslice.getSharedDataTensor({oslice.width()},
+                                             (t - 1) * oslice.width());
+      } else {
+        hs_prev = h_prev.getBatchSlice(b, 1);
+      }
+
+      xs.dot(weight_xh, zrg_t);
+
+      Tensor ztrt = zrg_t.getSharedDataTensor({unit * 2}, 0);
+      Tensor ztrt_b = bias_h.getSharedDataTensor({unit * 2}, 0);
+
+      Tensor w_hh = weight_hh.getSharedDataTensor({unit * unit * 2}, 0);
+      Tensor w_g =
+        weight_hh.getSharedDataTensor({unit * unit}, unit * unit * 2);
+      Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
+      Tensor gt_b = bias_h.getSharedDataTensor({unit}, unit * 2);
+
+      ztrt.add_i(hs_prev.dot(w_hh));
+      ztrt.add_i(ztrt_b);
+
+      Tensor zt = ztrt.getSharedDataTensor({unit}, 0);
+      Tensor rt = ztrt.getSharedDataTensor({unit}, unit);
+
+      recurrent_acti_func.run_fn(rt, rt);
+      recurrent_acti_func.run_fn(zt, zt);
+
+      gt.add_i(rt.multiply(hs_prev).dot(w_g));
+      gt.add_i(gt_b);
+      acti_func.run_fn(gt, gt);
+
+      zt.multiply(hs_prev, hs);
+      Tensor a = zt.multiply(-1.0).add(1.0);
+      hs.add_i(gt.multiply(a));
+    }
+    h_prev.getBatchSlice(b, 1).copy(hs);
+  }
+
+  if (!return_sequences) {
+    TensorDim d = hidden_.getDim();
+    for (unsigned int b = 0; b < input_dim[0].batch(); ++b) {
+      Tensor dest = net_hidden[0]->getVariableRef().getSharedDataTensor(
+        {d.width()}, b * d.width());
+      Tensor src = hidden_.getSharedDataTensor(
+        {d.width()}, b * d.width() * d.height() + (d.height() - 1) * d.width());
+      dest.copy(src);
+    }
+  } else {
+    net_hidden[0]->getVariableRef().copy(hidden_);
+  }
 }
 
 void GRULayer::copy(std::shared_ptr<LayerV1> l) {
@@ -167,6 +248,7 @@ void GRULayer::copy(std::shared_ptr<LayerV1> l) {
 
   std::shared_ptr<GRULayer> from = std::static_pointer_cast<GRULayer>(l);
   this->unit = from->unit;
+  this->hidden_state_activation_type = from->hidden_state_activation_type;
   this->acti_func = from->acti_func;
   this->recurrent_activation_type = from->recurrent_activation_type;
   this->recurrent_acti_func = from->recurrent_acti_func;
@@ -174,11 +256,124 @@ void GRULayer::copy(std::shared_ptr<LayerV1> l) {
 }
 
 void GRULayer::calcDerivative() {
-  // NYI
+  Tensor &derivative_ = zrg->getGradientRef();
+  Tensor &weight =
+    weightAt(static_cast<int>(GRUParams::weight_xh)).getVariableRef();
+  Tensor &ret_ = net_input[0]->getGradientRef();
+  derivative_.dot(weight, ret_, false, true);
 }
 
 void GRULayer::calcGradient() {
-  // NYI
+  Tensor &djdw_x =
+    weightAt(static_cast<int>(GRUParams::weight_xh)).getGradientRef();
+  Tensor &djdw_h =
+    weightAt(static_cast<int>(GRUParams::weight_hh)).getGradientRef();
+  Tensor &djdb_h =
+    weightAt(static_cast<int>(GRUParams::bias_h)).getGradientRef();
+  Tensor &weight_hh =
+    weightAt(static_cast<int>(GRUParams::weight_hh)).getVariableRef();
+
+  djdw_x.setZero();
+  djdw_h.setZero();
+  djdb_h.setZero();
+
+  hidden->getGradientRef().setZero();
+  zrg->getGradientRef().setZero();
+
+  Tensor derivative_ = hidden->getGradientRef();
+
+  if (!return_sequences) {
+    TensorDim d = derivative_.getDim();
+    for (unsigned int b = 0; b < input_dim[0].batch(); ++b) {
+      Tensor dest = derivative_.getSharedDataTensor(
+        {d.width()}, b * d.width() * d.height() + (d.height() - 1) * d.width());
+      Tensor src = net_hidden[0]->getGradientRef().getSharedDataTensor(
+        {d.width()}, b * d.width());
+      dest.copy(src);
+    }
+  } else {
+    derivative_.copy(net_hidden[0]->getGradientRef());
+  }
+
+  Tensor &hidden_ = hidden->getVariableRef();
+  Tensor &input_ = net_input[0]->getVariableRef();
+  Tensor dh_nx = Tensor({derivative_.width()});
+
+  for (unsigned int b = 0; b < input_dim[0].batch(); ++b) {
+    Tensor deriv_t = derivative_.getBatchSlice(b, 1);
+    Tensor xs_t = input_.getBatchSlice(b, 1);
+    Tensor hs_t = hidden_.getBatchSlice(b, 1);
+
+    dh_nx.setZero();
+
+    Tensor dh;
+    Tensor hs_prev;
+    Tensor hs;
+    Tensor xs;
+    Tensor dzrg_ = zrg->getGradientRef().getBatchSlice(b, 1);
+    Tensor zrg_ = zrg->getVariableRef().getBatchSlice(b, 1);
+
+    for (unsigned int t = deriv_t.height(); t-- > 0;) {
+      dh = deriv_t.getSharedDataTensor({deriv_t.width()}, t * deriv_t.width());
+      xs = xs_t.getSharedDataTensor({xs_t.width()}, t * xs_t.width());
+      hs = hs_t.getSharedDataTensor({hs_t.width()}, t * hs_t.width());
+
+      Tensor dzrg_t =
+        dzrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
+      Tensor zrg_t =
+        zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
+
+      if (t == 0) {
+        hs_prev = Tensor({hs_t.width()});
+        hs_prev.setZero();
+      } else {
+        hs_prev =
+          hs_t.getSharedDataTensor({hs_t.width()}, (t - 1) * hs_t.width());
+      }
+      if (t < deriv_t.height() - 1) {
+        dh.add_i(dh_nx);
+      }
+
+      Tensor dhz = dzrg_t.getSharedDataTensor({unit}, 0);
+      Tensor dhr = dzrg_t.getSharedDataTensor({unit}, unit);
+      Tensor dhg = dzrg_t.getSharedDataTensor({unit}, unit * 2);
+
+      Tensor zt = zrg_t.getSharedDataTensor({unit}, 0);
+      Tensor rt = zrg_t.getSharedDataTensor({unit}, unit);
+      Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
+
+      dh.multiply(hs_prev, dhz);
+      dhz.subtract_i(gt.multiply(dh));
+      zt.multiply(-1.0, dhg);
+      dhg.add_i(1.0);
+      dhg.multiply_i(dh);
+      recurrent_acti_func.run_prime_fn(zt, dhz, dhz);
+      acti_func.run_prime_fn(gt, dhg, dhg);
+
+      Tensor dhzr = dzrg_t.getSharedDataTensor({unit * 2}, 0);
+      Tensor djdw_zr_h = djdw_h.getSharedDataTensor({unit * unit * 2}, 0);
+      Tensor djdw_g_h =
+        djdw_h.getSharedDataTensor({unit * unit}, unit * unit * 2);
+
+      Tensor wg_hh =
+        weight_hh.getSharedDataTensor({unit * unit}, unit * unit * 2);
+      Tensor wzr_hh = weight_hh.getSharedDataTensor({unit * unit * 2}, 0);
+
+      dhg.multiply(wg_hh, dh_nx);
+      hs_prev.multiply(dh_nx, dhr);
+      dh_nx.multiply_i(rt);
+      recurrent_acti_func.run_prime_fn(rt, dhr, dhr);
+
+      djdb_h.add_i(dzrg_t);
+
+      djdw_x.add_i(xs.dot(dzrg_t, true, false));
+      djdw_zr_h.add_i(hs_prev.dot(dhzr, true, false));
+      djdw_g_h.add_i(hs_prev.multiply(rt).dot(dhg, true, false));
+
+      dhzr.dot(wzr_hh, dh_nx, false, true);
+      dh_nx.add_i(zt.multiply(dh));
+    }
+  }
 }
 
 } // namespace nntrainer
index 165a00d..ae88a62 100644 (file)
@@ -121,6 +121,11 @@ private:
   unsigned int unit;
 
   /**
+   * @brief     activation type for hidden state : default is sigmoid
+   */
+  ActivationType hidden_state_activation_type;
+
+  /**
    * @brief     activation function for h_t : default is sigmoid
    */
   ActiFunc acti_func;
index 2ecdc1e..8f2bfce 100644 (file)
@@ -426,3 +426,24 @@ if __name__ == "__main__":
     multi_rnn_layer_tc(1,2)(file_name="multi_rnn_return_sequence.info")
     multi_rnn_layer_tc(2,2)(file_name="multi_rnn_return_sequence_with_batch.info")
     
+    gru_layer_tc = lambda batch, time, return_sequences: partial(
+        record,
+        model=[
+            K.Input(batch_shape=(batch,time, 1)),
+            K.layers.GRU(
+                time,
+                recurrent_activation="sigmoid",
+                activation="tanh",
+                return_sequences=return_sequences,
+            ),
+            K.layers.Dense(1),
+        ],
+        optimizer=opt.SGD(learning_rate=0.1),
+        iteration=10,
+        input_shape=(batch, time, 1),
+        label_shape=(batch, 1),
+        is_onehot=False,
+        loss_fn_str="mse",
+    )
+
+    gru_layer_tc(1, 1, False)(file_name="gru_basic.info")
index 7397e00..7138e9a 100644 (file)
@@ -620,6 +620,7 @@ static std::string fc_base = "type = Fully_connected";
 static std::string conv_base = "type = conv2d | stride = 1,1 | padding = 0,0";
 static std::string rnn_base = "type = rnn";
 static std::string lstm_base = "type = lstm";
+static std::string gru_base = "type = gru";
 static std::string pooling_base = "type = pooling2d | padding = 0,0";
 static std::string preprocess_flip_base = "type = preprocess_flip";
 static std::string preprocess_translate_base = "type = preprocess_translate";
@@ -1207,6 +1208,18 @@ INI multi_rnn_return_sequence_with_batch(
   }
 );
 
+INI gru_basic(
+  "gru_basic",
+  {
+    nn_base + "loss=mse | batch_size=1",
+    sgd_base + "learning_rate = 0.1",
+    I("input") + input_base + "input_shape=1:1:1",
+    I("gru") + gru_base +
+      "unit = 1" + "input_layers=input",
+    I("outputlayer") + fc_base + "unit = 1" + "input_layers=gru"
+  }
+);
+
 INSTANTIATE_TEST_CASE_P(
   nntrainerModelAutoTests, nntrainerModelTest, ::testing::Values(
     mkModelTc(fc_sigmoid_mse, "3:1:1:10", 10),
@@ -1259,7 +1272,8 @@ INSTANTIATE_TEST_CASE_P(
     mkModelTc(rnn_return_sequences, "1:1:2:1", 10),
     mkModelTc(rnn_return_sequence_with_batch, "2:1:2:1", 10),
     mkModelTc(multi_rnn_return_sequence, "1:1:1:1", 10),
-    mkModelTc(multi_rnn_return_sequence_with_batch, "2:1:1:1", 10)
+    mkModelTc(multi_rnn_return_sequence_with_batch, "2:1:1:1", 10),
+    mkModelTc(gru_basic, "1:1:1:1", 1)
 ), [](const testing::TestParamInfo<nntrainerModelTest::ParamType>& info){
  return std::get<0>(info.param).getName();
 });