[layer] Add backwarding for attention layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 1 Oct 2021 08:23:04 +0000 (17:23 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 6 Oct 2021 12:05:18 +0000 (21:05 +0900)
This patch adds backwarding for attention layer. Corresponding unittests
will be added in the next patch.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
jni/Android.mk
nntrainer/layers/attention_layer.cpp
nntrainer/layers/attention_layer.h
test/unittest/layers/unittest_layers_attention.cpp

index 65c6751..e4e2f94 100644 (file)
@@ -159,6 +159,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/activation_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/flatten_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/addition_layer.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/attention_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/concat_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/preprocess_flip_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/preprocess_translate_layer.cpp \
index 9d6a2ed..b876e0e 100644 (file)
@@ -19,7 +19,7 @@ namespace nntrainer {
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum AttentionParams { query = 0, value = 1 };
+enum AttentionParams { query = 0, value = 1, score, weights };
 
 void AttentionLayer::finalize(InitLayerContext &context) {
   if (context.getNumInputs() != 2)
@@ -28,10 +28,24 @@ void AttentionLayer::finalize(InitLayerContext &context) {
 
   sm.setActiFunc(ActivationType::ACT_SOFTMAX);
 
-  auto const &all_shapes = context.getInputDimensions();
-  auto const &query_shape = all_shapes[AttentionParams::query];
+  auto const &all_dims = context.getInputDimensions();
+  auto const &query_dim = all_dims[AttentionParams::query];
+  auto const &value_dim = all_dims[AttentionParams::value];
 
-  context.setOutputDimensions({query_shape});
+  wt_idx[AttentionParams::query] = query;
+  wt_idx[AttentionParams::value] = value;
+
+  auto weights_dim = query_dim;
+  weights_dim.width(value_dim.width());
+  wt_idx[AttentionParams::weights] = context.requestTensor(
+    weights_dim, context.getName() + ":weights", Tensor::Initializer::NONE,
+    false, TensorLifespan::ITERATION_LIFESPAN);
+
+  wt_idx[AttentionParams::score] = context.requestTensor(
+    weights_dim, context.getName() + ":score", Tensor::Initializer::NONE, false,
+    TensorLifespan::FORWARD_FUNC_LIFESPAN);
+
+  context.setOutputDimensions({query_dim});
 }
 
 void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
@@ -39,17 +53,30 @@ void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &value = context.getInput(AttentionParams::value);
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
-  Tensor distribution;
+  Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
+  Tensor &score = context.getTensor(wt_idx[AttentionParams::weights]);
 
-  Tensor score = query.dot(value, false, true);
-  sm.run_fn(score, distribution);
-  distribution.dot(value, output);
+  query.dot(value, score, false, true);
+  sm.run_fn(score, weights);
+  weights.dot(value, output);
 }
 
 void AttentionLayer::calcDerivative(RunLayerContext &context) {
-  /**
-   * Not yet implemented
-   */
+  Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  Tensor &query = context.getInput(AttentionParams::query);
+  Tensor &value = context.getInput(AttentionParams::value);
+
+  Tensor &dquery = context.getOutgoingDerivative(AttentionParams::query);
+  Tensor &dvalue = context.getOutgoingDerivative(AttentionParams::value);
+  Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
+
+  Tensor t1;
+  sm.run_prime_fn(weights, t1, derivative_);
+
+  Tensor t2 = value.dot(t1);
+  dquery = t2.dot(value).dot(derivative_);
+
+  dvalue = t2.dot(query).add(weights).dot(derivative_);
 }
 
 void AttentionLayer::setProperty(const std::vector<std::string> &values) {
index 900330e..eab01c4 100644 (file)
@@ -87,7 +87,8 @@ public:
   inline static const std::string type = "attention";
 
 private:
-  ActiFunc sm; /** softmax activation operation */
+  ActiFunc sm;                        /** softmax activation operation */
+  std::array<unsigned int, 4> wt_idx; /**< indices of the weights and tensors */
 };
 
 } // namespace nntrainer
index abe367f..e47f152 100644 (file)
@@ -2,7 +2,7 @@
 /**
  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
  *
- * @file unittest_layers_addition.cpp
+ * @file unittest_layers_attention.cpp
  * @date 1 October 2021
  * @brief Attention Layer Test
  * @see        https://github.com/nnstreamer/nntrainer
@@ -20,5 +20,5 @@ auto semantic_attention =
   LayerSemanticsParamType(nntrainer::createLayer<nntrainer::AttentionLayer>,
                           nntrainer::AttentionLayer::type, {}, 0, false, 2);
 
-INSTANTIATE_TEST_CASE_P(Addition, LayerSemantics,
+INSTANTIATE_TEST_CASE_P(Attention, LayerSemantics,
                         ::testing::Values(semantic_attention));