prob_left,
prob_right,
u_neg_div,
- u_pos_div
+ u_pos_div,
+ dstate
};
void MoLAttentionLayer::finalize(InitLayerContext &context) {
wt_idx[AttentionParams::u_pos_div] =
context.requestTensor(prob_dim, "u_pos_div", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
+ wt_idx[AttentionParams::dstate] =
+ context.requestTensor(state_dim, "dstate", Tensor::Initializer::NONE, false,
+ TensorLifespan::BACKWARD_FUNC_LIFESPAN);
context.setOutputDimensions({query_dim, state_dim});
}
context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
Tensor &dstate =
context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+ Tensor &dstate_local = context.getTensor(wt_idx[AttentionParams::dstate]);
Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
if (!helper_exec)
calcDerivativeHelper(context, dstate);
+ else
+ dstate.copyData(dstate_local);
Tensor dfc_tanh = Tensor(fc_out.getDim());
dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
- Tensor &dstate =
- context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+ Tensor &dstate = context.getTensor(wt_idx[AttentionParams::dstate]);
Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
Tensor &dfc_w = context.getWeightGrad(wt_idx[AttentionParams::fc_w]);
ActiFunc softmax; /** softmax activation operation */
ActiFunc tanh; /** softmax activation operation */
ActiFunc sigmoid; /** softmax activation operation */
- std::array<unsigned int, 15>
+ std::array<unsigned int, 16>
wt_idx; /**< indices of the weights and tensors */
/**