[layer] revise attention layers to apply softmax as inplace
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 7 Sep 2022 04:31:24 +0000 (13:31 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 9 Feb 2023 08:11:05 +0000 (17:11 +0900)
 - Remove attention_score tensor in attention/multi_head_attention layer to apply softmax as inplace
 - Modify tensor lifespan of fc_out to FORWARD_FUNC_LIFESPAN
 - remove unused enum updated_state

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/attention_layer.cpp
nntrainer/layers/attention_layer.h
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/multi_head_attention_layer.cpp
nntrainer/layers/multi_head_attention_layer.h

index 07117a2..a2767cd 100644 (file)
@@ -18,7 +18,7 @@
 
 namespace nntrainer {
 
-AttentionLayer::AttentionLayer() {
+AttentionLayer::AttentionLayer() : sm(ActivationType::ACT_SOFTMAX) {
   wt_idx.fill(std::numeric_limits<unsigned>::max());
 }
 
@@ -26,7 +26,7 @@ AttentionLayer::~AttentionLayer() {}
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum AttentionParams { query = 0, value = 1, key = 2, score, weights };
+enum AttentionParams { query = 0, value = 1, key = 2, weights };
 
 void AttentionLayer::finalizeCommon(InitLayerContext &context) {
   if (context.getNumInputs() < 2 || context.getNumInputs() > 3)
@@ -55,8 +55,6 @@ void AttentionLayer::finalizeCommon(InitLayerContext &context) {
 void AttentionLayer::finalize(InitLayerContext &context) {
   finalizeCommon(context);
 
-  sm.setActiFunc(ActivationType::ACT_SOFTMAX);
-
   auto const &all_dims = context.getInputDimensions();
   auto const &query_dim = all_dims[AttentionParams::query];
   auto const &value_dim = all_dims[AttentionParams::value];
@@ -67,10 +65,6 @@ void AttentionLayer::finalize(InitLayerContext &context) {
     context.requestTensor(weights_dim, "weights", Tensor::Initializer::NONE,
                           false, TensorLifespan::ITERATION_LIFESPAN);
 
-  wt_idx[AttentionParams::score] =
-    context.requestTensor(weights_dim, "score", Tensor::Initializer::NONE,
-                          false, TensorLifespan::FORWARD_FUNC_LIFESPAN);
-
   context.setOutputDimensions({query_dim});
 }
 
@@ -81,11 +75,10 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
-  Tensor &score = context.getTensor(wt_idx[AttentionParams::score]);
 
-  query.dotBatched(key, score, false, true); /** dot 1 */
-  sm.run_fn(score, weights);                 /** softmax */
-  weights.dotBatched(value, output);         /** dot 2 */
+  query.dotBatched(key, weights, false, true); /** dot 1 */
+  sm.run_fn(weights, weights);                 /** softmax */
+  weights.dotBatched(value, output);           /** dot 2 */
 }
 
 void AttentionLayer::calcDerivative(RunLayerContext &context) {
@@ -111,12 +104,11 @@ void AttentionLayer::calcDerivative(RunLayerContext &context) {
   weights.dot_batched_deriv_wrt_2(dvalue, derivative);
 
   /** derivative for softmax */
-  Tensor dscore;
-  sm.run_prime_fn(weights, dscore, dweight);
+  sm.run_prime_fn(weights, dweight, dweight);
 
   /** derivative for dot 1 */
-  dquery.dot_batched_deriv_wrt_1(key, dscore, false, true);
-  query.dot_batched_deriv_wrt_2(dkey, dscore, false, true,
+  dquery.dot_batched_deriv_wrt_1(key, dweight, false, true);
+  query.dot_batched_deriv_wrt_2(dkey, dweight, false, true,
                                 context.getNumInputs() == 2);
 }
 
@@ -129,7 +121,6 @@ void AttentionLayer::setProperty(const std::vector<std::string> &values) {
 }
 
 void AttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
-  context.updateTensor(wt_idx[AttentionParams::score], batch);
   context.updateTensor(wt_idx[AttentionParams::weights], batch);
 }
 
index dcfb6c2..d1d727a 100644 (file)
@@ -104,7 +104,7 @@ protected:
 
 private:
   ActiFunc sm;                        /** softmax activation operation */
-  std::array<unsigned int, 5> wt_idx; /**< indices of the weights and tensors */
+  std::array<unsigned int, 4> wt_idx; /**< indices of the weights and tensors */
 };
 
 } // namespace nntrainer
index 42e1110..a5b2d51 100644 (file)
@@ -51,7 +51,6 @@ enum AttentionParams {
   u_neg_div,
   u_pos_div,
   dstate,
-  updated_state
 };
 
 void MoLAttentionLayer::finalize(InitLayerContext &context) {
@@ -113,7 +112,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
   fc_out_dim.width(fc_w_dim.width());
   wt_idx[AttentionParams::fc_out] =
     context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE,
-                          false, TensorLifespan::ITERATION_LIFESPAN);
+                          false, TensorLifespan::FORWARD_FUNC_LIFESPAN);
 
   wt_idx[AttentionParams::fc_tanh] =
     context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE,
@@ -363,7 +362,6 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
 
   Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
   Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
-  Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
   Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
   Tensor &dfc_proj_out =
     context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
@@ -376,7 +374,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
   else
     dstate.copyData(dstate_local);
 
-  Tensor dfc_tanh = Tensor(fc_out.getDim());
+  Tensor dfc_tanh = Tensor(fc_tanh.getDim());
   dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
 
   Tensor dfc_out;
@@ -393,7 +391,6 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
   Tensor &dfc_bias = context.getWeightGrad(wt_idx[AttentionParams::fc_bias]);
   Tensor &dfc_proj_w =
     context.getWeightGrad(wt_idx[AttentionParams::fc_proj_w]);
-  Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
   Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
   Tensor &dfc_proj_out =
     context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
@@ -401,7 +398,7 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
   if (!helper_exec)
     calcDerivativeHelper(context, dstate);
 
-  Tensor dfc_tanh = Tensor(fc_out.getDim());
+  Tensor dfc_tanh = Tensor(fc_tanh.getDim());
   fc_tanh.dot_deriv_wrt_2(
     dfc_proj_w, dfc_proj_out, false, false,
     !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w]));
index f91f799..3b6cc37 100644 (file)
@@ -27,6 +27,7 @@ MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
     props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
     props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
     props::AverageAttentionWeight()),
