[layer] Fixes for mol attention layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 6 Dec 2021 17:04:04 +0000 (02:04 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 7 Dec 2021 06:15:53 +0000 (15:15 +0900)
- gradient accumulation
- support only 1 output in case state update is not used
- bug fix in backwarding
- activation to work out of place

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/mol_attention_layer.h

index a694212..ea56d98 100644 (file)
 
 namespace nntrainer {
 
-MoLAttentionLayer::MoLAttentionLayer() : helper_exec(false), wt_idx({0}) {}
+MoLAttentionLayer::MoLAttentionLayer() :
+  helper_exec(false),
+  softmax(ActivationType::ACT_SOFTMAX, false),
+  tanh(ActivationType::ACT_TANH, false),
+  sigmoid(ActivationType::ACT_SIGMOID, false),
+  wt_idx({std::numeric_limits<unsigned>::max()}) {}
 
 MoLAttentionLayer::~MoLAttentionLayer() {}
 
@@ -44,7 +49,8 @@ enum AttentionParams {
   prob_right,
   u_neg_div,
   u_pos_div,
-  dstate
+  dstate,
+  updated_state
 };
 
 void MoLAttentionLayer::finalize(InitLayerContext &context) {
@@ -61,10 +67,6 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
   wt_idx[AttentionParams::state] = AttentionParams::state;
   wt_idx[AttentionParams::mask_len] = AttentionParams::mask_len;
 
-  softmax.setActiFunc(ActivationType::ACT_SOFTMAX);
-  tanh.setActiFunc(ActivationType::ACT_TANH);
-  sigmoid.setActiFunc(ActivationType::ACT_SIGMOID);
-
   NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
     << "Query and Value dimension mismatch for layer " << context.getName();
 
@@ -109,6 +111,10 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
     context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE,
                           false, TensorLifespan::ITERATION_LIFESPAN);
 
+  wt_idx[AttentionParams::fc_tanh] =
+    context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+
   TensorDim fc_proj_out_dim = fc_out_dim;
   fc_proj_out_dim.width(fc_proj_w_dim.width());
   wt_idx[AttentionParams::fc_proj_out] = context.requestTensor(
@@ -142,7 +148,10 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
     context.requestTensor(state_dim, "dstate", Tensor::Initializer::NONE, false,
                           TensorLifespan::BACKWARD_FUNC_LIFESPAN);
 
-  context.setOutputDimensions({query_dim, state_dim});
+  if (context.getNumOutputs() == 2)
+    context.setOutputDimensions({query_dim, state_dim});
+  else
+    context.setOutputDimensions({query_dim});
 }
 
 void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
@@ -151,7 +160,6 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &state = context.getInput(wt_idx[AttentionParams::state]);
 
   Tensor &output = context.getOutput(0);
-  Tensor &updated_state = context.getOutput(1);
   Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
   Tensor &fc_bias = context.getWeight(wt_idx[AttentionParams::fc_bias]);
   Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
@@ -202,8 +210,6 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
     .copy_with_stride(alpha);
 
-  state.add(kappa, updated_state);
-
   /** @todo cache u_base, u_pos, u_neg */
   Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
   for (unsigned int b = 0; b < batch; b++) {
@@ -219,11 +225,21 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
 
   Tensor beta_eps = beta.add(1e-8f);
 
-  Tensor u_pos_m = u_pos.subtract(updated_state);
+  Tensor u_pos_m, u_neg_m;
+  if (context.getNumOutputs() == 2) {
+    Tensor &updated_state = context.getOutput(1);
+    state.add(kappa, updated_state);
+    u_pos_m = u_pos.subtract(updated_state);
+    u_neg_m = u_neg.subtract(updated_state);
+  } else {
+    Tensor updated_state = state.add(kappa);
+    u_pos_m = u_pos.subtract(updated_state);
+    u_neg_m = u_neg.subtract(updated_state);
+  }
+
   u_pos_m.divide(beta_eps, u_pos_div);
   sigmoid.run_fn(u_pos_div, prob_left);
 
-  Tensor u_neg_m = u_neg.subtract(updated_state);
   u_neg_m.divide(beta_eps, u_neg_div);
   sigmoid.run_fn(u_neg_div, prob_right);
 
@@ -249,7 +265,6 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
 
   Tensor &derivative = context.getIncomingDerivative(0);
-  Tensor &derivative_state = context.getIncomingDerivative(1);
 
   Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
   Tensor &dfc_proj_out =
@@ -306,7 +321,10 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
 
   Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
   dm_neg.add(dm_pos, dstate);
-  dstate.add_i(derivative_state);
+  if (context.getNumOutputs() == 2) {
+    Tensor &derivative_state = context.getIncomingDerivative(1);
+    dstate.add_i(derivative_state);
+  }
   Tensor dkappa = dstate;
   Tensor dbeta = dbeta_eps;
 
@@ -347,7 +365,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
     context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
   Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]);
 
-  scores.dot_batched_deriv_wrt_1(dvalue, derivative);
+  scores.dot_batched_deriv_wrt_2(dvalue, derivative);
 
   if (!helper_exec)
     calcDerivativeHelper(context, dstate);
@@ -355,11 +373,11 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
     dstate.copyData(dstate_local);
 
   Tensor dfc_tanh = Tensor(fc_out.getDim());
-  dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
+  dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
 
   Tensor dfc_out;
   tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
-  dquery.dot_deriv_wrt_1(fc_w, dfc_out);
+  dquery.dot_deriv_wrt_1(fc_w, dfc_out, false, false);
 }
 
 void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
@@ -380,13 +398,24 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
     calcDerivativeHelper(context, dstate);
 
   Tensor dfc_tanh = Tensor(fc_out.getDim());
-  fc_tanh.dot_deriv_wrt_2(dfc_proj_w, dfc_proj_out);
+  fc_tanh.dot_deriv_wrt_2(
+    dfc_proj_w, dfc_proj_out, false, false,
+    !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w]));
   dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
 
   Tensor dfc_out;
   tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
-  query.dot_deriv_wrt_2(dfc_w, dfc_out);
-  dfc_out.sum({0, 1, 2}, dfc_bias);
+  query.dot_deriv_wrt_2(
+    dfc_w, dfc_out, false, false,
+    !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_w]));
+
+  if (context.isGradientFirstAccess(wt_idx[AttentionParams::fc_bias])) {
+    dfc_out.sum({0, 1, 2}, dfc_bias);
+  } else {
+    /// @todo optimize below by adding beta to Tensor::sum
+    Tensor t = dfc_out.sum({0, 1, 2});
+    dfc_bias.add_i(t);
+  }
 }
 
 void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
@@ -396,6 +425,7 @@ void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
 
 void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
   context.updateTensor(wt_idx[AttentionParams::fc_out], batch);
+  context.updateTensor(wt_idx[AttentionParams::fc_tanh], batch);
   context.updateTensor(wt_idx[AttentionParams::fc_proj_out], batch);
   context.updateTensor(wt_idx[AttentionParams::scores], batch);
   context.updateTensor(wt_idx[AttentionParams::prob], batch);
index 141c7c7..8e28c22 100644 (file)
@@ -106,7 +106,7 @@ private:
   ActiFunc softmax; /** softmax activation operation */
   ActiFunc tanh;    /** softmax activation operation */
   ActiFunc sigmoid; /** softmax activation operation */
-  std::array<unsigned int, 16>
+  std::array<unsigned int, 17>
     wt_idx; /**< indices of the weights and tensors */
 
   /**