[multi head attention] implement calcCommonDerivative
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 13 Jul 2022 02:25:08 +0000 (11:25 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 7 Sep 2022 11:44:07 +0000 (20:44 +0900)
 - implement multi head attention calcCommonDerivative

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/multi_head_attention_layer.cpp

index 8579dc7..68ba1e1 100644 (file)
@@ -528,7 +528,183 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
 }
 
-void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {}
+void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const unsigned int output_shape =
+    std::get<props::OutputShape>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+
+  Tensor empty_tensor;
+
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+  Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE);
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
+  const Tensor &d_ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &d_projected_query =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &d_projected_key =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+  Tensor &d_projected_value =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
+
+  Tensor &d_attention_score =
+    context.getTensor(weight_idx[AttentionParams::d_attention_score]);
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &d_attention_weight =
+    context.getTensorGrad(weight_idx[AttentionParams::attention_weight]);
+  Tensor &d_attention_output =
+    context.getTensorGrad(weight_idx[AttentionParams::attention_output]);
+
+  const TensorDim query_dim = query.getDim();
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const unsigned int query_width = query_dim.width();
+  const TensorDim key_dim = key.getDim();
+  const unsigned int key_height = key_dim.height();
+  const unsigned int input_key_width_size = key_dim.width();
+  const TensorDim value_dim = value.getDim();
+  const unsigned int value_height = value_dim.height();
+  const unsigned int input_value_width_size = value_dim.width();
+
+  d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative);
+
+  d_attention_output.reshape(
+    TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop}));
+
+  d_attention_output = d_attention_output.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove during
+   * transpose */
+  d_attention_output.setName("multi_head_attention:attention_output:grad");
+
+  projected_query.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+  d_projected_query.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+  d_projected_key.reshape(
+    TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+  d_projected_value.reshape(TensorDim(
+    {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+
+  d_attention_score.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  attention_weight.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  d_attention_weight.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  d_attention_output.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
+
+  d_attention_weight.dot_batched_deriv_wrt_1(projected_value,
+                                             d_attention_output);
+  attention_weight.dot_batched_deriv_wrt_2(d_projected_value,
+                                           d_attention_output);
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::after) {
+    const float scale = average_attention_weight ? 1 / (float)num_heads : 1;
+    d_attention_weight.add_i(d_ret_attention_weight, scale);
+  }
+
+  if (dropout_rate > epsilon) {
+    Tensor &dropout_mask =
+      context.getTensor(weight_idx[AttentionParams::dropout_mask]);
+    d_attention_weight.multiply_i(dropout_mask);
+  }
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::before) {
+    d_attention_weight.add_i(d_ret_attention_weight);
+  }
+
+  sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight);
+  if (provide_attention_mask) {
+    Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
+    d_mask.copyData(d_attention_score);
+  }
+  d_attention_score.multiply_i(
+    1 / sqrt((float)projected_query_dim_prop)); /** scale */
+
+  d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score,
+                                            false, true);
+  projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score,
+                                          false, true);
+
+  d_projected_query.reshape(
+    TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop}));
+  d_projected_key.reshape(
+    TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop}));
+  d_projected_value.reshape(
+    TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop}));
+
+  d_projected_query = d_projected_query.transpose("1:0:2");
+  d_projected_key = d_projected_key.transpose("1:0:2");
+  d_projected_value = d_projected_value.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove during
+   * transpose */
+  d_projected_query.setName("multi_head_attention:projected_query:grad");
+  d_projected_key.setName("multi_head_attention:projected_key:grad");
+  d_projected_value.setName("multi_head_attention:projected_value:grad");
+
+  /** restore shape */
+  projected_query.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
+  d_projected_query.reshape(TensorDim(
+    {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
+  d_projected_key.reshape(TensorDim(
+    {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
+  d_projected_value.reshape(TensorDim(
+    {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
+
+  d_attention_score.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  attention_weight.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  d_attention_weight.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  d_attention_output.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
+}
 
 void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) {
   if (!context.getTrainable()) {