#include <layer_context.h>
#include <lstm.h>
#include <lstmcell_core.h>
+#include <nntr_threads.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
d_hidden_state_.multiply_i(mask_);
}
- for (unsigned int batch = 0; batch < batch_size; ++batch) {
- const Tensor input_sample = input_.getBatchSlice(batch, 1);
+ auto workers = ParallelBatch(batch_size);
+
+ if (workers.getNumWorkers() > 1) {
+
+ TensorDim weight_ih_d = d_weight_ih.getDim();
+ TensorDim weight_hh_d = d_weight_hh.getDim();
+
+ TensorDim bias_ih_d = d_bias_ih.getDim();
+ TensorDim bias_hh_d = d_bias_hh.getDim();
+ TensorDim bias_h_d = d_bias_h.getDim();
+
+ weight_ih_d.batch(workers.getNumWorkers());
+ weight_hh_d.batch(workers.getNumWorkers());
+ bias_ih_d.batch(workers.getNumWorkers());
+ bias_hh_d.batch(workers.getNumWorkers());
+ bias_h_d.batch(workers.getNumWorkers());
+
+ Tensor sub_d_weight_ih = Tensor(weight_ih_d);
+ Tensor sub_d_weight_hh = Tensor(weight_hh_d);
+ Tensor sub_d_bias_ih = Tensor(bias_ih_d);
+ Tensor sub_d_bias_hh = Tensor(bias_hh_d);
+ Tensor sub_d_bias_h = Tensor(bias_h_d);
+
+ sub_d_weight_ih.setZero();
+ sub_d_weight_hh.setZero();
+ sub_d_bias_ih.setZero();
+ sub_d_bias_hh.setZero();
+ sub_d_bias_h.setZero();
+
+ auto batch_job = [&](unsigned int s, unsigned int e, unsigned int pid,
+ void *user_data) {
+ for (unsigned int batch = s; batch < e; ++batch) {
+ const Tensor input_sample = input_.getBatchSlice(batch, 1);
+
+ const Tensor hidden_state_sample =
+ hidden_state_.getBatchSlice(batch, 1);
+ Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
+ const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+ Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
+
+ const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+ Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
+
+ Tensor input;
+ Tensor prev_hidden_state;
+ Tensor d_prev_hidden_state;
+ Tensor prev_cell_state;
+ Tensor d_prev_cell_state;
+ Tensor d_hidden_state;
+ Tensor cell_state;
+ Tensor d_cell_state;
+
+ Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(pid, 1);
+ Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(pid, 1);
+ Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(pid, 1);
+ Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(pid, 1);
+ Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(pid, 1);
+
+ for (int t = max_timestep - 1; t > -1; t--) {
+ input = input_sample.getSharedDataTensor(
+ {feature_size},
+ (reverse ? max_timestep - 1 - t : t) * feature_size);
+
+ if (!t) {
+ prev_hidden_state = Tensor(unit);
+ prev_hidden_state.setZero();
+ d_prev_hidden_state = Tensor(unit);
+ d_prev_hidden_state.setZero();
+ } else {
+ prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ }
+ d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+ if (!t) {
+ prev_cell_state = Tensor(unit);
+ prev_cell_state.setZero();
+ d_prev_cell_state = Tensor(unit);
+ d_prev_cell_state.setZero();
+ } else {
+ prev_cell_state = cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ }
+ cell_state = cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+ d_cell_state = d_cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+ Tensor ifgo = ifgo_sample.getSharedDataTensor(
+ {NUM_GATE * unit},
+ (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+ Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
+ {NUM_GATE * unit},
+ (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+
+ // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
+ // already have precalculated values from incomming derivatives
+ Tensor d_prev_hidden_state_temp;
+
+ lstmcell_calcGradient(
+ 1, unit, disable_bias, integrate_bias, acti_func,
+ recurrent_acti_func, input, prev_hidden_state,
+ d_prev_hidden_state_temp, prev_cell_state, d_prev_cell_state,
+ d_hidden_state, cell_state, d_cell_state, p_d_weight_ih, weight_hh,
+ p_d_weight_hh, p_d_bias_h, p_d_bias_ih, p_d_bias_hh, ifgo, d_ifgo);
+
+ d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+ }
+ }
+ };
- const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
- Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
- const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
- Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
-
- const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
- Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
-
- Tensor input;
- Tensor prev_hidden_state;
- Tensor d_prev_hidden_state;
- Tensor prev_cell_state;
- Tensor d_prev_cell_state;
- Tensor d_hidden_state;
- Tensor cell_state;
- Tensor d_cell_state;
-
- for (int t = max_timestep - 1; t > -1; t--) {
- input = input_sample.getSharedDataTensor(
- {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+ workers.setCallback(batch_job, nullptr);
+ workers.run();
- if (!t) {
- prev_hidden_state = Tensor(unit);
- prev_hidden_state.setZero();
- d_prev_hidden_state = Tensor(unit);
- d_prev_hidden_state.setZero();
- } else {
- prev_hidden_state = hidden_state_sample.getSharedDataTensor(
- {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
- d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
- {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
- }
- d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
- {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+ for (unsigned int b = 0; b < workers.getNumWorkers(); ++b) {
- if (!t) {
- prev_cell_state = Tensor(unit);
- prev_cell_state.setZero();
- d_prev_cell_state = Tensor(unit);
- d_prev_cell_state.setZero();
- } else {
- prev_cell_state = cell_state_sample.getSharedDataTensor(
- {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
- d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
- {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
- }
- cell_state = cell_state_sample.getSharedDataTensor(
- {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
- d_cell_state = d_cell_state_sample.getSharedDataTensor(
- {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+ Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(b, 1);
+ Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(b, 1);
+ Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(b, 1);
+ Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(b, 1);
+ Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(b, 1);
- Tensor ifgo = ifgo_sample.getSharedDataTensor(
- {NUM_GATE * unit},
- (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
- Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
- {NUM_GATE * unit},
- (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+ d_weight_ih.add_i(p_d_weight_ih);
+ d_weight_hh.add_i(p_d_weight_hh);
+ d_bias_ih.add_i(p_d_bias_ih);
+ d_bias_hh.add_i(p_d_bias_hh);
+ d_bias_h.add_i(p_d_bias_h);
+ }
- // Temporary variable for d_prev_hidden_state. d_prev_hidden_state already
- // have precalculated values from incomming derivatives
- Tensor d_prev_hidden_state_temp;
-
- lstmcell_calcGradient(1, unit, disable_bias, integrate_bias, acti_func,
- recurrent_acti_func, input, prev_hidden_state,
- d_prev_hidden_state_temp, prev_cell_state,
- d_prev_cell_state, d_hidden_state, cell_state,
- d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
- d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
- d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+ } else {
+ for (unsigned int batch = 0; batch < batch_size; ++batch) {
+ const Tensor input_sample = input_.getBatchSlice(batch, 1);
+
+ const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
+ Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
+ const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
+ Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
+
+ const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
+ Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
+
+ Tensor input;
+ Tensor prev_hidden_state;
+ Tensor d_prev_hidden_state;
+ Tensor prev_cell_state;
+ Tensor d_prev_cell_state;
+ Tensor d_hidden_state;
+ Tensor cell_state;
+ Tensor d_cell_state;
+
+ for (int t = max_timestep - 1; t > -1; t--) {
+ input = input_sample.getSharedDataTensor(
+ {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size);
+
+ if (!t) {
+ prev_hidden_state = Tensor(unit);
+ prev_hidden_state.setZero();
+ d_prev_hidden_state = Tensor(unit);
+ d_prev_hidden_state.setZero();
+ } else {
+ prev_hidden_state = hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ }
+ d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+ if (!t) {
+ prev_cell_state = Tensor(unit);
+ prev_cell_state.setZero();
+ d_prev_cell_state = Tensor(unit);
+ d_prev_cell_state.setZero();
+ } else {
+ prev_cell_state = cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit);
+ }
+ cell_state = cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+ d_cell_state = d_cell_state_sample.getSharedDataTensor(
+ {unit}, (reverse ? max_timestep - 1 - t : t) * unit);
+
+ Tensor ifgo = ifgo_sample.getSharedDataTensor(
+ {NUM_GATE * unit},
+ (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+ Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
+ {NUM_GATE * unit},
+ (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
+
+ // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
+ // already have precalculated values from incomming derivatives
+ Tensor d_prev_hidden_state_temp;
+
+ lstmcell_calcGradient(1, unit, disable_bias, integrate_bias, acti_func,
+ recurrent_acti_func, input, prev_hidden_state,
+ d_prev_hidden_state_temp, prev_cell_state,
+ d_prev_cell_state, d_hidden_state, cell_state,
+ d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
+ d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+ d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
+ }
}
}
}