namespace nntrainer {
-void grucell_forwarding(const unsigned int unit, const unsigned int batch_size,
- const bool disable_bias, const bool integrate_bias,
- const bool reset_after, ActiFunc &acti_func,
- ActiFunc &recurrent_acti_func, const Tensor &input,
- const Tensor &prev_hidden_state, Tensor &hidden_state,
- const Tensor &weight_ih, const Tensor &weight_hh,
- const Tensor &bias_h, const Tensor &bias_ih,
- const Tensor &bias_hh, Tensor &zrg) {
+/**
+ * @brief gru forwarding
+ *
+ */
+static void grucell_forwarding(
+ const unsigned int unit, const unsigned int batch_size,
+ const bool disable_bias, const bool integrate_bias, const bool reset_after,
+ ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
+ const Tensor &prev_hidden_state, Tensor &hidden_state,
+ const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h,
+ const Tensor &bias_ih, const Tensor &bias_hh, Tensor &zrg) {
input.dot(weight_ih, zrg);
Tensor update_reset_gate =
update_gate.multiply_strided(prev_hidden_state, hidden_state);
temp = update_gate.multiply(-1.0).add(1.0);
-
- hidden_state.add_i(memory_cell.multiply_strided(temp));
+ memory_cell.multiply_strided(temp, hidden_state, 1.0f);
}
-void grucell_calcGradient(
+/**
+ * @brief gru calcGradient
+ *
+ */
+static void grucell_calcGradient(
const unsigned int unit, const unsigned int batch_size,
const bool disable_bias, const bool integrate_bias, const bool reset_after,
ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
const Tensor &prev_hidden_state =
context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
- Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+ // hidden_state == output in grucell
+ Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT);
const unsigned int batch_size = input.getDim().batch();
Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
- Tensor hidden_state;
-
grucell_forwarding(unit, batch_size, disable_bias, integrate_bias,
reset_after, acti_func, recurrent_acti_func, input,
prev_hidden_state, hidden_state, weight_ih, weight_hh,
mask.dropout_mask(dropout_rate);
hidden_state.multiply_i(mask);
}
-
- output.copyData(hidden_state);
}
void GRUCellLayer::calcDerivative(RunLayerContext &context) {
}
}
- Tensor d_hidden_state(batch_size, 1, 1, unit);
- d_hidden_state.copyData(incoming_derivative);
-
+ Tensor incoming_derivative_masked(batch_size, 1, 1, unit);
if (dropout_rate > epsilon) {
- d_hidden_state.multiply_i(
- context.getTensor(wt_idx[GRUCellParams::dropout_mask]));
+ incoming_derivative.multiply_strided(
+ context.getTensor(wt_idx[GRUCellParams::dropout_mask]),
+ incoming_derivative_masked);
}
- grucell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
- reset_after, acti_func, recurrent_acti_func, input,
- prev_hidden_state, d_prev_hidden_state, d_hidden_state,
- d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
- bias_hh, d_bias_hh, zrg, d_zrg);
+ grucell_calcGradient(
+ unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func,
+ recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
+ dropout_rate > epsilon ? incoming_derivative_masked : incoming_derivative,
+ d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, bias_hh,
+ d_bias_hh, zrg, d_zrg);
}
void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
context.updateTensor(wt_idx[GRUCellParams::zrg], batch);
+
if (dropout_rate > epsilon) {
context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
}