namespace nntrainer {
-MoLAttentionLayer::MoLAttentionLayer() : helper_exec(false), wt_idx({0}) {}
+MoLAttentionLayer::MoLAttentionLayer() :
+ helper_exec(false),
+ softmax(ActivationType::ACT_SOFTMAX, false),
+ tanh(ActivationType::ACT_TANH, false),
+ sigmoid(ActivationType::ACT_SIGMOID, false),
+ wt_idx({std::numeric_limits<unsigned>::max()}) {}
MoLAttentionLayer::~MoLAttentionLayer() {}
prob_right,
u_neg_div,
u_pos_div,
- dstate
+ dstate,
+ updated_state
};
void MoLAttentionLayer::finalize(InitLayerContext &context) {
wt_idx[AttentionParams::state] = AttentionParams::state;
wt_idx[AttentionParams::mask_len] = AttentionParams::mask_len;
- softmax.setActiFunc(ActivationType::ACT_SOFTMAX);
- tanh.setActiFunc(ActivationType::ACT_TANH);
- sigmoid.setActiFunc(ActivationType::ACT_SIGMOID);
-
NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
<< "Query and Value dimension mismatch for layer " << context.getName();
context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
+ wt_idx[AttentionParams::fc_tanh] =
+ context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE,
+ false, TensorLifespan::ITERATION_LIFESPAN);
+
TensorDim fc_proj_out_dim = fc_out_dim;
fc_proj_out_dim.width(fc_proj_w_dim.width());
wt_idx[AttentionParams::fc_proj_out] = context.requestTensor(
context.requestTensor(state_dim, "dstate", Tensor::Initializer::NONE, false,
TensorLifespan::BACKWARD_FUNC_LIFESPAN);
- context.setOutputDimensions({query_dim, state_dim});
+ if (context.getNumOutputs() == 2)
+ context.setOutputDimensions({query_dim, state_dim});
+ else
+ context.setOutputDimensions({query_dim});
}
void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
Tensor &state = context.getInput(wt_idx[AttentionParams::state]);
Tensor &output = context.getOutput(0);
- Tensor &updated_state = context.getOutput(1);
Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
Tensor &fc_bias = context.getWeight(wt_idx[AttentionParams::fc_bias]);
Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
.copy_with_stride(alpha);
- state.add(kappa, updated_state);
-
/** @todo cache u_base, u_pos, u_neg */
Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
for (unsigned int b = 0; b < batch; b++) {
Tensor beta_eps = beta.add(1e-8f);
- Tensor u_pos_m = u_pos.subtract(updated_state);
+ Tensor u_pos_m, u_neg_m;
+ if (context.getNumOutputs() == 2) {
+ Tensor &updated_state = context.getOutput(1);
+ state.add(kappa, updated_state);
+ u_pos_m = u_pos.subtract(updated_state);
+ u_neg_m = u_neg.subtract(updated_state);
+ } else {
+ Tensor updated_state = state.add(kappa);
+ u_pos_m = u_pos.subtract(updated_state);
+ u_neg_m = u_neg.subtract(updated_state);
+ }
+
u_pos_m.divide(beta_eps, u_pos_div);
sigmoid.run_fn(u_pos_div, prob_left);
- Tensor u_neg_m = u_neg.subtract(updated_state);
u_neg_m.divide(beta_eps, u_neg_div);
sigmoid.run_fn(u_neg_div, prob_right);
Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
Tensor &derivative = context.getIncomingDerivative(0);
- Tensor &derivative_state = context.getIncomingDerivative(1);
Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
Tensor &dfc_proj_out =
Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
dm_neg.add(dm_pos, dstate);
- dstate.add_i(derivative_state);
+ if (context.getNumOutputs() == 2) {
+ Tensor &derivative_state = context.getIncomingDerivative(1);
+ dstate.add_i(derivative_state);
+ }
Tensor dkappa = dstate;
Tensor dbeta = dbeta_eps;
context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]);
- scores.dot_batched_deriv_wrt_1(dvalue, derivative);
+ scores.dot_batched_deriv_wrt_2(dvalue, derivative);
if (!helper_exec)
calcDerivativeHelper(context, dstate);
dstate.copyData(dstate_local);
Tensor dfc_tanh = Tensor(fc_out.getDim());
- dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
+ dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
Tensor dfc_out;
tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
- dquery.dot_deriv_wrt_1(fc_w, dfc_out);
+ dquery.dot_deriv_wrt_1(fc_w, dfc_out, false, false);
}
void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
calcDerivativeHelper(context, dstate);
Tensor dfc_tanh = Tensor(fc_out.getDim());
- fc_tanh.dot_deriv_wrt_2(dfc_proj_w, dfc_proj_out);
+ fc_tanh.dot_deriv_wrt_2(
+ dfc_proj_w, dfc_proj_out, false, false,
+ !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w]));
dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
Tensor dfc_out;
tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
- query.dot_deriv_wrt_2(dfc_w, dfc_out);
- dfc_out.sum({0, 1, 2}, dfc_bias);
+ query.dot_deriv_wrt_2(
+ dfc_w, dfc_out, false, false,
+ !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_w]));
+
+ if (context.isGradientFirstAccess(wt_idx[AttentionParams::fc_bias])) {
+ dfc_out.sum({0, 1, 2}, dfc_bias);
+ } else {
+ /// @todo optimize below by adding beta to Tensor::sum
+ Tensor t = dfc_out.sum({0, 1, 2});
+ dfc_bias.add_i(t);
+ }
}
void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
context.updateTensor(wt_idx[AttentionParams::fc_out], batch);
+ context.updateTensor(wt_idx[AttentionParams::fc_tanh], batch);
context.updateTensor(wt_idx[AttentionParams::fc_proj_out], batch);
context.updateTensor(wt_idx[AttentionParams::scores], batch);
context.updateTensor(wt_idx[AttentionParams::prob], batch);