[lstmcell core] prepare refactoring lstmcell core layer
authorhyeonseok lee <hs89.lee@samsung.com>
Sat, 18 Dec 2021 00:29:31 +0000 (09:29 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 29 Dec 2021 06:20:14 +0000 (15:20 +0900)
 - LSTM cell core layer will be refactoring from layer based to function based.
   This commit prepare the core functions which is not used right now.

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/lstmcell.cpp
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/lstmcell_core.h
nntrainer/layers/zoneout_lstmcell.cpp
nntrainer/layers/zoneout_lstmcell.h

index 62b4530becfd0b616cc55a4c606be79fccbeeb1a..c059a133d7347b6319dafe44811d6efbd3725221 100644 (file)
@@ -6,8 +6,6 @@
  * @date   17 March 2021
  * @brief  This is LSTMCell Layer Class of Neural Network
  * @see    https://github.com/nnstreamer/nntrainer
- *         https://arxiv.org/pdf/1606.01305.pdf
- *         https://github.com/teganmaharaj/zoneout
  * @author Parichay Kapoor <pk.kapoor@samsung.com>
  * @bug    No known bugs except for NYI items
  *
index f8fe882ea70946418fd19ed11d9a715a17273784..9bf6fe73d26e7a2d344a38b2224292e7b55a9c3d 100644 (file)
@@ -626,4 +626,116 @@ void LSTMCellCoreLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   context.updateTensor(wt_idx[LSTMCellCoreParams::ifgo], batch);
 }
 
+void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
+                         const bool disable_bias, const bool integrate_bias,
+                         ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
+                         const Tensor &input, const Tensor &prev_hidden_state,
+                         const Tensor &prev_cell_state, Tensor &hidden_state,
+                         Tensor &cell_state, const Tensor &weight_ih,
+                         const Tensor &weight_hh, const Tensor &bias_h,
+                         const Tensor &bias_ih, const Tensor &bias_hh,
+                         Tensor &ifgo) {
+  input.dot(weight_ih, ifgo);
+  prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
+  if (!disable_bias) {
+    if (integrate_bias) {
+      ifgo.add_i(bias_h);
+    } else {
+      ifgo.add_i(bias_ih);
+      ifgo.add_i(bias_hh);
+    }
+  }
+
+  Tensor input_forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor input_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor output_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+  recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
+  recurrent_acti_func.run_fn(output_gate, output_gate);
+  acti_func.run_fn(memory_cell, memory_cell);
+
+  prev_cell_state.multiply_strided(forget_gate, cell_state);
+  memory_cell.multiply_strided(input_gate, cell_state, 1.0f);
+
+  acti_func.run_fn(cell_state, hidden_state);
+  hidden_state.multiply_i_strided(output_gate);
+}
+
+void lstmcell_calcDerivative(const Tensor &d_ifgo, const Tensor &weight_ih,
+                             Tensor &outgoing_derivative) {
+  d_ifgo.dot(weight_ih, outgoing_derivative, false, true);
+}
+
+void lstmcell_calcGradient(
+  const unsigned int unit, const unsigned int batch_size,
+  const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
+  ActiFunc &recurrent_acti_func, const Tensor &input,
+  const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
+  const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
+  Tensor &d_hidden_state, const Tensor &cell_state, Tensor &d_cell_state,
+  Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
+  Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
+  Tensor &d_ifgo) {
+  Tensor input_forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor input_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor forget_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor memory_cell =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor output_gate =
+    ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+  Tensor d_input_forget_gate =
+    d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
+  Tensor d_input_gate =
+    d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
+  Tensor d_forget_gate =
+    d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
+  Tensor d_memory_cell =
+    d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 2, false);
+  Tensor d_output_gate =
+    d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit}, unit * 3, false);
+
+  Tensor activated_cell_state;
+  acti_func.run_fn(cell_state, activated_cell_state);
+  d_hidden_state.multiply_strided(activated_cell_state, d_output_gate);
+  acti_func.run_prime_fn(activated_cell_state, d_prev_cell_state,
+                         d_hidden_state);
+  d_prev_cell_state.multiply_i_strided(output_gate);
+  d_prev_cell_state.add_i(d_cell_state);
+
+  d_prev_cell_state.multiply_strided(input_gate, d_memory_cell);
+  d_prev_cell_state.multiply_strided(memory_cell, d_input_gate);
+
+  d_prev_cell_state.multiply_strided(prev_cell_state, d_forget_gate);
+  d_prev_cell_state.multiply_i_strided(forget_gate);
+
+  recurrent_acti_func.run_prime_fn(output_gate, d_output_gate, d_output_gate);
+  recurrent_acti_func.run_prime_fn(input_forget_gate, d_input_forget_gate,
+                                   d_input_forget_gate);
+  acti_func.run_prime_fn(memory_cell, d_memory_cell, d_memory_cell);
+
+  if (!disable_bias) {
+    if (integrate_bias) {
+      d_ifgo.sum(0, d_bias_h, 1.0f, 1.0f);
+    } else {
+      d_ifgo.sum(0, d_bias_ih, 1.0f, 1.0f);
+      d_ifgo.sum(0, d_bias_hh, 1.0f, 1.0f);
+    }
+  }
+
+  input.dot(d_ifgo, d_weight_ih, true, false, 1.0f);
+  prev_hidden_state.dot(d_ifgo, d_weight_hh, true, false, 1.0f);
+  d_ifgo.dot(weight_hh, d_prev_hidden_state, false, true);
+}
+
 } // namespace nntrainer
