static constexpr size_t SINGLE_INOUT_IDX = 0;
enum GRUCellParams {
- weight_xh,
+ weight_ih,
weight_hh,
bias_h,
+ bias_ih,
+ bias_hh,
hidden_state,
zrg,
dropout_mask
};
-#define ENABLE_BIAS_IH 1
-// Todo: enable bias_hh
-#define ENABLE_BIAS_HH 0
-
// Todo: handle with strided tensor more efficiently and reduce temporary
// tensors
GRUCellLayer::GRUCellLayer() :
LayerImpl(),
- grucell_props(props::Unit(), props::HiddenStateActivation(),
- props::RecurrentActivation(), props::DropOutRate(),
+ grucell_props(props::Unit(),
+ props::HiddenStateActivation() = ActivationType::ACT_TANH,
+ props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
+ props::DropOutRate(), props::IntegrateBias(),
props::MaxTimestep(), props::Timestep()),
acti_func(ActivationType::ACT_NONE, true),
recurrent_acti_func(ActivationType::ACT_NONE, true),
wt_idx.fill(std::numeric_limits<unsigned>::max());
}
-// - weight_xh ( input to hidden )
-// : [1, 1, input_size, unit (hidden_size) x NUM_GATE] -> z, r, g
-// - weight_hh ( hidden to hidden )
-// : [1, 1, unit (hidden_size) , unit (hidden_size) x NUM_GATE] -> z, r, g
-// - bias_h ( hidden bias )
-// : [1, 1, 1, unit (hidden_size) x NUM_GATE] -> z, r, g
void GRUCellLayer::finalize(InitLayerContext &context) {
- auto &weight_regularizer =
- std::get<props::WeightRegularizer>(*layer_impl_props);
- auto &weight_regularizer_constant =
- std::get<props::WeightRegularizerConstant>(*layer_impl_props);
- auto &weight_initializer =
- std::get<props::WeightInitializer>(*layer_impl_props);
- auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
+ const Tensor::Initializer weight_initializer =
+ std::get<props::WeightInitializer>(*layer_impl_props).get();
+ const Tensor::Initializer bias_initializer =
+ std::get<props::BiasInitializer>(*layer_impl_props).get();
+ const WeightRegularizer weight_regularizer =
+ std::get<props::WeightRegularizer>(*layer_impl_props).get();
+ const float weight_regularizer_constant =
+ std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+ const bool disable_bias =
+ std::get<props::DisableBias>(*layer_impl_props).get();
const unsigned int unit = std::get<props::Unit>(grucell_props).get();
- auto &hidden_state_activation_type =
- std::get<props::HiddenStateActivation>(grucell_props);
- auto &recurrent_activation_type =
- std::get<props::RecurrentActivation>(grucell_props);
- const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
- const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
+ const ActivationType hidden_state_activation_type =
+ std::get<props::HiddenStateActivation>(grucell_props).get();
+ const ActivationType recurrent_activation_type =
+ std::get<props::RecurrentActivation>(grucell_props).get();
+ const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+ const bool integrate_bias =
+ std::get<props::IntegrateBias>(grucell_props).get();
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(grucell_props).get();
if (context.getNumInputs() != 1) {
throw std::invalid_argument("GRUCell layer takes only one input");
const unsigned int batch_size = input_dim.batch();
const unsigned int feature_size = input_dim.width();
- // output_dim = [ batch, 1, 1, hidden_size (unit)]
+ // output_dim = [ batch, 1, 1, unit ]
TensorDim output_dim(batch_size, 1, 1, unit);
context.setOutputDimensions({output_dim});
- if (dropout_rate > epsilon) {
- wt_idx[GRUCellParams::dropout_mask] = context.requestTensor(
- output_dim, "dropout_mask", Tensor::Initializer::NONE, false,
- TensorLifespan::ITERATION_LIFESPAN);
- }
-
- TensorDim weight_xh_dim({feature_size, NUM_GATE * unit});
- TensorDim weight_hh_dim({unit, NUM_GATE * unit});
- TensorDim bias_dim({NUM_GATE * unit});
-
- // weight_initializer can be set seperately. weight_xh initializer,
+ // weight_initializer can be set seperately. weight_ih initializer,
// weight_hh initializer kernel initializer & recurrent_initializer in keras
// for now, it is set same way.
- wt_idx[GRUCellParams::weight_xh] =
- context.requestWeight(weight_xh_dim, weight_initializer, weight_regularizer,
- weight_regularizer_constant, "weight_xh", true);
+
+ // - weight_ih ( input to hidden )
+ // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
+ TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
+ wt_idx[GRUCellParams::weight_ih] =
+ context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
+ weight_regularizer_constant, "weight_ih", true);
+ // - weight_hh ( hidden to hidden )
+ // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
+ TensorDim weight_hh_dim({unit, NUM_GATE * unit});
wt_idx[GRUCellParams::weight_hh] =
context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
weight_regularizer_constant, "weight_hh", true);
- wt_idx[GRUCellParams::bias_h] = context.requestWeight(
- bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias_h", true);
+ if (!disable_bias) {
+ if (integrate_bias) {
+ // - bias_h ( input bias, hidden bias are integrate to 1 bias )
+ // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+ TensorDim bias_h_dim({NUM_GATE * unit});
+ wt_idx[GRUCellParams::bias_h] =
+ context.requestWeight(bias_h_dim, bias_initializer,
+ WeightRegularizer::NONE, 1.0f, "bias_h", true);
+ } else {
+ // - bias_ih ( input bias )
+ // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+ TensorDim bias_ih_dim({NUM_GATE * unit});
+ wt_idx[GRUCellParams::bias_ih] =
+ context.requestWeight(bias_ih_dim, bias_initializer,
+ WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+ // - bias_hh ( hidden bias )
+ // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
+ TensorDim bias_hh_dim({NUM_GATE * unit});
+ wt_idx[GRUCellParams::bias_hh] =
+ context.requestWeight(bias_hh_dim, bias_initializer,
+ WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+ }
+ }
+ // hidden_state_dim = [ max_timestep * batch, 1, 1, unit ]
TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
wt_idx[GRUCellParams::hidden_state] = context.requestTensor(
hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN, false);
- TensorDim zrg_dim(max_timestep * batch_size, 1, 1, unit * NUM_GATE);
+ // zrg_dim = [ max_timestep * batch, 1, 1, NUM_GATE * unit ]
+ TensorDim zrg_dim(max_timestep * batch_size, 1, 1, NUM_GATE * unit);
wt_idx[GRUCellParams::zrg] =
context.requestTensor(zrg_dim, "zrg", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN, false);
- if (hidden_state_activation_type.get() == ActivationType::ACT_NONE) {
- hidden_state_activation_type.set(ActivationType::ACT_TANH);
+ if (dropout_rate > epsilon) {
+ // dropout_mask_dim = [ max_timestep * batch, 1, 1, unit ]
+ TensorDim dropout_mask_dim(max_timestep * batch_size, 1, 1, unit);
+ wt_idx[GRUCellParams::dropout_mask] = context.requestTensor(
+ dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+ TensorLifespan::ITERATION_LIFESPAN);
}
- acti_func.setActiFunc(hidden_state_activation_type.get());
- if (recurrent_activation_type.get() == ActivationType::ACT_NONE) {
- recurrent_activation_type.set(ActivationType::ACT_SIGMOID);
- }
- recurrent_acti_func.setActiFunc(recurrent_activation_type.get());
+ acti_func.setActiFunc(hidden_state_activation_type);
+ recurrent_acti_func.setActiFunc(recurrent_activation_type);
}
void GRUCellLayer::setProperty(const std::vector<std::string> &values) {
}
void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
- const unsigned int unit = std::get<props::Unit>(grucell_props).get();
- const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
- const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
- const unsigned int timestep = std::get<props::Timestep>(grucell_props);
+ const bool disable_bias =
+ std::get<props::DisableBias>(*layer_impl_props).get();
- Tensor &weight_xh = context.getWeight(wt_idx[GRUCellParams::weight_xh]);
- Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
- Tensor &bias_ih = context.getWeight(wt_idx[GRUCellParams::bias_h]);
+ const unsigned int unit = std::get<props::Unit>(grucell_props).get();
+ const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+ const bool integrate_bias =
+ std::get<props::IntegrateBias>(grucell_props).get();
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(grucell_props).get();
+ const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
Tensor &input = context.getInput(SINGLE_INOUT_IDX);
- Tensor &hidden_states =
- context.getTensor(wt_idx[GRUCellParams::hidden_state]);
- Tensor &zrg_gates = context.getTensor(wt_idx[GRUCellParams::zrg]);
- Tensor prev_hidden_state;
-
+ Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
const TensorDim &input_dim = input.getDim();
const unsigned int batch_size = input_dim.batch();
- hidden_states.reshape({max_timestep, 1, batch_size, unit});
- zrg_gates.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
+ Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
+ Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
+ Tensor empty;
+ Tensor &bias_h = !disable_bias && integrate_bias
+ ? context.getWeight(wt_idx[GRUCellParams::bias_h])
+ : empty;
+ Tensor &bias_ih = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[GRUCellParams::bias_ih])
+ : empty;
+ Tensor &bias_hh = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
+ : empty;
- Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
+ Tensor &hidden_states =
+ context.getTensor(wt_idx[GRUCellParams::hidden_state]);
+ hidden_states.reshape({max_timestep, 1, batch_size, unit});
+ Tensor prev_hidden_state;
if (!timestep) {
prev_hidden_state = Tensor(batch_size, unit);
prev_hidden_state.setZero();
} else {
prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
}
+ Tensor hidden_state = hidden_states.getBatchSlice(timestep, 1);
+
+ Tensor &zrg_gates = context.getTensor(wt_idx[GRUCellParams::zrg]);
+ zrg_gates.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
Tensor zrg_gate = zrg_gates.getBatchSlice(timestep, 1);
- input.dot(weight_xh, zrg_gate); // x_z, x_r, x_g
+ input.dot(weight_ih, zrg_gate); // x_z, x_r, x_g
Tensor zr_gate =
zrg_gate.getSharedDataTensor({batch_size, 2 * unit}, 0, false);
weight_hh_g.copy_with_stride(
weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
- if (timestep) {
- zr_gate.add_i_strided(prev_hidden_state.dot(weight_hh_zr));
+ zr_gate.add_i_strided(prev_hidden_state.dot(weight_hh_zr));
+ if (!disable_bias) {
+ if (integrate_bias) {
+ Tensor bias_h_zr = bias_h.getSharedDataTensor({2 * unit}, 0);
+ zr_gate.add_i(bias_h_zr);
+ } else {
+ Tensor bias_ih_zr = bias_ih.getSharedDataTensor({2 * unit}, 0);
+ zr_gate.add_i(bias_ih_zr);
+ Tensor bias_hh_zr = bias_hh.getSharedDataTensor({2 * unit}, 0);
+ zr_gate.add_i(bias_hh_zr);
+ }
}
- Tensor bias_ih_zr = bias_ih.getSharedDataTensor({2 * unit}, 0);
- zr_gate.add_i(bias_ih_zr);
recurrent_acti_func.run_fn(zr_gate, zr_gate);
Tensor r_gate = zr_gate.getSharedDataTensor({batch_size, unit}, unit, false);
Tensor temp;
- prev_hidden_state.dot(weight_hh_g, temp, false, false);
-#if ENABLE_BIAS_HH
- // Todo: fix this to get the bias_hh_g from bias_hh
- Tensor bias_hh_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
- temp.add_i(bias_hh_g);
-#endif
+ prev_hidden_state.dot(weight_hh_g, temp);
+ if (!disable_bias && !integrate_bias) {
+ Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
+ temp.add_i(bias_hh_g);
+ }
temp.multiply_i_strided(r_gate);
g_gate.add_i_strided(temp);
-#if ENABLE_BIAS_IH
- Tensor bias_ih_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
- g_gate.add_i(bias_ih_g);
-#endif
+ if (!disable_bias) {
+ if (integrate_bias) {
+ Tensor bias_h_g = bias_h.getSharedDataTensor({unit}, 2 * unit);
+ g_gate.add_i(bias_h_g);
+ } else {
+ Tensor bias_ih_g = bias_ih.getSharedDataTensor({unit}, 2 * unit);
+ g_gate.add_i(bias_ih_g);
+ }
+ }
acti_func.run_fn(g_gate, g_gate);
hidden_state.multiply_i(mask);
}
- Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
output.copy(hidden_state);
}
void GRUCellLayer::calcDerivative(RunLayerContext &context) {
const unsigned int unit = std::get<props::Unit>(grucell_props).get();
- const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
- const unsigned int timestep = std::get<props::Timestep>(grucell_props);
- const TensorDim &input_dim = context.getInput(SINGLE_INOUT_IDX).getDim();
- const unsigned int batch_size = input_dim.batch();
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(grucell_props).get();
+ const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
+
+ const unsigned int batch_size =
+ context.getInput(SINGLE_INOUT_IDX).getDim().batch();
Tensor &zrg_gates_derivatives =
context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
zrg_gates_derivatives.reshape({max_timestep, 1, batch_size, NUM_GATE * unit});
Tensor zrg_gate_derivative = zrg_gates_derivatives.getBatchSlice(timestep, 1);
- Tensor &weight_xh = context.getWeight(wt_idx[GRUCellParams::weight_xh]);
+ Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
- zrg_gate_derivative.dot(weight_xh, outgoing_derivative, false, true);
+ zrg_gate_derivative.dot(weight_ih, outgoing_derivative, false, true);
}
void GRUCellLayer::calcGradient(RunLayerContext &context) {
+ const bool disable_bias =
+ std::get<props::DisableBias>(*layer_impl_props).get();
+
const unsigned int unit = std::get<props::Unit>(grucell_props).get();
- const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
- const unsigned int max_timestep = std::get<props::MaxTimestep>(grucell_props);
- const unsigned int timestep = std::get<props::Timestep>(grucell_props);
+ const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
+ const bool integrate_bias =
+ std::get<props::IntegrateBias>(grucell_props).get();
+ const unsigned int max_timestep =
+ std::get<props::MaxTimestep>(grucell_props).get();
+ const unsigned int timestep = std::get<props::Timestep>(grucell_props).get();
Tensor &input = context.getInput(SINGLE_INOUT_IDX);
- const TensorDim &input_dim = input.getDim();
- const unsigned int batch_size = input_dim.batch();
+ const unsigned int batch_size = input.getDim().batch();
- Tensor &djdweight_xh =
- context.getWeightGrad(wt_idx[GRUCellParams::weight_xh]);
+ Tensor &djdweight_ih =
+ context.getWeightGrad(wt_idx[GRUCellParams::weight_ih]);
+ Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
Tensor &djdweight_hh =
context.getWeightGrad(wt_idx[GRUCellParams::weight_hh]);
- Tensor &djdbias_ih = context.getWeightGrad(wt_idx[GRUCellParams::bias_h]);
- Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
- Tensor &bias_h = context.getWeight(wt_idx[GRUCellParams::bias_h]);
- Tensor bias_h_g = bias_h.getSharedDataTensor({unit}, 2 * unit);
- Tensor djdw_zr_h =
+ Tensor empty;
+ Tensor &djdbias_h = !disable_bias && integrate_bias
+ ? context.getWeightGrad(wt_idx[GRUCellParams::bias_h])
+ : empty;
+ Tensor &djdbias_ih = !disable_bias && !integrate_bias
+ ? context.getWeightGrad(wt_idx[GRUCellParams::bias_ih])
+ : empty;
+ Tensor &bias_hh = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
+ : empty;
+ Tensor &djdbias_hh = !disable_bias && !integrate_bias
+ ? context.getWeightGrad(wt_idx[GRUCellParams::bias_hh])
+ : empty;
+
+ Tensor djdweight_hh_zr =
djdweight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
- Tensor djdw_g_h =
+ Tensor djdweight_hh_g =
djdweight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
Tensor &hidden_states =
context.getTensor(wt_idx[GRUCellParams::hidden_state]);
hidden_states_derivatives.reshape({max_timestep, 1, batch_size, unit});
Tensor hidden_state_derivative =
hidden_states_derivatives.getBatchSlice(timestep, 1);
- hidden_state_derivative.reshape(incoming_derivative.getDim());
if (timestep + 1 == max_timestep) {
- djdweight_xh.setZero();
+ djdweight_ih.setZero();
djdweight_hh.setZero();
- djdbias_ih.setZero();
- hidden_state_derivative.copyData(incoming_derivative);
- } else {
- hidden_state_derivative.add_i(incoming_derivative);
+ if (!disable_bias) {
+ if (integrate_bias) {
+ djdbias_h.setZero();
+ } else {
+ djdbias_ih.setZero();
+ djdbias_hh.setZero();
+ }
+ }
+ hidden_state_derivative.setZero();
}
- // restore the dimension
- hidden_state_derivative.reshape({1, 1, batch_size, unit});
- Tensor hs_prev;
+ hidden_state_derivative.reshape(
+ incoming_derivative.getDim()); // reshape to incoming_derivative dim
+ hidden_state_derivative.add_i(incoming_derivative);
+ hidden_state_derivative.reshape(
+ {1, 1, batch_size, unit}); // restore the dimension
+
+ Tensor prev_hidden_state;
Tensor dh_nx;
if (timestep) {
- hs_prev = hidden_states.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state = hidden_states.getBatchSlice(timestep - 1, 1);
dh_nx = hidden_states_derivatives.getBatchSlice(timestep - 1, 1);
} else {
dh_nx = Tensor(batch_size, unit);
- hs_prev = Tensor(batch_size, unit);
- hs_prev.setZero();
+ prev_hidden_state = Tensor(batch_size, unit);
+ prev_hidden_state.setZero();
}
if (dropout_rate > epsilon) {
Tensor rt = zrg_gate.getSharedDataTensor({batch_size, unit}, unit, false);
Tensor gt = zrg_gate.getSharedDataTensor({batch_size, unit}, unit * 2, false);
- hidden_state_derivative.multiply_strided(zt, dh_nx); // dh_nx = d1
- hidden_state_derivative.multiply_strided(hs_prev, dhz); // dhz = d2
+ hidden_state_derivative.multiply_strided(zt, dh_nx); // dh_nx = d1
+ hidden_state_derivative.multiply_strided(prev_hidden_state, dhz); // dhz = d2
dhz.add_i_strided(hidden_state_derivative.multiply_strided(gt),
-1.0f); // dhz = d5
zt.multiply(-1.0, dhg);
Tensor temp = Tensor(batch_size, unit);
Tensor dhg_;
dhg_.copy_with_stride(dhg);
- hs_prev.dot(wg_hh, temp);
-#if ENABLE_BIAS_HH
- temp.add_i(bias_h_g);
-#endif
+ prev_hidden_state.dot(wg_hh, temp);
+ if (!disable_bias && !integrate_bias) {
+ Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
+ temp.add_i(bias_hh_g);
+ }
dhg_.multiply_strided(temp, dhr); // dhr = d15
- // reset temp : hs_prev * rt for djdbias_hh_g and dh_nx
+ // reset temp : prev_hidden_state * rt for djdbias_hh_g and dh_nx
dhg_.multiply_strided(rt, temp);
-#if ENABLE_BIAS_HH
- // Todo: fix this to get the djdbias_hh_g from djdbias_hh
- Tensor djdbias_hh_g = djdbias_ih.getSharedDataTensor({unit}, 2 * unit);
- temp.sum(2, djdbias_hh_g, 1.0, 1.0);
-#endif
+ if (!disable_bias && !integrate_bias) {
+ Tensor djdbias_hh_g = djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
+ temp.sum(2, djdbias_hh_g, 1.0, 1.0);
+ }
temp.dot(wg_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14
recurrent_acti_func.run_prime_fn(rt, dhr, dhr); // dhr = d16
-#if ENABLE_BIAS_HH
- // Todo: fix this to get the djdbias_hh_zr from djdbias_hh
- Tensor djdbias_hh_zr = djdbias_ih.getSharedDataTensor({2 * unit}, 0);
- djdbias_hh_zr.add_i(
- zrg_gate_derivative.sum(2).getSharedDataTensor({2 * unit}, 0));
-#endif
-#if ENABLE_BIAS_IH
- zrg_gate_derivative.sum(2, djdbias_ih, 1.0, 1.0);
-#endif
+ if (!disable_bias) {
+ if (integrate_bias) {
+ zrg_gate_derivative.sum(2, djdbias_h, 1.0, 1.0);
+ } else {
+ zrg_gate_derivative.sum(2, djdbias_ih, 1.0, 1.0);
+ Tensor djdbias_hh_zr = djdbias_hh.getSharedDataTensor({2 * unit}, 0);
+ djdbias_hh_zr.add_i(
+ zrg_gate_derivative.sum(2).getSharedDataTensor({2 * unit}, 0));
+ }
+ }
- djdweight_xh.add_i(input.dot(zrg_gate_derivative, true, false));
+ djdweight_ih.add_i(input.dot(zrg_gate_derivative, true, false));
Tensor dhzr_;
dhzr_.copy_with_stride(dhzr);
- djdw_zr_h.add_i_strided(hs_prev.dot(dhzr_, true, false));
- djdw_g_h.add_i_strided(hs_prev.dot(temp, true, false));
+ djdweight_hh_zr.add_i_strided(prev_hidden_state.dot(dhzr_, true, false));
+ djdweight_hh_g.add_i_strided(prev_hidden_state.dot(temp, true, false));
dhzr_.dot(wzr_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14 + d12 + d17
}
const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
if (dropout_rate > epsilon) {
- context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
+ context.updateTensor(wt_idx[GRUCellParams::dropout_mask],
+ max_timestep * batch);
}
}