std::invalid_argument)
<< "unit property missing for zoneout_lstmcell layer";
const unsigned int unit = std::get<props::Unit>(zoneout_lstmcell_props).get();
- const float hidden_state_zoneout_rate =
- std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
- const float cell_state_zoneout_rate =
- std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
const bool test = std::get<Test>(zoneout_lstmcell_props);
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props);
context.requestWeight(hidden_state_zoneout_mask_dim,
Tensor::Initializer::NONE, WeightRegularizer::NONE,
1.0f, "hidden_state_zoneout_mask", false);
- } else if (hidden_state_zoneout_rate > epsilon) {
+ } else {
wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
context.requestTensor(
hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
cell_state_zoneout_mask_dim, Tensor::Initializer::NONE,
WeightRegularizer::NONE, 1.0f, "cell_state_zoneout_mask", false);
- } else if (cell_state_zoneout_rate > epsilon) {
+ } else {
wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
init_lstm_context::getTensors(tensors));
lstmcellcorelayer.forwarding(core_context, training);
- if (hidden_state_zoneout_rate > epsilon) {
- if (training) {
- Tensor &hidden_state_zoneout_mask =
- test ? context.getWeight(
- wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
- : context.getTensor(
- wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
- hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_zoneout_mask =
- hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
- Tensor prev_hidden_state_zoneout_mask;
- if (!test) {
- prev_hidden_state_zoneout_mask =
- next_hidden_state_zoneout_mask.zoneout_mask(
- hidden_state_zoneout_rate);
- } else {
- next_hidden_state_zoneout_mask.multiply(-1.0f,
- prev_hidden_state_zoneout_mask);
- prev_hidden_state_zoneout_mask.add_i(1.0f);
- }
-
- Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
- hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_origin =
- hidden_state_origin.getBatchSlice(timestep, 1);
- next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
-
- next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
- next_hidden_state);
- prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
- next_hidden_state, 1.0f);
+ if (training) {
+ Tensor &hidden_state_zoneout_mask =
+ test ? context.getWeight(
+ wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+ : context.getTensor(
+ wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+ hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state_zoneout_mask =
+ hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+ next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ Tensor prev_hidden_state_zoneout_mask;
+ if (!test) {
+ prev_hidden_state_zoneout_mask =
+ next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+ } else {
+ next_hidden_state_zoneout_mask.multiply(-1.0f,
+ prev_hidden_state_zoneout_mask);
+ prev_hidden_state_zoneout_mask.add_i(1.0f);
}
- // Todo: zoneout at inference
+
+ Tensor &hidden_state_origin = context.getTensor(hidden_state_origin_idx);
+ hidden_state_origin.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state_origin =
+ hidden_state_origin.getBatchSlice(timestep, 1);
+ next_hidden_state_origin.reshape({batch_size, 1, 1, unit});
+
+ next_hidden_state_origin.multiply(next_hidden_state_zoneout_mask,
+ next_hidden_state);
+ prev_hidden_state.multiply(prev_hidden_state_zoneout_mask,
+ next_hidden_state, 1.0f);
}
- if (cell_state_zoneout_rate > epsilon) {
- if (training) {
- Tensor &cell_state_zoneout_mask =
- test ? context.getWeight(
- wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
- : context.getTensor(
- wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
- cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_zoneout_mask =
- cell_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
- Tensor prev_cell_state_zoneout_mask;
- if (!test) {
- prev_cell_state_zoneout_mask =
- next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
- } else {
- next_cell_state_zoneout_mask.multiply(-1.0f,
- prev_cell_state_zoneout_mask);
- prev_cell_state_zoneout_mask.add_i(1.0f);
- }
-
- Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
- cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_origin =
- cell_state_origin.getBatchSlice(timestep, 1);
- next_cell_state_origin.reshape({batch_size, 1, 1, unit});
-
- next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
- next_cell_state);
- prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
- 1.0f);
+
+ if (training) {
+ Tensor &cell_state_zoneout_mask =
+ test
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+ : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+ cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state_zoneout_mask =
+ cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+ next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ Tensor prev_cell_state_zoneout_mask;
+ if (!test) {
+ prev_cell_state_zoneout_mask =
+ next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+ } else {
+ next_cell_state_zoneout_mask.multiply(-1.0f,
+ prev_cell_state_zoneout_mask);
+ prev_cell_state_zoneout_mask.add_i(1.0f);
}
- // Todo: zoneout at inference
+
+ Tensor &cell_state_origin = context.getTensor(cell_state_origin_idx);
+ cell_state_origin.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state_origin =
+ cell_state_origin.getBatchSlice(timestep, 1);
+ next_cell_state_origin.reshape({batch_size, 1, 1, unit});
+
+ next_cell_state_origin.multiply(next_cell_state_zoneout_mask,
+ next_cell_state);
+ prev_cell_state.multiply(prev_cell_state_zoneout_mask, next_cell_state,
+ 1.0f);
}
+ // Todo: zoneout at inference
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
output.copyData(next_hidden_state);
Tensor prev_cell_state_derivative;
Tensor prev_hidden_state_derivative_residual;
Tensor prev_cell_state_derivative_residual;
- if (hidden_state_zoneout_rate > epsilon) {
- Tensor &hidden_state_zoneout_mask =
- test ? context.getWeight(
- wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
- : context.getTensor(
- wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
- hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_zoneout_mask =
- hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
- Tensor prev_hidden_state_zoneout_mask;
- if (!test) {
- prev_hidden_state_zoneout_mask =
- next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
- } else {
- next_hidden_state_zoneout_mask.multiply(-1.0f,
- prev_hidden_state_zoneout_mask);
- prev_hidden_state_zoneout_mask.add_i(1.0f);
- }
- if (timestep) {
- prev_hidden_state_derivative =
- hidden_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
- next_hidden_state_derivative.multiply(
- prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
- }
+ Tensor &hidden_state_zoneout_mask =
+ test
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
+ : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
+ hidden_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state_zoneout_mask =
+ hidden_state_zoneout_mask.getBatchSlice(timestep, 1);
+ next_hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ Tensor prev_hidden_state_zoneout_mask;
+ if (!test) {
+ prev_hidden_state_zoneout_mask =
+ next_hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
+ } else {
+ next_hidden_state_zoneout_mask.multiply(-1.0f,
+ prev_hidden_state_zoneout_mask);
+ prev_hidden_state_zoneout_mask.add_i(1.0f);
+ }
- Tensor &hidden_state_origin_derivative =
- context.getTensorGrad(hidden_state_origin_idx);
- hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_hidden_state_origin_derivative =
- hidden_state_origin_derivative.getBatchSlice(timestep, 1);
- next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+ if (timestep) {
+ prev_hidden_state_derivative =
+ hidden_state_derivative.getBatchSlice(timestep - 1, 1);
+ prev_hidden_state_derivative.reshape({batch_size, 1, 1, unit});
+ next_hidden_state_derivative.multiply(
+ prev_hidden_state_zoneout_mask, prev_hidden_state_derivative_residual);
+ }
- next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
- next_hidden_state_origin_derivative);
+ Tensor &hidden_state_origin_derivative =
+ context.getTensorGrad(hidden_state_origin_idx);
+ hidden_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_hidden_state_origin_derivative =
+ hidden_state_origin_derivative.getBatchSlice(timestep, 1);
+ next_hidden_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+
+ next_hidden_state_derivative.multiply(next_hidden_state_zoneout_mask,
+ next_hidden_state_origin_derivative);
+
+ Tensor &cell_state_zoneout_mask =
+ test
+ ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
+ : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
+ cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state_zoneout_mask =
+ cell_state_zoneout_mask.getBatchSlice(timestep, 1);
+ next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
+ Tensor prev_cell_state_zoneout_mask;
+ if (!test) {
+ prev_cell_state_zoneout_mask =
+ next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
+ } else {
+ next_cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
+ prev_cell_state_zoneout_mask.add_i(1.0f);
}
- if (cell_state_zoneout_rate > epsilon) {
- Tensor &cell_state_zoneout_mask =
- test
- ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
- : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
- cell_state_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_zoneout_mask =
- cell_state_zoneout_mask.getBatchSlice(timestep, 1);
- next_cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
- Tensor prev_cell_state_zoneout_mask;
- if (!test) {
- prev_cell_state_zoneout_mask =
- next_cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
- } else {
- next_cell_state_zoneout_mask.multiply(-1.0f,
- prev_cell_state_zoneout_mask);
- prev_cell_state_zoneout_mask.add_i(1.0f);
- }
- if (timestep) {
- prev_cell_state_derivative =
- cell_state_derivative.getBatchSlice(timestep - 1, 1);
- prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
- next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
- prev_cell_state_derivative_residual);
- }
+ if (timestep) {
+ prev_cell_state_derivative =
+ cell_state_derivative.getBatchSlice(timestep - 1, 1);
+ prev_cell_state_derivative.reshape({batch_size, 1, 1, unit});
+ next_cell_state_derivative.multiply(prev_cell_state_zoneout_mask,
+ prev_cell_state_derivative_residual);
+ }
- Tensor &cell_state_origin_derivative =
- context.getTensorGrad(cell_state_origin_idx);
- cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
- Tensor next_cell_state_origin_derivative =
- cell_state_origin_derivative.getBatchSlice(timestep, 1);
- next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
+ Tensor &cell_state_origin_derivative =
+ context.getTensorGrad(cell_state_origin_idx);
+ cell_state_origin_derivative.reshape({max_timestep, 1, batch_size, unit});
+ Tensor next_cell_state_origin_derivative =
+ cell_state_origin_derivative.getBatchSlice(timestep, 1);
+ next_cell_state_origin_derivative.reshape({batch_size, 1, 1, unit});
- next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
- next_cell_state_origin_derivative);
- }
+ next_cell_state_derivative.multiply(next_cell_state_zoneout_mask,
+ next_cell_state_origin_derivative);
init_lstm_context::fillWeights(weights, context, true, max_timestep, timestep,
test);
lstmcellcorelayer.calcGradient(core_context);
if (timestep) {
- if (hidden_state_zoneout_rate > epsilon) {
- prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
- }
- if (cell_state_zoneout_rate > epsilon) {
- prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
- }
+ prev_hidden_state_derivative.add_i(prev_hidden_state_derivative_residual);
+ prev_cell_state_derivative.add_i(prev_cell_state_derivative_residual);
}
}
void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
unsigned int batch) {
- const float hidden_state_zoneout_rate =
- std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props);
- const float cell_state_zoneout_rate =
- std::get<CellStateZoneOutRate>(zoneout_lstmcell_props);
- const bool test = std::get<Test>(zoneout_lstmcell_props);
const unsigned int max_timestep =
std::get<props::MaxTimestep>(zoneout_lstmcell_props);
context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state],
context.updateTensor(cell_state_origin_idx, max_timestep * batch);
context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], max_timestep * batch);
- if (hidden_state_zoneout_rate > epsilon && !test) {
- context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
- max_timestep * batch);
- }
- if (cell_state_zoneout_rate > epsilon && !test) {
- context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
- max_timestep * batch);
- }
+ context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
+ max_timestep * batch);
+ context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
+ max_timestep * batch);
}
} // namespace nntrainer