[unittest] generate positional encoding unittest
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 25 Aug 2022 14:07:27 +0000 (23:07 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 7 Sep 2022 13:23:09 +0000 (22:23 +0900)
 - Generate positional encoding layer/model unittest.

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/positional_encoding_layer.cpp
nntrainer/layers/positional_encoding_layer.h
packaging/unittest_layers_v2.tar.gz
packaging/unittest_models_v2.tar.gz
test/input_gen/genLayerTests.py
test/input_gen/genModelTests_v2.py
test/input_gen/recorder.py
test/unittest/layers/meson.build
test/unittest/layers/unittest_layers_positional_encoding.cpp [new file with mode: 0644]
test/unittest/models/unittest_models.cpp

index 724e983..a5db15d 100644 (file)
@@ -27,6 +27,7 @@ enum PositionalEncodingParams {
 };
 
 PositionalEncodingLayer::PositionalEncodingLayer() :
+  isPEcalculated(false),
   positional_encoding_props(props::MaxTimestep()) {
   weight_idx.fill(std::numeric_limits<unsigned>::max());
 }
index 1c58c83..139bbbf 100644 (file)
@@ -94,10 +94,10 @@ public:
   inline static const std::string type = "positional_encoding";
 
 private:
-  std::tuple<props::MaxTimestep> positional_encoding_props;
-  std::array<unsigned int, 1> weight_idx;
   bool isPEcalculated; // bool value to check positional encoding is already
                        // calculated
+  std::tuple<props::MaxTimestep> positional_encoding_props;
+  std::array<unsigned int, 1> weight_idx;
 
   /**
    * @brief calculate positional encoding
index 1e5b099..3ff332f 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index ced1a23..34af99e 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 405bb2c..74e5c7d 100644 (file)
@@ -42,6 +42,33 @@ def inspect_file(file_name):
             print("size: ", sz)
             print(np.fromfile(f, dtype="float32", count=sz))
 
+class PositionalEncoding(tf.keras.layers.Layer):
+    def __init__(self, position, d_model):
+        super(PositionalEncoding, self).__init__()
+        self.position = position
+        self.d_model = d_model
+
+    def get_angles(self, pos, i, d_model):
+        angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
+        return pos * angle_rates
+
+    def build(self, input_shape):
+        angle_rads = self.get_angles(np.arange(self.position)[:, np.newaxis],
+                        np.arange(self.d_model)[np.newaxis, :], self.d_model)
+
+        # apply sin to even indices in the array; 2i
+        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
+
+        # apply cos to odd indices in the array; 2i+1
+        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
+
+        self.pos_encoding = angle_rads[np.newaxis, ...]
+
+        self.pos_encoding = tf.cast(self.pos_encoding, dtype=tf.float32)
+
+    def call(self, inputs):
+        inputs += self.pos_encoding[:, :tf.shape(inputs[0])[-2], :]
+        return inputs
 
 if __name__ == "__main__":
     fc = K.layers.Dense(5)
@@ -286,5 +313,9 @@ if __name__ == "__main__":
     concat = K.layers.Concatenate(axis=1)
     record_single(concat, [(2,2,3,3), (2, 3, 3, 3)], "concat_dim1")
 
+    positional_encoding = PositionalEncoding(10, 6)
+    record_single(positional_encoding, [(3, 1, 7, 6)], "positional_encoding_partial")
+    record_single(positional_encoding, [(3, 1, 10, 6)], "positional_encoding")
+
 inspect_file("dropout_20_training.nnlayergolden")
 
index 842d036..361d2c4 100644 (file)
@@ -8,6 +8,7 @@
 # @brief Generate model tcs
 # @author Parichay Kapoor <pk.kapoor@samsung.com>
 
+import math
 from recorder_v2 import record_v2, inspect_file, _rand_like
 import torch
 
@@ -119,6 +120,25 @@ class MultiHeadAttention(torch.nn.Module):
         labels = _rand_like(label_dims, dtype=float)
         return inputs, labels
 
+class PositionalEncoding(torch.nn.Module):
+    def __init__(self, d_model: int, max_len):
+        super().__init__()
+        position = torch.arange(max_len).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+        pe = torch.zeros(1, max_len, d_model)
+        pe[0, :, 0::2] = torch.sin(position * div_term)
+        pe[0, :, 1::2] = torch.cos(position * div_term)
+        self.register_buffer('pe', pe)
+        self.multi_head_attention = torch.nn.MultiheadAttention(d_model, 2, batch_first=True)
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        output = inputs[0]
+        output += self.pe[:,:output.size(1),:]
+        output = self.multi_head_attention(output, output, output)
+        loss = self.loss(output[0], labels[0])
+        return output, loss
+
 class FCRelu(torch.nn.Module):
     def __init__(self, decay=False):
         super().__init__()
@@ -148,7 +168,6 @@ class FCRelu(torch.nn.Module):
                 {'params': non_decay_params},
                 {'params': decay_params, 'weight_decay': 0.9}], lr=0.1)
 
-
 if __name__ == "__main__":
     record_v2(
         ReduceMeanLast(),
@@ -233,6 +252,15 @@ if __name__ == "__main__":
         name="multi_head_attention_self_attention",
     )
 
+    record_v2(
+        PositionalEncoding(d_model=6, max_len=7),
+        iteration=1,
+        input_dims=[(3,5,6)],
+        input_dtype=[float],
+        label_dims=[(3,5,6)],
+        name="positional_encoding",
+    )
+
     fc_relu_decay = FCRelu(decay=True)
     record_v2(
         fc_relu_decay,
index 95372f4..cf84a60 100644 (file)
@@ -436,7 +436,6 @@ def record_single(layer, input_shape, test_name, call_args={}, input_type='int')
             if not isinstance(tensors, list):
                 tensors = [tensors]
             for tensor in tensors:
-                print(tf.size(tensor))
                 writer(tf.size(tensor), tensor)
 
         ## @todo inputs outputs derivatives can be more than one
index b6f43c5..9864ff6 100644 (file)
@@ -61,6 +61,7 @@ test_target = [
   'unittest_layers_reshape.cpp',
   # 'unittest_layers_mol_attention.cpp',
   'unittest_layers_multi_head_attention.cpp',
+  'unittest_layers_positional_encoding.cpp',
 ]
 
 if get_option('enable-tflite-backbone')
diff --git a/test/unittest/layers/unittest_layers_positional_encoding.cpp b/test/unittest/layers/unittest_layers_positional_encoding.cpp
new file mode 100644 (file)
index 0000000..4bf5544
--- /dev/null
@@ -0,0 +1,39 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2022 Hyeonseok Lee <hs89.lee@samsung.com>
+ *
+ * @file unittest_layers_positional_encoding.cpp
+ * @date 24 August 2022
+ * @brief PositionalEncodingLayer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Hyeonseok Lee <hs89.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <layers_common_tests.h>
+#include <positional_encoding_layer.h>
+
+auto semantic_positional_encoding = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
+  nntrainer::PositionalEncodingLayer::type, {"max_timestep=10"}, 0, false, 1);
+
+INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerSemantics,
+                        ::testing::Values(semantic_positional_encoding));
+
+auto positional_encoding_partial = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
+  {"max_timestep=10"}, "3:1:7:6", "positional_encoding_partial.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT);
+
+auto positional_encoding = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
+  {"max_timestep=10"}, "3:1:10:6", "positional_encoding.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT);
+
+INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerGoldenTest,
+                        ::testing::Values(positional_encoding_partial,
+                                          positional_encoding));
index 2e3ab1b..209e209 100644 (file)
@@ -225,6 +225,30 @@ static std::unique_ptr<NeuralNetwork> makeMultiHeadAttention_self_attention() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makePositionalEncoding() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=5:1:6"}},
+    {"reshape", {"name=reshape", "target_shape=1:5:6"}},
+    {"positional_encoding", {"name=positional_encoding", "max_timestep=7"}},
+    {"multi_head_attention",
+     {"name=multi_head_attention",
+      "input_layers=positional_encoding, positional_encoding, "
+      "positional_encoding",
+      "num_heads=2"}},
+    {"mse", {"name=loss", "input_layers=multi_head_attention(0)"}},
+  });
+
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 GTEST_PARAMETER_TEST(
   model, nntrainerModelTest,
   ::testing::ValuesIn({
@@ -251,6 +275,8 @@ GTEST_PARAMETER_TEST(
     mkModelTc_V2(makeMultiHeadAttention_self_attention,
                  "multi_head_attention_self_attention",
                  ModelTestOption::ALL_V2),
+    mkModelTc_V2(makePositionalEncoding, "positional_encoding",
+                 ModelTestOption::ALL_V2),
     mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_,
                  ModelTestOption::COMPARE_V2),
   }),