{batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
}
-void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {}
+void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
+ const unsigned int num_heads =
+ std::get<props::NumHeads>(multi_head_attention_props).get();
+ const unsigned int projected_key_dim_prop =
+ std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+ const unsigned int projected_value_dim_prop =
+ std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+ const unsigned int output_shape =
+ std::get<props::OutputShape>(multi_head_attention_props).get();
+ const float dropout_rate =
+ std::get<props::DropOutRate>(multi_head_attention_props).get();
+ const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+ std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+ const bool average_attention_weight =
+ std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+ const bool provide_attention_mask = context.getNumInputs() == 4;
+ const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+
+ Tensor empty_tensor;
+
+ Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+ Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY);
+ Tensor &key = context.getInput(INOUT_INDEX::KEY);
+ Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY);
+ Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+ Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE);
+ const Tensor &incoming_derivative =
+ context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
+ const Tensor &d_ret_attention_weight =
+ return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+ ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+ : empty_tensor;
+
+ Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+
+ Tensor &projected_query =
+ context.getTensor(weight_idx[AttentionParams::projected_query]);
+ Tensor &d_projected_query =
+ context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
+ Tensor &projected_key =
+ context.getTensor(weight_idx[AttentionParams::projected_key]);
+ Tensor &d_projected_key =
+ context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
+ Tensor &projected_value =
+ context.getTensor(weight_idx[AttentionParams::projected_value]);
+ Tensor &d_projected_value =
+ context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
+
+ Tensor &d_attention_score =
+ context.getTensor(weight_idx[AttentionParams::d_attention_score]);
+ Tensor &attention_weight =
+ context.getTensor(weight_idx[AttentionParams::attention_weight]);
+ Tensor &d_attention_weight =
+ context.getTensorGrad(weight_idx[AttentionParams::attention_weight]);
+ Tensor &d_attention_output =
+ context.getTensorGrad(weight_idx[AttentionParams::attention_output]);
+
+ const TensorDim query_dim = query.getDim();
+ const unsigned int batch_size = query_dim.batch();
+ const unsigned int query_height = query_dim.height();
+ const unsigned int query_width = query_dim.width();
+ const TensorDim key_dim = key.getDim();
+ const unsigned int key_height = key_dim.height();
+ const unsigned int input_key_width_size = key_dim.width();
+ const TensorDim value_dim = value.getDim();
+ const unsigned int value_height = value_dim.height();
+ const unsigned int input_value_width_size = value_dim.width();
+
+ d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative);
+
+ d_attention_output.reshape(
+ TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop}));
+
+ d_attention_output = d_attention_output.transpose("1:0:2");
+
+ /** set tensor name to restore origin name cause origin name was remove during
+ * transpose */
+ d_attention_output.setName("multi_head_attention:attention_output:grad");
+
+ projected_query.reshape(TensorDim(
+ {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+ d_projected_query.reshape(TensorDim(
+ {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+ projected_key.reshape(
+ TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+ d_projected_key.reshape(
+ TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+ projected_value.reshape(TensorDim(
+ {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+ d_projected_value.reshape(TensorDim(
+ {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+
+ d_attention_score.reshape(
+ TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+ attention_weight.reshape(
+ TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+ d_attention_weight.reshape(
+ TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+ d_attention_output.reshape(TensorDim(
+ {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
+
+ d_attention_weight.dot_batched_deriv_wrt_1(projected_value,
+ d_attention_output);
+ attention_weight.dot_batched_deriv_wrt_2(d_projected_value,
+ d_attention_output);
+
+ if (return_attention_weight ==
+ props::ReturnAttentionWeightInfo::Enum::after) {
+ const float scale = average_attention_weight ? 1 / (float)num_heads : 1;
+ d_attention_weight.add_i(d_ret_attention_weight, scale);
+ }
+
+ if (dropout_rate > epsilon) {
+ Tensor &dropout_mask =
+ context.getTensor(weight_idx[AttentionParams::dropout_mask]);
+ d_attention_weight.multiply_i(dropout_mask);
+ }
+
+ if (return_attention_weight ==
+ props::ReturnAttentionWeightInfo::Enum::before) {
+ d_attention_weight.add_i(d_ret_attention_weight);
+ }
+
+ sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight);
+ if (provide_attention_mask) {
+ Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
+ d_mask.copyData(d_attention_score);
+ }
+ d_attention_score.multiply_i(
+ 1 / sqrt((float)projected_query_dim_prop)); /** scale */
+
+ d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score,
+ false, true);
+ projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score,
+ false, true);
+
+ d_projected_query.reshape(
+ TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop}));
+ d_projected_key.reshape(
+ TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop}));
+ d_projected_value.reshape(
+ TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop}));
+
+ d_projected_query = d_projected_query.transpose("1:0:2");
+ d_projected_key = d_projected_key.transpose("1:0:2");
+ d_projected_value = d_projected_value.transpose("1:0:2");
+
+ /** set tensor name to restore origin name cause origin name was remove during
+ * transpose */
+ d_projected_query.setName("multi_head_attention:projected_query:grad");
+ d_projected_key.setName("multi_head_attention:projected_key:grad");
+ d_projected_value.setName("multi_head_attention:projected_value:grad");
+
+ /** restore shape */
+ projected_query.reshape(TensorDim(
+ {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
+ d_projected_query.reshape(TensorDim(
+ {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
+ projected_key.reshape(
+ TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
+ d_projected_key.reshape(TensorDim(
+ {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
+ projected_value.reshape(TensorDim(
+ {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
+ d_projected_value.reshape(TensorDim(
+ {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
+
+ d_attention_score.reshape(
+ TensorDim({batch_size, num_heads, query_height, key_height}));
+ attention_weight.reshape(
+ TensorDim({batch_size, num_heads, query_height, key_height}));
+ d_attention_weight.reshape(
+ TensorDim({batch_size, num_heads, query_height, key_height}));
+ d_attention_output.reshape(TensorDim(
+ {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
+}
void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) {
if (!context.getTrainable()) {