1.0f, 0.0f, "hidden_state_zoneout_mask", false);
} else {
wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
- context.requestTensor(
- hidden_state_zoneout_mask_dim, "hidden_state_zoneout_mask",
- Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+ context.requestTensor(hidden_state_zoneout_mask_dim,
+ "hidden_state_zoneout_mask",
+ Tensor::Initializer::NONE, false,
+ TensorLifespan::ITERATION_LIFESPAN, false);
}
// cell_state_zoneout_mask_dim = [ max_timestep *
// batch_size, 1, 1, unit ]
} else {
wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
- Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
+ Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN,
+ false);
}
acti_func.setActiFunc(hidden_state_activation_type);
hidden_state.multiply_i(hidden_state_zoneout_mask);
prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
1.0f);
-
Tensor &cs_zoneout_mask =
test
? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
}
void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
- Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
- const Tensor &weight_ih =
- context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
Tensor &outgoing_derivative =
context.getOutgoingDerivative(INOUT_INDEX::INPUT);
+ const Tensor &weight_ih =
+ context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
+ const Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
}
}
Tensor d_prev_hidden_state_residual;
+ Tensor d_hidden_state_masked;
Tensor &hs_zoneout_mask =
test
? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
d_prev_hidden_state_residual);
- Tensor d_hidden_state_masked;
d_hidden_state.multiply(hidden_state_zoneout_mask, d_hidden_state_masked);
Tensor d_prev_cell_state_residual;
void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
unsigned int batch) {
const unsigned int max_timestep =
- std::get<props::MaxTimestep>(zoneout_lstmcell_props);
+ std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);