From bed74b801d686314964ca51ef6efbce1c370e5e8 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Tue, 7 Dec 2021 02:04:04 +0900 Subject: [PATCH] [layer] Fixes for mol attention layer - gradient accumulation - support only 1 output in case state update is not used - bug fix in backwarding - activation to work out of place Signed-off-by: Parichay Kapoor --- nntrainer/layers/mol_attention_layer.cpp | 70 +++++++++++++++++++++++--------- nntrainer/layers/mol_attention_layer.h | 2 +- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/nntrainer/layers/mol_attention_layer.cpp b/nntrainer/layers/mol_attention_layer.cpp index a694212..ea56d98 100644 --- a/nntrainer/layers/mol_attention_layer.cpp +++ b/nntrainer/layers/mol_attention_layer.cpp @@ -21,7 +21,12 @@ namespace nntrainer { -MoLAttentionLayer::MoLAttentionLayer() : helper_exec(false), wt_idx({0}) {} +MoLAttentionLayer::MoLAttentionLayer() : + helper_exec(false), + softmax(ActivationType::ACT_SOFTMAX, false), + tanh(ActivationType::ACT_TANH, false), + sigmoid(ActivationType::ACT_SIGMOID, false), + wt_idx({std::numeric_limits::max()}) {} MoLAttentionLayer::~MoLAttentionLayer() {} @@ -44,7 +49,8 @@ enum AttentionParams { prob_right, u_neg_div, u_pos_div, - dstate + dstate, + updated_state }; void MoLAttentionLayer::finalize(InitLayerContext &context) { @@ -61,10 +67,6 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) { wt_idx[AttentionParams::state] = AttentionParams::state; wt_idx[AttentionParams::mask_len] = AttentionParams::mask_len; - softmax.setActiFunc(ActivationType::ACT_SOFTMAX); - tanh.setActiFunc(ActivationType::ACT_TANH); - sigmoid.setActiFunc(ActivationType::ACT_SIGMOID); - NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument) << "Query and Value dimension mismatch for layer " << context.getName(); @@ -109,6 +111,10 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) { context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); + wt_idx[AttentionParams::fc_tanh] = + context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE, + false, TensorLifespan::ITERATION_LIFESPAN); + TensorDim fc_proj_out_dim = fc_out_dim; fc_proj_out_dim.width(fc_proj_w_dim.width()); wt_idx[AttentionParams::fc_proj_out] = context.requestTensor( @@ -142,7 +148,10 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) { context.requestTensor(state_dim, "dstate", Tensor::Initializer::NONE, false, TensorLifespan::BACKWARD_FUNC_LIFESPAN); - context.setOutputDimensions({query_dim, state_dim}); + if (context.getNumOutputs() == 2) + context.setOutputDimensions({query_dim, state_dim}); + else + context.setOutputDimensions({query_dim}); } void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) { @@ -151,7 +160,6 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) { Tensor &state = context.getInput(wt_idx[AttentionParams::state]); Tensor &output = context.getOutput(0); - Tensor &updated_state = context.getOutput(1); Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]); Tensor &fc_bias = context.getWeight(wt_idx[AttentionParams::fc_bias]); Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]); @@ -202,8 +210,6 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) { fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false) .copy_with_stride(alpha); - state.add(kappa, updated_state); - /** @todo cache u_base, u_pos, u_neg */ Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k})); for (unsigned int b = 0; b < batch; b++) { @@ -219,11 +225,21 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) { Tensor beta_eps = beta.add(1e-8f); - Tensor u_pos_m = u_pos.subtract(updated_state); + Tensor u_pos_m, u_neg_m; + if (context.getNumOutputs() == 2) { + Tensor &updated_state = context.getOutput(1); + state.add(kappa, updated_state); + u_pos_m = u_pos.subtract(updated_state); + u_neg_m = u_neg.subtract(updated_state); + } else { + Tensor updated_state = state.add(kappa); + u_pos_m = u_pos.subtract(updated_state); + u_neg_m = u_neg.subtract(updated_state); + } + u_pos_m.divide(beta_eps, u_pos_div); sigmoid.run_fn(u_pos_div, prob_left); - Tensor u_neg_m = u_neg.subtract(updated_state); u_neg_m.divide(beta_eps, u_neg_div); sigmoid.run_fn(u_neg_div, prob_right); @@ -249,7 +265,6 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context, Tensor &value = context.getInput(wt_idx[AttentionParams::value]); Tensor &derivative = context.getIncomingDerivative(0); - Tensor &derivative_state = context.getIncomingDerivative(1); Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]); Tensor &dfc_proj_out = @@ -306,7 +321,10 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context, Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos); dm_neg.add(dm_pos, dstate); - dstate.add_i(derivative_state); + if (context.getNumOutputs() == 2) { + Tensor &derivative_state = context.getIncomingDerivative(1); + dstate.add_i(derivative_state); + } Tensor dkappa = dstate; Tensor dbeta = dbeta_eps; @@ -347,7 +365,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) { context.getTensor(wt_idx[AttentionParams::fc_proj_out]); Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]); - scores.dot_batched_deriv_wrt_1(dvalue, derivative); + scores.dot_batched_deriv_wrt_2(dvalue, derivative); if (!helper_exec) calcDerivativeHelper(context, dstate); @@ -355,11 +373,11 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) { dstate.copyData(dstate_local); Tensor dfc_tanh = Tensor(fc_out.getDim()); - dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out); + dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false); Tensor dfc_out; tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh); - dquery.dot_deriv_wrt_1(fc_w, dfc_out); + dquery.dot_deriv_wrt_1(fc_w, dfc_out, false, false); } void MoLAttentionLayer::calcGradient(RunLayerContext &context) { @@ -380,13 +398,24 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) { calcDerivativeHelper(context, dstate); Tensor dfc_tanh = Tensor(fc_out.getDim()); - fc_tanh.dot_deriv_wrt_2(dfc_proj_w, dfc_proj_out); + fc_tanh.dot_deriv_wrt_2( + dfc_proj_w, dfc_proj_out, false, false, + !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w])); dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out); Tensor dfc_out; tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh); - query.dot_deriv_wrt_2(dfc_w, dfc_out); - dfc_out.sum({0, 1, 2}, dfc_bias); + query.dot_deriv_wrt_2( + dfc_w, dfc_out, false, false, + !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_w])); + + if (context.isGradientFirstAccess(wt_idx[AttentionParams::fc_bias])) { + dfc_out.sum({0, 1, 2}, dfc_bias); + } else { + /// @todo optimize below by adding beta to Tensor::sum + Tensor t = dfc_out.sum({0, 1, 2}); + dfc_bias.add_i(t); + } } void MoLAttentionLayer::setProperty(const std::vector &values) { @@ -396,6 +425,7 @@ void MoLAttentionLayer::setProperty(const std::vector &values) { void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) { context.updateTensor(wt_idx[AttentionParams::fc_out], batch); + context.updateTensor(wt_idx[AttentionParams::fc_tanh], batch); context.updateTensor(wt_idx[AttentionParams::fc_proj_out], batch); context.updateTensor(wt_idx[AttentionParams::scores], batch); context.updateTensor(wt_idx[AttentionParams::prob], batch); diff --git a/nntrainer/layers/mol_attention_layer.h b/nntrainer/layers/mol_attention_layer.h index 141c7c7..8e28c22 100644 --- a/nntrainer/layers/mol_attention_layer.h +++ b/nntrainer/layers/mol_attention_layer.h @@ -106,7 +106,7 @@ private: ActiFunc softmax; /** softmax activation operation */ ActiFunc tanh; /** softmax activation operation */ ActiFunc sigmoid; /** softmax activation operation */ - std::array + std::array wt_idx; /**< indices of the weights and tensors */ /** -- 2.7.4