From: hyeonseok lee Date: Wed, 13 Jul 2022 02:25:08 +0000 (+0900) Subject: [multi head attention] implement calcCommonDerivative X-Git-Tag: accepted/tizen/unified/20220919.021604~8 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a0a3ee08ee30d2d166564c21c0bbdb319ccf91d2;p=platform%2Fcore%2Fml%2Fnntrainer.git [multi head attention] implement calcCommonDerivative - implement multi head attention calcCommonDerivative Signed-off-by: hyeonseok lee --- diff --git a/nntrainer/layers/multi_head_attention_layer.cpp b/nntrainer/layers/multi_head_attention_layer.cpp index 8579dc7..68ba1e1 100644 --- a/nntrainer/layers/multi_head_attention_layer.cpp +++ b/nntrainer/layers/multi_head_attention_layer.cpp @@ -528,7 +528,183 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context, {batch_size, 1, query_height, num_heads * projected_value_dim_prop})); } -void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {} +void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) { + const unsigned int num_heads = + std::get(multi_head_attention_props).get(); + const unsigned int projected_key_dim_prop = + std::get(multi_head_attention_props).get(); + const unsigned int projected_value_dim_prop = + std::get(multi_head_attention_props).get(); + const unsigned int output_shape = + std::get(multi_head_attention_props).get(); + const float dropout_rate = + std::get(multi_head_attention_props).get(); + const props::ReturnAttentionWeightInfo::Enum return_attention_weight = + std::get(multi_head_attention_props).get(); + const bool average_attention_weight = + std::get(multi_head_attention_props).get(); + + const bool provide_attention_mask = context.getNumInputs() == 4; + const unsigned int projected_query_dim_prop = projected_key_dim_prop; + + Tensor empty_tensor; + + Tensor &query = context.getInput(INOUT_INDEX::QUERY); + Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY); + Tensor &key = context.getInput(INOUT_INDEX::KEY); + Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY); + Tensor &value = context.getInput(INOUT_INDEX::VALUE); + Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE); + const Tensor &incoming_derivative = + context.getIncomingDerivative(INOUT_INDEX::OUTPUT); + const Tensor &d_ret_attention_weight = + return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none + ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT) + : empty_tensor; + + Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]); + + Tensor &projected_query = + context.getTensor(weight_idx[AttentionParams::projected_query]); + Tensor &d_projected_query = + context.getTensorGrad(weight_idx[AttentionParams::projected_query]); + Tensor &projected_key = + context.getTensor(weight_idx[AttentionParams::projected_key]); + Tensor &d_projected_key = + context.getTensorGrad(weight_idx[AttentionParams::projected_key]); + Tensor &projected_value = + context.getTensor(weight_idx[AttentionParams::projected_value]); + Tensor &d_projected_value = + context.getTensorGrad(weight_idx[AttentionParams::projected_value]); + + Tensor &d_attention_score = + context.getTensor(weight_idx[AttentionParams::d_attention_score]); + Tensor &attention_weight = + context.getTensor(weight_idx[AttentionParams::attention_weight]); + Tensor &d_attention_weight = + context.getTensorGrad(weight_idx[AttentionParams::attention_weight]); + Tensor &d_attention_output = + context.getTensorGrad(weight_idx[AttentionParams::attention_output]); + + const TensorDim query_dim = query.getDim(); + const unsigned int batch_size = query_dim.batch(); + const unsigned int query_height = query_dim.height(); + const unsigned int query_width = query_dim.width(); + const TensorDim key_dim = key.getDim(); + const unsigned int key_height = key_dim.height(); + const unsigned int input_key_width_size = key_dim.width(); + const TensorDim value_dim = value.getDim(); + const unsigned int value_height = value_dim.height(); + const unsigned int input_value_width_size = value_dim.width(); + + d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative); + + d_attention_output.reshape( + TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop})); + + d_attention_output = d_attention_output.transpose("1:0:2"); + + /** set tensor name to restore origin name cause origin name was remove during + * transpose */ + d_attention_output.setName("multi_head_attention:attention_output:grad"); + + projected_query.reshape(TensorDim( + {batch_size * num_heads, 1, query_height, projected_query_dim_prop})); + d_projected_query.reshape(TensorDim( + {batch_size * num_heads, 1, query_height, projected_query_dim_prop})); + projected_key.reshape( + TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop})); + d_projected_key.reshape( + TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop})); + projected_value.reshape(TensorDim( + {batch_size * num_heads, 1, value_height, projected_value_dim_prop})); + d_projected_value.reshape(TensorDim( + {batch_size * num_heads, 1, value_height, projected_value_dim_prop})); + + d_attention_score.reshape( + TensorDim({batch_size * num_heads, 1, query_height, key_height})); + attention_weight.reshape( + TensorDim({batch_size * num_heads, 1, query_height, key_height})); + d_attention_weight.reshape( + TensorDim({batch_size * num_heads, 1, query_height, key_height})); + d_attention_output.reshape(TensorDim( + {batch_size * num_heads, 1, query_height, projected_value_dim_prop})); + + d_attention_weight.dot_batched_deriv_wrt_1(projected_value, + d_attention_output); + attention_weight.dot_batched_deriv_wrt_2(d_projected_value, + d_attention_output); + + if (return_attention_weight == + props::ReturnAttentionWeightInfo::Enum::after) { + const float scale = average_attention_weight ? 1 / (float)num_heads : 1; + d_attention_weight.add_i(d_ret_attention_weight, scale); + } + + if (dropout_rate > epsilon) { + Tensor &dropout_mask = + context.getTensor(weight_idx[AttentionParams::dropout_mask]); + d_attention_weight.multiply_i(dropout_mask); + } + + if (return_attention_weight == + props::ReturnAttentionWeightInfo::Enum::before) { + d_attention_weight.add_i(d_ret_attention_weight); + } + + sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight); + if (provide_attention_mask) { + Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK); + d_mask.copyData(d_attention_score); + } + d_attention_score.multiply_i( + 1 / sqrt((float)projected_query_dim_prop)); /** scale */ + + d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score, + false, true); + projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score, + false, true); + + d_projected_query.reshape( + TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop})); + d_projected_key.reshape( + TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop})); + d_projected_value.reshape( + TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop})); + + d_projected_query = d_projected_query.transpose("1:0:2"); + d_projected_key = d_projected_key.transpose("1:0:2"); + d_projected_value = d_projected_value.transpose("1:0:2"); + + /** set tensor name to restore origin name cause origin name was remove during + * transpose */ + d_projected_query.setName("multi_head_attention:projected_query:grad"); + d_projected_key.setName("multi_head_attention:projected_key:grad"); + d_projected_value.setName("multi_head_attention:projected_value:grad"); + + /** restore shape */ + projected_query.reshape(TensorDim( + {batch_size, 1, query_height, num_heads * projected_query_dim_prop})); + d_projected_query.reshape(TensorDim( + {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop})); + projected_key.reshape( + TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop})); + d_projected_key.reshape(TensorDim( + {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop})); + projected_value.reshape(TensorDim( + {batch_size, 1, value_height, num_heads * projected_value_dim_prop})); + d_projected_value.reshape(TensorDim( + {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop})); + + d_attention_score.reshape( + TensorDim({batch_size, num_heads, query_height, key_height})); + attention_weight.reshape( + TensorDim({batch_size, num_heads, query_height, key_height})); + d_attention_weight.reshape( + TensorDim({batch_size, num_heads, query_height, key_height})); + d_attention_output.reshape(TensorDim( + {batch_size, 1, query_height, num_heads * projected_value_dim_prop})); +} void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) { if (!context.getTrainable()) {