+  sm(ActivationType::ACT_SOFTMAX),
   epsilon(1e-3) {
   weight_idx.fill(std::numeric_limits<unsigned>::max());
 }
@@ -56,8 +57,6 @@ enum AttentionParams {
   projected_query,
   projected_key,
   projected_value,
-  attention_score,
-  d_attention_score,
   /** intended comment for later use of attention_mask */
   // attention_mask,
   attention_weight,
@@ -258,15 +257,6 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
     projected_value_dim, "projected_value", Tensor::Initializer::NONE, true,
     TensorLifespan::ITERATION_LIFESPAN);
 
-  /** tensor for attention score */
-  TensorDim attention_score_dim(
-    {batch_size, num_heads, query_height, key_height});
-  weight_idx[AttentionParams::attention_score] = context.requestTensor(
-    attention_score_dim, "attention_score", Tensor::Initializer::NONE, false,
-    TensorLifespan::FORWARD_FUNC_LIFESPAN);
-  weight_idx[AttentionParams::d_attention_score] = context.requestTensor(
-    attention_score_dim, "d_attention_score", Tensor::Initializer::NONE, false,
-    TensorLifespan::BACKWARD_FUNC_LIFESPAN);
   if (provide_attention_mask) {
     /** Intended comment for bool type mask */
     // TensorDim attention_mask_dim(
@@ -376,8 +366,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
   Tensor &projected_value =
     context.getTensor(weight_idx[AttentionParams::projected_value]);
 
-  Tensor &attention_score =
-    context.getTensor(weight_idx[AttentionParams::attention_score]);
   Tensor &attention_weight =
     context.getTensor(weight_idx[AttentionParams::attention_weight]);
   Tensor &attention_output =
@@ -431,16 +419,14 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
   projected_value.reshape(TensorDim(
     {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
 
-  attention_score.reshape(
-    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
   attention_weight.reshape(
     TensorDim({batch_size * num_heads, 1, query_height, key_height}));
   attention_output.reshape(TensorDim(
     {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
 
   /** scaled dot product attention */
-  projected_query.dotBatched(projected_key, attention_score, false, true);
-  attention_score.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+  projected_query.dotBatched(projected_key, attention_weight, false, true);
+  attention_weight.multiply_i(1 / sqrt((float)projected_query_dim_prop));
 
   if (provide_attention_mask) {
     // Tensor &attention_mask =
@@ -456,16 +442,16 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     //   attention_mask.multiply_i(-1e9);
     // }
     // attention_mask.multiply_i(mask);
-    // attention_score.add_i(attention_mask);
+    // attention_weight.add_i(attention_mask);
 
-    attention_score.reshape(
+    attention_weight.reshape(
       TensorDim({batch_size, num_heads, query_height, key_height}));
-    attention_score.add_i(mask);
-    attention_score.reshape(
+    attention_weight.add_i(mask);
+    attention_weight.reshape(
       TensorDim({batch_size * num_heads, 1, query_height, key_height}));
   }
 
-  sm.run_fn(attention_score, attention_weight);
+  sm.run_fn(attention_weight, attention_weight);
 
   if (return_attention_weight ==
       props::ReturnAttentionWeightInfo::Enum::before) {
@@ -520,8 +506,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
   projected_value.reshape(TensorDim(
     {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
 
-  attention_score.reshape(
-    TensorDim({batch_size, num_heads, query_height, key_height}));
   attention_weight.reshape(
     TensorDim({batch_size, num_heads, query_height, key_height}));
   attention_output.reshape(TensorDim(
@@ -577,8 +561,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
   Tensor &d_projected_value =
     context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
 
-  Tensor &d_attention_score =
-    context.getTensor(weight_idx[AttentionParams::d_attention_score]);
   Tensor &attention_weight =
     context.getTensor(weight_idx[AttentionParams::attention_weight]);
   Tensor &d_attention_weight =
@@ -621,8 +603,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
   d_projected_value.reshape(TensorDim(
     {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
 
-  d_attention_score.reshape(
-    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
   attention_weight.reshape(
     TensorDim({batch_size * num_heads, 1, query_height, key_height}));
   d_attention_weight.reshape(
@@ -652,17 +632,17 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
     d_attention_weight.add_i(d_ret_attention_weight);
   }
 
-  sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight);
+  sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight);
   if (provide_attention_mask) {
     Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
-    d_mask.copyData(d_attention_score);
+    d_mask.copyData(d_attention_weight);
   }
-  d_attention_score.multiply_i(
+  d_attention_weight.multiply_i(
     1 / sqrt((float)projected_query_dim_prop)); /** scale */
 
-  d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score,
+  d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight,
                                             false, true);
-  projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score,
+  projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight,
                                           false, true);
 
   d_projected_query.reshape(
@@ -696,8 +676,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
   d_projected_value.reshape(TensorDim(
     {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
 
-  d_attention_score.reshape(
-    TensorDim({batch_size, num_heads, query_height, key_height}));
   attention_weight.reshape(
     TensorDim({batch_size, num_heads, query_height, key_height}));
   d_attention_weight.reshape(
@@ -895,8 +873,6 @@ void MultiHeadAttentionLayer::setBatch(RunLayerContext &context,
   context.updateTensor(weight_idx[AttentionParams::projected_query], batch);
   context.updateTensor(weight_idx[AttentionParams::projected_key], batch);
   context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
-  context.updateTensor(weight_idx[AttentionParams::attention_score], batch);
-  context.updateTensor(weight_idx[AttentionParams::d_attention_score], batch);
   context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
   if (dropout_rate > epsilon) {
     context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);
index 65141ea..4085aaf 100644 (file)
@@ -108,7 +108,7 @@ private:
     multi_head_attention_props; /**< multi_head_attention layer properties */
 
   ActiFunc sm; /** softmax activation operation */
-  std::array<unsigned int, 16>
+  std::array<unsigned int, 14>
     weight_idx; /**< indices of the weights and tensors */
 
   /**