From cb2a96894aa99285b569cf3573fee5ceae51a47a Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Wed, 7 Sep 2022 13:31:24 +0900 Subject: [PATCH] [layer] revise attention layers to apply softmax as inplace - Remove attention_score tensor in attention/multi_head_attention layer to apply softmax as inplace - Modify tensor lifespan of fc_out to FORWARD_FUNC_LIFESPAN - remove unused enum updated_state Signed-off-by: hyeonseok lee --- nntrainer/layers/attention_layer.cpp | 25 ++++--------- nntrainer/layers/attention_layer.h | 2 +- nntrainer/layers/mol_attention_layer.cpp | 9 ++--- nntrainer/layers/multi_head_attention_layer.cpp | 50 +++++++------------------ nntrainer/layers/multi_head_attention_layer.h | 2 +- 5 files changed, 26 insertions(+), 62 deletions(-) diff --git a/nntrainer/layers/attention_layer.cpp b/nntrainer/layers/attention_layer.cpp index 07117a2..a2767cd 100644 --- a/nntrainer/layers/attention_layer.cpp +++ b/nntrainer/layers/attention_layer.cpp @@ -18,7 +18,7 @@ namespace nntrainer { -AttentionLayer::AttentionLayer() { +AttentionLayer::AttentionLayer() : sm(ActivationType::ACT_SOFTMAX) { wt_idx.fill(std::numeric_limits::max()); } @@ -26,7 +26,7 @@ AttentionLayer::~AttentionLayer() {} static constexpr size_t SINGLE_INOUT_IDX = 0; -enum AttentionParams { query = 0, value = 1, key = 2, score, weights }; +enum AttentionParams { query = 0, value = 1, key = 2, weights }; void AttentionLayer::finalizeCommon(InitLayerContext &context) { if (context.getNumInputs() < 2 || context.getNumInputs() > 3) @@ -55,8 +55,6 @@ void AttentionLayer::finalizeCommon(InitLayerContext &context) { void AttentionLayer::finalize(InitLayerContext &context) { finalizeCommon(context); - sm.setActiFunc(ActivationType::ACT_SOFTMAX); - auto const &all_dims = context.getInputDimensions(); auto const &query_dim = all_dims[AttentionParams::query]; auto const &value_dim = all_dims[AttentionParams::value]; @@ -67,10 +65,6 @@ void AttentionLayer::finalize(InitLayerContext &context) { context.requestTensor(weights_dim, "weights", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); - wt_idx[AttentionParams::score] = - context.requestTensor(weights_dim, "score", Tensor::Initializer::NONE, - false, TensorLifespan::FORWARD_FUNC_LIFESPAN); - context.setOutputDimensions({query_dim}); } @@ -81,11 +75,10 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) { Tensor &output = context.getOutput(SINGLE_INOUT_IDX); Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]); - Tensor &score = context.getTensor(wt_idx[AttentionParams::score]); - query.dotBatched(key, score, false, true); /** dot 1 */ - sm.run_fn(score, weights); /** softmax */ - weights.dotBatched(value, output); /** dot 2 */ + query.dotBatched(key, weights, false, true); /** dot 1 */ + sm.run_fn(weights, weights); /** softmax */ + weights.dotBatched(value, output); /** dot 2 */ } void AttentionLayer::calcDerivative(RunLayerContext &context) { @@ -111,12 +104,11 @@ void AttentionLayer::calcDerivative(RunLayerContext &context) { weights.dot_batched_deriv_wrt_2(dvalue, derivative); /** derivative for softmax */ - Tensor dscore; - sm.run_prime_fn(weights, dscore, dweight); + sm.run_prime_fn(weights, dweight, dweight); /** derivative for dot 1 */ - dquery.dot_batched_deriv_wrt_1(key, dscore, false, true); - query.dot_batched_deriv_wrt_2(dkey, dscore, false, true, + dquery.dot_batched_deriv_wrt_1(key, dweight, false, true); + query.dot_batched_deriv_wrt_2(dkey, dweight, false, true, context.getNumInputs() == 2); } @@ -129,7 +121,6 @@ void AttentionLayer::setProperty(const std::vector &values) { } void AttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) { - context.updateTensor(wt_idx[AttentionParams::score], batch); context.updateTensor(wt_idx[AttentionParams::weights], batch); } diff --git a/nntrainer/layers/attention_layer.h b/nntrainer/layers/attention_layer.h index dcfb6c2..d1d727a 100644 --- a/nntrainer/layers/attention_layer.h +++ b/nntrainer/layers/attention_layer.h @@ -104,7 +104,7 @@ protected: private: ActiFunc sm; /** softmax activation operation */ - std::array wt_idx; /**< indices of the weights and tensors */ + std::array wt_idx; /**< indices of the weights and tensors */ }; } // namespace nntrainer diff --git a/nntrainer/layers/mol_attention_layer.cpp b/nntrainer/layers/mol_attention_layer.cpp index 42e1110..a5b2d51 100644 --- a/nntrainer/layers/mol_attention_layer.cpp +++ b/nntrainer/layers/mol_attention_layer.cpp @@ -51,7 +51,6 @@ enum AttentionParams { u_neg_div, u_pos_div, dstate, - updated_state }; void MoLAttentionLayer::finalize(InitLayerContext &context) { @@ -113,7 +112,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) { fc_out_dim.width(fc_w_dim.width()); wt_idx[AttentionParams::fc_out] = context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE, - false, TensorLifespan::ITERATION_LIFESPAN); + false, TensorLifespan::FORWARD_FUNC_LIFESPAN); wt_idx[AttentionParams::fc_tanh] = context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE, @@ -363,7 +362,6 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) { Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]); Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]); - Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]); Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]); Tensor &dfc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]); @@ -376,7 +374,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) { else dstate.copyData(dstate_local); - Tensor dfc_tanh = Tensor(fc_out.getDim()); + Tensor dfc_tanh = Tensor(fc_tanh.getDim()); dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false); Tensor dfc_out; @@ -393,7 +391,6 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) { Tensor &dfc_bias = context.getWeightGrad(wt_idx[AttentionParams::fc_bias]); Tensor &dfc_proj_w = context.getWeightGrad(wt_idx[AttentionParams::fc_proj_w]); - Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]); Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]); Tensor &dfc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]); @@ -401,7 +398,7 @@ void MoLAttentionLayer::calcGradient(RunLayerContext &context) { if (!helper_exec) calcDerivativeHelper(context, dstate); - Tensor dfc_tanh = Tensor(fc_out.getDim()); + Tensor dfc_tanh = Tensor(fc_tanh.getDim()); fc_tanh.dot_deriv_wrt_2( dfc_proj_w, dfc_proj_out, false, false, !context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w])); diff --git a/nntrainer/layers/multi_head_attention_layer.cpp b/nntrainer/layers/multi_head_attention_layer.cpp index f91f799..3b6cc37 100644 --- a/nntrainer/layers/multi_head_attention_layer.cpp +++ b/nntrainer/layers/multi_head_attention_layer.cpp @@ -27,6 +27,7 @@ MultiHeadAttentionLayer::MultiHeadAttentionLayer() : props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(), props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(), props::AverageAttentionWeight()), + sm(ActivationType::ACT_SOFTMAX), epsilon(1e-3) { weight_idx.fill(std::numeric_limits::max()); } @@ -56,8 +57,6 @@ enum AttentionParams { projected_query, projected_key, projected_value, - attention_score, - d_attention_score, /** intended comment for later use of attention_mask */ // attention_mask, attention_weight, @@ -258,15 +257,6 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) { projected_value_dim, "projected_value", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); - /** tensor for attention score */ - TensorDim attention_score_dim( - {batch_size, num_heads, query_height, key_height}); - weight_idx[AttentionParams::attention_score] = context.requestTensor( - attention_score_dim, "attention_score", Tensor::Initializer::NONE, false, - TensorLifespan::FORWARD_FUNC_LIFESPAN); - weight_idx[AttentionParams::d_attention_score] = context.requestTensor( - attention_score_dim, "d_attention_score", Tensor::Initializer::NONE, false, - TensorLifespan::BACKWARD_FUNC_LIFESPAN); if (provide_attention_mask) { /** Intended comment for bool type mask */ // TensorDim attention_mask_dim( @@ -376,8 +366,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context, Tensor &projected_value = context.getTensor(weight_idx[AttentionParams::projected_value]); - Tensor &attention_score = - context.getTensor(weight_idx[AttentionParams::attention_score]); Tensor &attention_weight = context.getTensor(weight_idx[AttentionParams::attention_weight]); Tensor &attention_output = @@ -431,16 +419,14 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context, projected_value.reshape(TensorDim( {batch_size * num_heads, 1, value_height, projected_value_dim_prop})); - attention_score.reshape( - TensorDim({batch_size * num_heads, 1, query_height, key_height})); attention_weight.reshape( TensorDim({batch_size * num_heads, 1, query_height, key_height})); attention_output.reshape(TensorDim( {batch_size * num_heads, 1, query_height, projected_value_dim_prop})); /** scaled dot product attention */ - projected_query.dotBatched(projected_key, attention_score, false, true); - attention_score.multiply_i(1 / sqrt((float)projected_query_dim_prop)); + projected_query.dotBatched(projected_key, attention_weight, false, true); + attention_weight.multiply_i(1 / sqrt((float)projected_query_dim_prop)); if (provide_attention_mask) { // Tensor &attention_mask = @@ -456,16 +442,16 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context, // attention_mask.multiply_i(-1e9); // } // attention_mask.multiply_i(mask); - // attention_score.add_i(attention_mask); + // attention_weight.add_i(attention_mask); - attention_score.reshape( + attention_weight.reshape( TensorDim({batch_size, num_heads, query_height, key_height})); - attention_score.add_i(mask); - attention_score.reshape( + attention_weight.add_i(mask); + attention_weight.reshape( TensorDim({batch_size * num_heads, 1, query_height, key_height})); } - sm.run_fn(attention_score, attention_weight); + sm.run_fn(attention_weight, attention_weight); if (return_attention_weight == props::ReturnAttentionWeightInfo::Enum::before) { @@ -520,8 +506,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context, projected_value.reshape(TensorDim( {batch_size, 1, value_height, num_heads * projected_value_dim_prop})); - attention_score.reshape( - TensorDim({batch_size, num_heads, query_height, key_height})); attention_weight.reshape( TensorDim({batch_size, num_heads, query_height, key_height})); attention_output.reshape(TensorDim( @@ -577,8 +561,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) { Tensor &d_projected_value = context.getTensorGrad(weight_idx[AttentionParams::projected_value]); - Tensor &d_attention_score = - context.getTensor(weight_idx[AttentionParams::d_attention_score]); Tensor &attention_weight = context.getTensor(weight_idx[AttentionParams::attention_weight]); Tensor &d_attention_weight = @@ -621,8 +603,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) { d_projected_value.reshape(TensorDim( {batch_size * num_heads, 1, value_height, projected_value_dim_prop})); - d_attention_score.reshape( - TensorDim({batch_size * num_heads, 1, query_height, key_height})); attention_weight.reshape( TensorDim({batch_size * num_heads, 1, query_height, key_height})); d_attention_weight.reshape( @@ -652,17 +632,17 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) { d_attention_weight.add_i(d_ret_attention_weight); } - sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight); + sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight); if (provide_attention_mask) { Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK); - d_mask.copyData(d_attention_score); + d_mask.copyData(d_attention_weight); } - d_attention_score.multiply_i( + d_attention_weight.multiply_i( 1 / sqrt((float)projected_query_dim_prop)); /** scale */ - d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score, + d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight, false, true); - projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score, + projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight, false, true); d_projected_query.reshape( @@ -696,8 +676,6 @@ void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) { d_projected_value.reshape(TensorDim( {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop})); - d_attention_score.reshape( - TensorDim({batch_size, num_heads, query_height, key_height})); attention_weight.reshape( TensorDim({batch_size, num_heads, query_height, key_height})); d_attention_weight.reshape( @@ -895,8 +873,6 @@ void MultiHeadAttentionLayer::setBatch(RunLayerContext &context, context.updateTensor(weight_idx[AttentionParams::projected_query], batch); context.updateTensor(weight_idx[AttentionParams::projected_key], batch); context.updateTensor(weight_idx[AttentionParams::projected_value], batch); - context.updateTensor(weight_idx[AttentionParams::attention_score], batch); - context.updateTensor(weight_idx[AttentionParams::d_attention_score], batch); context.updateTensor(weight_idx[AttentionParams::attention_weight], batch); if (dropout_rate > epsilon) { context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch); diff --git a/nntrainer/layers/multi_head_attention_layer.h b/nntrainer/layers/multi_head_attention_layer.h index 65141ea..4085aaf 100644 --- a/nntrainer/layers/multi_head_attention_layer.h +++ b/nntrainer/layers/multi_head_attention_layer.h @@ -108,7 +108,7 @@ private: multi_head_attention_props; /**< multi_head_attention layer properties */ ActiFunc sm; /** softmax activation operation */ - std::array + std::array weight_idx; /**< indices of the weights and tensors */ /** -- 2.7.4