index 4160f987f70d5d68beb3418480fbc718c006292e..4f02398ba2d3f2cfbb448fc8a0cf5678f9250c53 100644 (file)
@@ -133,6 +133,29 @@ private:
    */
   ActiFunc recurrent_acti_func;
 };
+
+void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
+                         const bool disable_bias, const bool integrate_bias,
+                         ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
+                         const Tensor &input, const Tensor &prev_hidden_state,
+                         const Tensor &prev_cell_state, Tensor &hidden_state,
+                         Tensor &cell_state, const Tensor &weight_ih,
+                         const Tensor &weight_hh, const Tensor &bias_h,
+                         const Tensor &bias_ih, const Tensor &bias_hh,
+                         Tensor &ifgo);
+void lstmcell_calcDerivative(const Tensor &d_ifgo, const Tensor &weight_ih,
+                             Tensor &outgoing_derivative);
+void lstmcell_calcGradient(
+  const unsigned int unit, const unsigned int batch_size,
+  const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
+  ActiFunc &recurrent_acti_func, const Tensor &input,
+  const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
+  const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
+  Tensor &d_hidden_state, const Tensor &cell_state, Tensor &d_cell_state,
+  Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
+  Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
+  Tensor &d_ifgo);
+
 } // namespace nntrainer
 
 #endif /* __cplusplus */
index 5bc49bd7f8fafea817464fb14f5ac35e7acae57c..5c0c3d69a11b3ab352b620877a6ea24eff1b74ed 100644 (file)
@@ -6,6 +6,8 @@
  * @date   30 November 2021
  * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
  * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/pdf/1606.01305.pdf
+ *         https://github.com/teganmaharaj/zoneout
  * @author hyeonseok lee <hs89.lee@samsung.com>
  * @bug    No known bugs except for NYI items
  *
index dd2996c03e58c3281564246b0d37b7984651e390..3895b83065bfcb00b4ec4b2f733c1a9cb1a077c0 100644 (file)
@@ -6,6 +6,8 @@
  * @date   30 November 2021
  * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
  * @see           https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/pdf/1606.01305.pdf
+ *         https://github.com/teganmaharaj/zoneout
  * @author hyeonseok lee <hs89.lee@samsung.com>
  * @bug    No known bugs except for NYI items
  *