From: Parichay Kapoor Date: Fri, 8 Oct 2021 08:33:12 +0000 (+0900) Subject: [layer] Attention support for different key value X-Git-Tag: accepted/tizen/unified/20220323.062643~348 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3113df80e3821d848df6b7074861afb8fb4526ec;p=platform%2Fcore%2Fml%2Fnntrainer.git [layer] Attention support for different key value This patch adds support for different values of key and value to be given to the attention layer. Corresponding unittests are also added. Signed-off-by: Parichay Kapoor --- diff --git a/nntrainer/layers/attention_layer.cpp b/nntrainer/layers/attention_layer.cpp index 58edc10..4212754 100644 --- a/nntrainer/layers/attention_layer.cpp +++ b/nntrainer/layers/attention_layer.cpp @@ -19,12 +19,11 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; -enum AttentionParams { query = 0, value = 1, score, weights }; +enum AttentionParams { query = 0, value = 1, key = 2, score, weights }; void AttentionLayer::finalize(InitLayerContext &context) { - if (context.getNumInputs() != 2) - throw std::runtime_error( - "Attention layer does not support exclusive keys."); + if (context.getNumInputs() < 2 || context.getNumInputs() > 3) + throw std::runtime_error("Attention layer needs 2-3 inputs."); sm.setActiFunc(ActivationType::ACT_SOFTMAX); @@ -32,8 +31,17 @@ void AttentionLayer::finalize(InitLayerContext &context) { auto const &query_dim = all_dims[AttentionParams::query]; auto const &value_dim = all_dims[AttentionParams::value]; - wt_idx[AttentionParams::query] = query; - wt_idx[AttentionParams::value] = value; + wt_idx[AttentionParams::query] = AttentionParams::query; + wt_idx[AttentionParams::value] = AttentionParams::value; + wt_idx[AttentionParams::key] = AttentionParams::value; + + if (context.getNumInputs() == 3) { + auto const &key_dim = all_dims[AttentionParams::key]; + if (key_dim != value_dim) + throw std::invalid_argument("Key and value must have same shape"); + + wt_idx[AttentionParams::key] = AttentionParams::key; + } auto weights_dim = query_dim; weights_dim.width(value_dim.height()); @@ -49,8 +57,9 @@ void AttentionLayer::finalize(InitLayerContext &context) { } void AttentionLayer::forwarding(RunLayerContext &context, bool training) { - Tensor &query = context.getInput(AttentionParams::query); - Tensor &value = context.getInput(AttentionParams::value); + Tensor &query = context.getInput(wt_idx[AttentionParams::query]); + Tensor &value = context.getInput(wt_idx[AttentionParams::value]); + Tensor &key = context.getInput(wt_idx[AttentionParams::key]); Tensor &output = context.getOutput(SINGLE_INOUT_IDX); Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]); @@ -60,11 +69,12 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) { /** @todo try using transpose to speedup the operation */ Tensor query_b = query.getBatchSlice(b, 1); Tensor value_b = value.getBatchSlice(b, 1); + Tensor key_b = key.getBatchSlice(b, 1); Tensor score_b = score.getBatchSlice(b, 1); Tensor weights_b = weights.getBatchSlice(b, 1); Tensor output_b = output.getBatchSlice(b, 1); - query_b.dot(value_b, score_b, false, true); + query_b.dot(key_b, score_b, false, true); sm.run_fn(score_b, weights_b); weights_b.dot(value_b, output_b); } @@ -72,31 +82,42 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) { void AttentionLayer::calcDerivative(RunLayerContext &context) { Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &query = context.getInput(AttentionParams::query); - Tensor &value = context.getInput(AttentionParams::value); - Tensor &dquery = context.getOutgoingDerivative(AttentionParams::query); - Tensor &dvalue = context.getOutgoingDerivative(AttentionParams::value); + Tensor &query = context.getInput(wt_idx[AttentionParams::query]); + Tensor &value = context.getInput(wt_idx[AttentionParams::value]); + Tensor &key = context.getInput(wt_idx[AttentionParams::key]); + + Tensor &dquery = + context.getOutgoingDerivative(wt_idx[AttentionParams::query]); + Tensor &dvalue = + context.getOutgoingDerivative(wt_idx[AttentionParams::value]); + Tensor &dkey = context.getOutgoingDerivative(wt_idx[AttentionParams::key]); + Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]); for (unsigned int b = 0; b < query.batch(); b++) { /** @todo try using transpose to speedup the operation */ Tensor query_b = query.getBatchSlice(b, 1); Tensor value_b = value.getBatchSlice(b, 1); + Tensor key_b = key.getBatchSlice(b, 1); Tensor weights_b = weights.getBatchSlice(b, 1); Tensor dquery_b = dquery.getBatchSlice(b, 1); Tensor dvalue_b = dvalue.getBatchSlice(b, 1); + Tensor dkey_b = dkey.getBatchSlice(b, 1); Tensor deriv_b = derivative.getBatchSlice(b, 1); Tensor dweight = deriv_b.dot(value_b, false, true); Tensor t1; sm.run_prime_fn(weights_b, t1, dweight); - t1.dot(value_b, dquery_b); + t1.dot(key_b, dquery_b); weights_b.dot(deriv_b, dvalue_b, true, false); - t1.dot(query_b, dvalue_b, true, false, 1.0); + if (context.getNumInputs() == 2) + t1.dot(query_b, dvalue_b, true, false, 1.0); + else + t1.dot(query_b, dkey_b, true, false); } } diff --git a/nntrainer/layers/attention_layer.h b/nntrainer/layers/attention_layer.h index eab01c4..6fe2e80 100644 --- a/nntrainer/layers/attention_layer.h +++ b/nntrainer/layers/attention_layer.h @@ -88,7 +88,7 @@ public: private: ActiFunc sm; /** softmax activation operation */ - std::array wt_idx; /**< indices of the weights and tensors */ + std::array wt_idx; /**< indices of the weights and tensors */ }; } // namespace nntrainer diff --git a/packaging/unittest_layers_v2.tar.gz b/packaging/unittest_layers_v2.tar.gz index ef93d65..04eeaff 100644 Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ diff --git a/test/input_gen/genLayerTests.py b/test/input_gen/genLayerTests.py index beec6f6..3c70b67 100644 --- a/test/input_gen/genLayerTests.py +++ b/test/input_gen/genLayerTests.py @@ -86,6 +86,9 @@ if __name__ == "__main__": attention = K.layers.Attention() record_single(attention, [(2, 5, 7), (2, 3, 7)], "attention_shared_kv_batched", {}, input_type='float') + attention = K.layers.Attention() + record_single(attention, [(2, 5, 7), (2, 3, 7), (2, 3, 7)], + "attention_batched", {}, input_type='float') inspect_file("conv_sb_no_overlap.nnlayergolden") diff --git a/test/unittest/layers/unittest_layers_attention.cpp b/test/unittest/layers/unittest_layers_attention.cpp index 634c3d7..3fc3140 100644 --- a/test/unittest/layers/unittest_layers_attention.cpp +++ b/test/unittest/layers/unittest_layers_attention.cpp @@ -32,6 +32,12 @@ auto attention_shared_kv_batched = LayerGoldenTestParamType( "attention_shared_kv_batched.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT); +auto attention_batched = LayerGoldenTestParamType( + nntrainer::createLayer, {}, + "2:1:5:7,2:1:3:7,2:1:3:7", "attention_batched.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT); + INSTANTIATE_TEST_CASE_P(Attention, LayerGoldenTest, ::testing::Values(attention_shared_kv, - attention_shared_kv_batched)); + attention_shared_kv_batched, + attention_batched));