[Test] Update layer golden test format
authorJihoon Lee <jhoon.it.lee@samsung.com>
Thu, 23 Sep 2021 07:19:07 +0000 (16:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 28 Sep 2021 11:24:01 +0000 (20:24 +0900)
**Changes proposed in this PR:**
- Add initial_weight to the layer golden data
- Add layer::build() point to the translayer
- Properly call layer instead of  __call__ in translayer
- Run formatter(black)

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
packaging/unittest_layers_v2.tar.gz
test/input_gen/genLayerTests.py
test/input_gen/recorder.py
test/input_gen/transLayer.py
test/unittest/layers/layers_golden_tests.cpp

index cfc4b07..58de15e 100644 (file)
Binary files a/packaging/unittest_layers_v2.tar.gz and b/packaging/unittest_layers_v2.tar.gz differ
index c248598..76a1449 100644 (file)
@@ -7,10 +7,14 @@
 # @date 13 Se 2020
 # @brief Generate *.nnlayergolden file
 # *.nnlayergolden file is expected to contain following information **in order**
-# ## TBA ##
+# - Initial Weights
+# - inputs
+# - outputs
+# - *gradients
+# - weights
+# - derivatives
 #
-#
-# @author Jihoon lee <jhoon.it.lee@samsung.com>
+# @author Jihoon Lee <jhoon.it.lee@samsung.com>
 
 from multiprocessing.sharedctypes import Value
 import warnings
@@ -25,28 +29,25 @@ with warnings.catch_warnings():
     import tensorflow as tf
     from tensorflow.python import keras as K
 
-from transLayer import attach_trans_layer as TL
-
-
 ##
 # @brief inpsect if file is created correctly
 # @note this just checks if offset is corretly set, The result have to inspected
 # manually
 def inspect_file(file_name):
-    import struct
     with open(file_name, "rb") as f:
         while True:
-            sz = int.from_bytes(f.read(4), byteorder='little')
+            sz = int.from_bytes(f.read(4), byteorder="little")
             if not sz:
                 break
             print("size: ", sz)
-            print(np.fromfile(f, dtype='float32', count=sz))
+            print(np.fromfile(f, dtype="float32", count=sz))
+
 
 if __name__ == "__main__":
     fc = K.layers.Dense(5)
-    record_single(fc, (3, 1, 1, 10), "fc_golden_plain.nnlayergolden")
+    record_single(fc, (3, 1, 1, 10), "fc_golden_plain")
     fc = K.layers.Dense(4)
-    record_single(fc, (1, 1, 1, 10), "fc_golden_single_batch.nnlayergolden")
+    record_single(fc, (1, 1, 1, 10), "fc_golden_single_batch")
 
 # inspect_file("fc_golden.nnlayergolden")
 
index eb2b0c5..15b7270 100644 (file)
@@ -99,7 +99,7 @@ def _debug_print(
     debug=None,
     print_option={"end": "\n"},
     print_format=_debug_default_formatter,
-    **data
+    **data,
 ):
     if not debug:
         return
@@ -124,7 +124,7 @@ def prepare_data(model, input_shape, label_shape, writer_fn, is_onehot, **kwargs
             depth=label_shape[1],
         )
     else:
-        label=_rand_like(label_shape) / 10
+        label = _rand_like(label_shape) / 10
 
     initial_weights = []
     for layer in iter_model(model):
@@ -139,7 +139,7 @@ def prepare_data(model, input_shape, label_shape, writer_fn, is_onehot, **kwargs
         initial_input=initial_input,
         label=label,
         initial_weights=initial_weights,
-        **kwargs
+        **kwargs,
     )
 
     return initial_input, label
@@ -175,8 +175,8 @@ def train_step(model, optimizer, loss_fn, initial_input, label, writer_fn, **kwa
 
         # loss = loss_fn(label, outp[-1])
         loss = []
-        if kwargs.get('multi_out', None) != None:
-            multi_out = kwargs.get('multi_out', [])
+        if kwargs.get("multi_out", None) != None:
+            multi_out = kwargs.get("multi_out", [])
         else:
             multi_out = [-1]
         for i in multi_out:
@@ -216,7 +216,7 @@ def train_step(model, optimizer, loss_fn, initial_input, label, writer_fn, **kwa
             *layer_output,  # output of forward
             *dx,  # output of backward
             *gradients,  # weight gradient output from backward
-            *weights  # updated weight after optimization
+            *weights,  # updated weight after optimization
         )
 
         _debug_print(name=layer.name, print_format=value_only_formatter, **kwargs)
@@ -235,7 +235,7 @@ def train_step(model, optimizer, loss_fn, initial_input, label, writer_fn, **kwa
             weights=weights,
             gradients=gradients,
             dx_shape=[i.shape for i in dx],
-            **kwargs
+            **kwargs,
         )
 
     for l in loss:
@@ -367,13 +367,15 @@ def record(
     inputs=None,
     outputs=None,
     is_onehot=True,
-    **kwargs
+    **kwargs,
 ):
     if os.path.isfile(file_name):
         print("Warning: the file %s is being truncated and overwritten" % file_name)
 
     loss_fn = _get_loss_fn(loss_fn_str)
-    model = generate_recordable_model(loss_fn_str, model, inputs, outputs, is_onehot, **kwargs)
+    model = generate_recordable_model(
+        loss_fn_str, model, inputs, outputs, is_onehot, **kwargs
+    )
 
     with open(file_name, "wb") as f:
         write = _get_writer(f)
@@ -385,7 +387,7 @@ def record(
             _debug_print(
                 iteration="\033[1;33m[%d/%d]\033[0m" % (_ + 1, iteration),
                 print_format=value_only_formatter,
-                **kwargs
+                **kwargs,
             )
             train_step(model, optimizer, loss_fn, initial_input, label, write, **kwargs)
 
@@ -394,14 +396,16 @@ def record(
 
 ##
 # @brief record a single layer
-def record_single(layer, input_shape, file_name):
+def record_single(layer, input_shape, test_name, call_args={}):
     layer = attach_trans_layer(layer)
+    layer.build(input_shape)
     inputs = _rand_like(input_shape)
 
+    initial_weights = [tf.Variable(i) for i in layer.weights]
     with tf.GradientTape(persistent=True) as tape:
         tape.watch(inputs)
-        outputs = layer(inputs)
-        dy_constant = outputs * 2 # set incoming derivative to 2 instead of 1
+        outputs = layer.call(inputs, **call_args)
+        dy_constant = outputs * 2  # set incoming derivative to 2 instead of 1
 
     weights = layer.weights.copy()
     gradients = tape.gradient(dy_constant, layer.trainable_weights)
@@ -412,7 +416,7 @@ def record_single(layer, input_shape, file_name):
     except AttributeError:
         pass
 
-    with open(file_name, "wb") as f:
+    with open(test_name + ".nnlayergolden", "wb") as f:
         writer = _get_writer(f)
 
         def write_tensor(*tensors):
@@ -421,10 +425,11 @@ def record_single(layer, input_shape, file_name):
                 writer(tf.size(tensor), tensor)
 
         ## @todo inputs outputs derivatives can be more than one
-        write_tensor(*weights)
+        ## @note please update genLayerTests.py comments when updating below
+        write_tensor(*initial_weights)
         write_tensor(inputs)
         write_tensor(outputs)
         write_tensor(*gradients)
+        write_tensor(*weights)
         write_tensor(derivatives)
 
-
index b0a29db..4e8e659 100644 (file)
@@ -31,6 +31,11 @@ class AbstractTransLayer(K.layers.Layer):
         self.call.__func__.__signature__ = signature(self.tf_layer.call)
         self.has_training = "training" in inspect.getfullargspec(self.call).args
 
+    def build(self, input_shape):
+        if not self.built:
+            self.tf_layer.build(input_shape)
+            super().build(input_shape)
+
     ##
     # @brief call function
     # @param nntr_input input with nntrainer layout
@@ -41,7 +46,7 @@ class AbstractTransLayer(K.layers.Layer):
         if self.has_training:
             additional_args["training"] = training
 
-        tf_output = self.tf_layer(tf_input, **additional_args)
+        tf_output = self.tf_layer.call(tf_input, **additional_args)
         return self.to_nntr_tensor(tf_output)
 
     ##
@@ -108,6 +113,22 @@ class ChannelLastTransLayer(AbstractTransLayer):
         self.to_tf_layer_ = K.layers.Permute(ChannelLastTransLayer.TO_CHANNELS_LAST)
         self.to_nntr_layer_ = K.layers.Permute(ChannelLastTransLayer.TO_CHANNELS_FIRST)
 
+    def build(self, input_shape):
+        if self.built:
+            return
+
+        if isinstance(input_shape, tf.TensorShape):
+            input_shape_list_ = input_shape.as_list()
+        else:
+            input_shape_list_ = input_shape
+        transposed_list_ = [None] * 4
+
+        for idx, i in enumerate((0,) + ChannelLastTransLayer.TO_CHANNELS_LAST):
+            transposed_list_[idx] = input_shape_list_[i]
+
+        transposed_input_shape = tf.TensorShape(transposed_list_)
+        super().build(transposed_input_shape)
+
     def to_tf_tensor(self, tensor):
         return self.to_tf_layer_(tensor)
 
@@ -139,12 +160,16 @@ CHANNEL_LAST_LAYERS = (
 # @brief Translayer for batch normalization layer
 class BatchNormTransLayer(IdentityTransLayer):
     def build(self, input_shape):
+        if self.built:
+            return
+
         if len(input_shape) > 3:
             self.tf_layer = ChannelLastTransLayer(self.tf_layer)
-        self.tf_layer.build(input_shape)
+
+        super().build(input_shape)
 
     def call(self, input, training=None):
-        return self.tf_layer(input, training)
+        return self.tf_layer.call(input, training)
 
     def to_nntr_weights(self, tensorOrList):
         x = tensorOrList
@@ -179,7 +204,7 @@ class MultiOutLayer(IdentityTransLayer):
         if self.has_training:
             additional_args["training"] = training
 
-        tf_output = self.tf_layer(x, **additional_args)
+        tf_output = self.tf_layer.call(x, **additional_args)
 
         return [layer(tf_output) for layer in self.stub_layers]
 
index f54ba89..0f9fae1 100644 (file)
@@ -147,16 +147,20 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) {
 
 static void compareRunContext(RunLayerContext &rc, std::ifstream &file) {
   file.seekg(0, std::ios::beg);
-  auto compare_tensors = [&file](unsigned length, auto tensor_getter,
-                                 auto pred) {
+  auto compare_tensors = [&file](unsigned length, auto tensor_getter, auto pred,
+                                 const std::string &name) {
     for (unsigned i = 0; i < length; ++i) {
       if (!pred(i)) {
         continue;
       }
       const auto &tensor = tensor_getter(i);
       auto answer = tensor.clone();
-      sizeCheckedReadTensor(answer, file);
-      EXPECT_EQ(tensor, answer);
+      sizeCheckedReadTensor(answer, file, name + " at " + std::to_string(i));
+
+      if (name == "initial_weights") {
+        continue;
+      }
+      EXPECT_EQ(tensor, answer) << name << " at " << std::to_string(i);
     }
   };
 
@@ -166,17 +170,23 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file) {
   };
 
   compare_tensors(rc.getNumWeights(),
-                  [&rc](unsigned idx) { return rc.getWeight(idx); }, always);
+                  [&rc](unsigned idx) { return rc.getWeight(idx); }, always,
+                  "initial_weights");
   compare_tensors(rc.getNumInputs(),
-                  [&rc](unsigned idx) { return rc.getInput(idx); }, always);
+                  [&rc](unsigned idx) { return rc.getInput(idx); }, always,
+                  "inputs");
   compare_tensors(rc.getNumOutputs(),
-                  [&rc](unsigned idx) { return rc.getOutput(idx); }, always);
+                  [&rc](unsigned idx) { return rc.getOutput(idx); }, always,
+                  "outputs");
   compare_tensors(rc.getNumWeights(),
                   [&rc](unsigned idx) { return rc.getWeightGrad(idx); },
-                  only_trainable);
+                  only_trainable, "gradients");
+  compare_tensors(rc.getNumWeights(),
+                  [&rc](unsigned idx) { return rc.getWeight(idx); }, always,
+                  "weights");
   compare_tensors(rc.getNumInputs(),
                   [&rc](unsigned idx) { return rc.getOutgoingDerivative(idx); },
-                  always);
+                  always, "derivatives");
 }
 
 LayerGoldenTest::~LayerGoldenTest() {}