[layer] Attention support for different key value
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 8 Oct 2021 08:33:12 +0000 (17:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 8 Oct 2021 11:25:06 +0000 (20:25 +0900)
This patch adds support for different values of key and value to
be given to the attention layer.
Corresponding unittests are also added.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/attention_layer.cpp
nntrainer/layers/attention_layer.h
packaging/unittest_layers_v2.tar.gz
test/input_gen/genLayerTests.py
test/unittest/layers/unittest_layers_attention.cpp

index 58edc10..4212754 100644 (file)
@@ -19,12 +19,11 @@ namespace nntrainer {
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum AttentionParams { query = 0, value = 1, score, weights };
+enum AttentionParams { query = 0, value = 1, key = 2, score, weights };
 
 void AttentionLayer::finalize(InitLayerContext &context) {
-  if (context.getNumInputs() != 2)
-    throw std::runtime_error(
-      "Attention layer does not support exclusive keys.");
+  if (context.getNumInputs() < 2 || context.getNumInputs() > 3)
+    throw std::runtime_error("Attention layer needs 2-3 inputs.");
 
   sm.setActiFunc(ActivationType::ACT_SOFTMAX);
 
@@ -32,8 +31,17 @@ void AttentionLayer::finalize(InitLayerContext &context) {
   auto const &query_dim = all_dims[AttentionParams::query];
   auto const &value_dim = all_dims[AttentionParams::value];
 
-  wt_idx[AttentionParams::query] = query;
-  wt_idx[AttentionParams::value] = value;
+  wt_idx[AttentionParams::query] = AttentionParams::query;
+  wt_idx[AttentionParams::value] = AttentionParams::value;
+  wt_idx[AttentionParams::key] = AttentionParams::value;
+
+  if (context.getNumInputs() == 3) {
+    auto const &key_dim = all_dims[AttentionParams::key];
+    if (key_dim != value_dim)
+      throw std::invalid_argument("Key and value must have same shape");
+
+    wt_idx[AttentionParams::key] = AttentionParams::key;
+  }
 
   auto weights_dim = query_dim;
   weights_dim.width(value_dim.height());
@@ -49,8 +57,9 @@ void AttentionLayer::finalize(InitLayerContext &context) {
 }
 
 void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
-  Tensor &query = context.getInput(AttentionParams::query);
-  Tensor &value = context.getInput(AttentionParams::value);
+  Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
+  Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
+  Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
@@ -60,11 +69,12 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
     /** @todo try using transpose to speedup the operation */
     Tensor query_b = query.getBatchSlice(b, 1);
     Tensor value_b = value.getBatchSlice(b, 1);
+    Tensor key_b = key.getBatchSlice(b, 1);
     Tensor score_b = score.getBatchSlice(b, 1);
     Tensor weights_b = weights.getBatchSlice(b, 1);
     Tensor output_b = output.getBatchSlice(b, 1);
 
-    query_b.dot(value_b, score_b, false, true);
+    query_b.dot(key_b, score_b, false, true);
     sm.run_fn(score_b, weights_b);
     weights_b.dot(value_b, output_b);
   }
@@ -72,31 +82,42 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
 
 void AttentionLayer::calcDerivative(RunLayerContext &context) {
   Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
-  Tensor &query = context.getInput(AttentionParams::query);
-  Tensor &value = context.getInput(AttentionParams::value);
 
-  Tensor &dquery = context.getOutgoingDerivative(AttentionParams::query);
-  Tensor &dvalue = context.getOutgoingDerivative(AttentionParams::value);
+  Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
+  Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
+  Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
+
+  Tensor &dquery =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::query]);
+  Tensor &dvalue =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
+  Tensor &dkey = context.getOutgoingDerivative(wt_idx[AttentionParams::key]);
+
   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
 
   for (unsigned int b = 0; b < query.batch(); b++) {
     /** @todo try using transpose to speedup the operation */
     Tensor query_b = query.getBatchSlice(b, 1);
     Tensor value_b = value.getBatchSlice(b, 1);
+    Tensor key_b = key.getBatchSlice(b, 1);
     Tensor weights_b = weights.getBatchSlice(b, 1);
 
     Tensor dquery_b = dquery.getBatchSlice(b, 1);
     Tensor dvalue_b = dvalue.getBatchSlice(b, 1);
+    Tensor dkey_b = dkey.getBatchSlice(b, 1);
     Tensor deriv_b = derivative.getBatchSlice(b, 1);
 
     Tensor dweight = deriv_b.dot(value_b, false, true);
 
     Tensor t1;
     sm.run_prime_fn(weights_b, t1, dweight);
-    t1.dot(value_b, dquery_b);
+    t1.dot(key_b, dquery_b);
 
     weights_b.dot(deriv_b, dvalue_b, true, false);
-    t1.dot(query_b, dvalue_b, true, false, 1.0);
+    if (context.getNumInputs() == 2)
+      t1.dot(query_b, dvalue_b, true, false, 1.0);
+    else
+      t1.dot(query_b, dkey_b, true, false);
   }
 }
 
index eab01c4..6fe2e80 100644 (file)
@@ -88,7 +88,7 @@ public:
 
 private:
   ActiFunc sm;                        /** softmax activation operation */
-  std::array<unsigned int, 4> wt_idx; /**< indices of the weights and tensors */
+  std::array<unsigned int, 5> wt_idx; /**< indices of the weights and tensors */
 };
 
 } // namespace nntrainer
index ef93d65..04eeaff 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index beec6f6..3c70b67 100644 (file)
@@ -86,6 +86,9 @@ if __name__ == "__main__":
     attention = K.layers.Attention()
     record_single(attention, [(2, 5, 7), (2, 3, 7)],
                  "attention_shared_kv_batched", {}, input_type='float')
+    attention = K.layers.Attention()
+    record_single(attention, [(2, 5, 7), (2, 3, 7), (2, 3, 7)],
+                 "attention_batched", {}, input_type='float')
 
 inspect_file("conv_sb_no_overlap.nnlayergolden")
 
index 634c3d7..3fc3140 100644 (file)
@@ -32,6 +32,12 @@ auto attention_shared_kv_batched = LayerGoldenTestParamType(
   "attention_shared_kv_batched.nnlayergolden",
   LayerGoldenTestParamOptions::DEFAULT);
 
+auto attention_batched = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::AttentionLayer>, {},
+  "2:1:5:7,2:1:3:7,2:1:3:7", "attention_batched.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT);
+
 INSTANTIATE_TEST_CASE_P(Attention, LayerGoldenTest,
                         ::testing::Values(attention_shared_kv,
-                                          attention_shared_kv_batched));
+                                          attention_shared_kv_batched,
+                                          attention_batched));