[ LAYERS ] LSTM : parallelization along batch direction (calGradient)
authorjijoong.moon <jijoong.moon@samsung.com>
Wed, 3 Aug 2022 06:44:36 +0000 (15:44 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 22 Aug 2022 01:44:40 +0000 (10:44 +0900)
This patch includes parallelization along batch direction for
calculation of LSTM Gradient.
Also thread id is added in thread callback parameter to use it internally.

Resolves:

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/layers/conv2d_layer.cpp
nntrainer/layers/lstm.cpp
nntrainer/utils/nntr_threads.cpp
nntrainer/utils/nntr_threads.h

index bb6d77f5d48eba0df28af8a35203cfdb22ea7db5..3606a4522a46b2dc036a0c3116ac2a8e325d1193 100644 (file)
@@ -412,7 +412,8 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) {
    * Below sets the pad area values to zero
    * it is faster to do this way than seting selective area to zero
    */
-  auto forwarding_job = [&](unsigned int s, unsigned int e, void *user_data) {
+  auto forwarding_job = [&](unsigned int s, unsigned int e, unsigned int pid,
+                            void *user_data) {
     Tensor result = Tensor(calcCol2ImOutputDim(out_dim, filter_dim));
     result.setZero();
     for (unsigned int b = s; b < e; ++b) {
@@ -430,7 +431,7 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) {
   if (workers.getNumWorkers() > 1) {
     workers.run();
   } else {
-    forwarding_job(0, in_dim.batch(), nullptr);
+    forwarding_job(0, in_dim.batch(), 0, nullptr);
   }
 
   filter_kernel.reshape(filter_dim);
@@ -465,7 +466,7 @@ void Conv2DLayer::calcDerivative(RunLayerContext &context) {
   /// col2im(column matrix) to reconstruct the original image
 
   auto compute_derivative = [&](unsigned int s, unsigned int e,
-                                void *user_data) {
+                                unsigned int pid, void *user_data) {
     Tensor result =
       Tensor(calcCol2ImOutputDim(derivative.getDim(), filter_dim));
 
@@ -483,9 +484,8 @@ void Conv2DLayer::calcDerivative(RunLayerContext &context) {
 
   if (workers.getNumWorkers() > 1) {
     workers.run();
-
   } else {
-    compute_derivative(0, derivative.batch(), nullptr);
+    compute_derivative(0, derivative.batch(), 0, nullptr);
   }
 
   filter_kernel.reshape(filter_dim);
@@ -526,7 +526,8 @@ void Conv2DLayer::calcGradient(RunLayerContext &context) {
     Tensor delK_par = Tensor(delK_ext);
     delK_par.setZero();
 
-    auto calc_grad_job = [&](unsigned int s, unsigned int e, void *user_data) {
+    auto calc_grad_job = [&](unsigned int s, unsigned int e, unsigned int pid,
+                             void *user_data) {
       Tensor result = Tensor(im2col_result.getDim());
       result.setZero();
       for (unsigned int b = s; b < e; ++b) {
index fbaa7473bbad0cd599566900430a722f694c91f9..97b32bdfaf92dd1667cf32117b3f05935ebda102 100644 (file)
@@ -14,6 +14,7 @@
 #include <layer_context.h>
 #include <lstm.h>
 #include <lstmcell_core.h>
+#include <nntr_threads.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <node_exporter.h>
@@ -222,78 +223,214 @@ void batch_first_calcGradient(
     d_hidden_state_.multiply_i(mask_);
   }
 
-  for (unsigned int batch = 0; batch < batch_size; ++batch) {
-    const Tensor input_sample = input_.getBatchSlice(batch, 1);
+  auto workers = ParallelBatch(batch_size);
+
+  if (workers.getNumWorkers() > 1) {
+
+    TensorDim weight_ih_d = d_weight_ih.getDim();
+    TensorDim weight_hh_d = d_weight_hh.getDim();
+
+    TensorDim bias_ih_d = d_bias_ih.getDim();
+    TensorDim bias_hh_d = d_bias_hh.getDim();
+    TensorDim bias_h_d = d_bias_h.getDim();
+
+    weight_ih_d.batch(workers.getNumWorkers());
+    weight_hh_d.batch(workers.getNumWorkers());
+    bias_ih_d.batch(workers.getNumWorkers());
+    bias_hh_d.batch(workers.getNumWorkers());
+    bias_h_d.batch(workers.getNumWorkers());
+
+    Tensor sub_d_weight_ih = Tensor(weight_ih_d);
+    Tensor sub_d_weight_hh = Tensor(weight_hh_d);
+    Tensor sub_d_bias_ih = Tensor(bias_ih_d);
+    Tensor sub_d_bias_hh = Tensor(bias_hh_d);
+    Tensor sub_d_bias_h = Tensor(bias_h_d);
+
+    sub_d_weight_ih.setZero();
+    sub_d_weight_hh.setZero();
+    sub_d_bias_ih.setZero();
+    sub_d_bias_hh.setZero();
+    sub_d_bias_h.setZero();
+
+    auto batch_job = [&](unsigned int s, unsigned int e, unsigned int pid,
+                         void *user_data) {
+      for (unsigned int batch = s; batch < e; ++batch) {
+        const Tensor input_sample = input_.getBatchSlice(batch, 1);
+
+        const Tensor hidden_state_sample =
+          hidden_state_.getBatchSlice(batch, 1);
+        Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
+        const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+        Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
+
+        const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+        Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
+
+        Tensor input;
+        Tensor prev_hidden_state;
+        Tensor d_prev_hidden_state;
+        Tensor prev_cell_state;
+        Tensor d_prev_cell_state;
+        Tensor d_hidden_state;
+        Tensor cell_state;
+        Tensor d_cell_state;
+
+        Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(pid, 1);
+        Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(pid, 1);
+        Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(pid, 1);
+        Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(pid, 1);
+        Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(pid, 1);
+
+        for (int t = max_timestep - 1; t > -1; t--) {
+          input = input_sample.getSharedDataTensor(
+            {feature_size},
+            (reverse ? max_timestep - 1 - t : t) * feature_size);
+
+          if (!t) {
+            prev_hidden_state = Tensor(unit);
+            prev_hidden_state.setZero();
+            d_prev_hidden_state = Tensor(unit);
+            d_prev_hidden_state.setZero();
+          } else {
+            prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+              {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+            d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+              {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+          }
+          d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+          if (!t) {
+            prev_cell_state = Tensor(unit);
+            prev_cell_state.setZero();
+            d_prev_cell_state = Tensor(unit);
+            d_prev_cell_state.setZero();
+          } else {
+            prev_cell_state = cell_state_sample.getSharedDataTensor(
+              {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+            d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
+              {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+          }
+          cell_state = cell_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+          d_cell_state = d_cell_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+          Tensor ifgo = ifgo_sample.getSharedDataTensor(
+            {NUM_GATE * unit},
+            (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+          Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
+            {NUM_GATE * unit},
+            (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+
+          // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
+          // already have precalculated values from incomming derivatives
+          Tensor d_prev_hidden_state_temp;
+
+          lstmcell_calcGradient(
+            1, unit, disable_bias, integrate_bias, acti_func,
+            recurrent_acti_func, input, prev_hidden_state,
+            d_prev_hidden_state_temp, prev_cell_state, d_prev_cell_state,
+            d_hidden_state, cell_state, d_cell_state, p_d_weight_ih, weight_hh,
+            p_d_weight_hh, p_d_bias_h, p_d_bias_ih, p_d_bias_hh, ifgo, d_ifgo);
+
+          d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+        }
+      }
+    };
 
-    const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
-    Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
-    const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
-    Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
-
-    const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
-    Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
-
-    Tensor input;
-    Tensor prev_hidden_state;
-    Tensor d_prev_hidden_state;
-    Tensor prev_cell_state;
-    Tensor d_prev_cell_state;
-    Tensor d_hidden_state;
-    Tensor cell_state;
-    Tensor d_cell_state;
-
-    for (int t = max_timestep - 1; t > -1; t--) {
-      input = input_sample.getSharedDataTensor(
-        {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+    workers.setCallback(batch_job, nullptr);
+    workers.run();
 
-      if (!t) {
-        prev_hidden_state = Tensor(unit);
-        prev_hidden_state.setZero();
-        d_prev_hidden_state = Tensor(unit);
-        d_prev_hidden_state.setZero();
-      } else {
-        prev_hidden_state = hidden_state_sample.getSharedDataTensor(
-          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
-        d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
-          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
-      }
-      d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
-        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+    for (unsigned int b = 0; b < workers.getNumWorkers(); ++b) {
 
-      if (!t) {
-        prev_cell_state = Tensor(unit);
-        prev_cell_state.setZero();
-        d_prev_cell_state = Tensor(unit);
-        d_prev_cell_state.setZero();
-      } else {
-        prev_cell_state = cell_state_sample.getSharedDataTensor(
-          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
-        d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
-          {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
-      }
-      cell_state = cell_state_sample.getSharedDataTensor(
-        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
-      d_cell_state = d_cell_state_sample.getSharedDataTensor(
-        {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+      Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(b, 1);
+      Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(b, 1);
+      Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(b, 1);
+      Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(b, 1);
+      Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(b, 1);
 
-      Tensor ifgo = ifgo_sample.getSharedDataTensor(
-        {NUM_GATE * unit},
-        (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
-      Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
-        {NUM_GATE * unit},
-        (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+      d_weight_ih.add_i(p_d_weight_ih);
+      d_weight_hh.add_i(p_d_weight_hh);
+      d_bias_ih.add_i(p_d_bias_ih);
+      d_bias_hh.add_i(p_d_bias_hh);
+      d_bias_h.add_i(p_d_bias_h);
+    }
 
-      // Temporary variable for d_prev_hidden_state. d_prev_hidden_state already
-      // have precalculated values from incomming derivatives
-      Tensor d_prev_hidden_state_temp;
-
-      lstmcell_calcGradient(1, unit, disable_bias, integrate_bias, acti_func,
-                            recurrent_acti_func, input, prev_hidden_state,
-                            d_prev_hidden_state_temp, prev_cell_state,
-                            d_prev_cell_state, d_hidden_state, cell_state,
-                            d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
-                            d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
-      d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+  } else {
+    for (unsigned int batch = 0; batch < batch_size; ++batch) {
+      const Tensor input_sample = input_.getBatchSlice(batch, 1);
+
+      const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
+      Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
+      const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+      Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
+
+      const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+      Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
+
+      Tensor input;
+      Tensor prev_hidden_state;
+      Tensor d_prev_hidden_state;
+      Tensor prev_cell_state;
+      Tensor d_prev_cell_state;
+      Tensor d_hidden_state;
+      Tensor cell_state;
+      Tensor d_cell_state;
+
+      for (int t = max_timestep - 1; t > -1; t--) {
+        input = input_sample.getSharedDataTensor(
+          {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+
+        if (!t) {
+          prev_hidden_state = Tensor(unit);
+          prev_hidden_state.setZero();
+          d_prev_hidden_state = Tensor(unit);
+          d_prev_hidden_state.setZero();
+        } else {
+          prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+          d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+        }
+        d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+        if (!t) {
+          prev_cell_state = Tensor(unit);
+          prev_cell_state.setZero();
+          d_prev_cell_state = Tensor(unit);
+          d_prev_cell_state.setZero();
+        } else {
+          prev_cell_state = cell_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+          d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
+            {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+        }
+        cell_state = cell_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+        d_cell_state = d_cell_state_sample.getSharedDataTensor(
+          {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+        Tensor ifgo = ifgo_sample.getSharedDataTensor(
+          {NUM_GATE * unit},
+          (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+        Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
+          {NUM_GATE * unit},
+          (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+
+        // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
+        // already have precalculated values from incomming derivatives
+        Tensor d_prev_hidden_state_temp;
+
+        lstmcell_calcGradient(1, unit, disable_bias, integrate_bias, acti_func,
+                              recurrent_acti_func, input, prev_hidden_state,
+                              d_prev_hidden_state_temp, prev_cell_state,
+                              d_prev_cell_state, d_hidden_state, cell_state,
+                              d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
+                              d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+        d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+      }
     }
   }
 }
index 4b5cb89220da564c28d5f905b2fef7b5261508d9..4ac9259e0335e00ca35481a812ba16df571877b0 100644 (file)
@@ -46,7 +46,7 @@ void ParallelBatch::run() {
     unsigned int e = s + chunk;
     if (e > end)
       e = end;
-    workers.push_back(std::thread(cb, s, e, user_data_prop->get()));
+    workers.push_back(std::thread(cb, s, e, i, user_data_prop->get()));
   }
 
   std::for_each(workers.begin(), workers.end(),
index ac1b6db56bc207038ce96b10b65dbcf523cfcfd4..2c490c17eabbfc19814da979dd48b1985ac5bc72 100644 (file)
@@ -20,7 +20,8 @@
 #include <nntrainer_error.h>
 #include <util_func.h>
 
-typedef void (*loop_cb)(unsigned int start, unsigned int end, void *user_data);
+typedef void (*loop_cb)(unsigned int start, unsigned int end, unsigned int pid,
+                        void *user_data);
 
 typedef std::function<std::remove_pointer<loop_cb>::type> threaded_cb;