[Fix] fix recorder registration
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 20 Oct 2021 12:21:12 +0000 (21:21 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 21 Oct 2021 06:06:37 +0000 (15:06 +0900)
This patch fix registration bug when there is more than two overloaded
delgation

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
test/input_gen/recorder_v2.py
test/input_gen/transLayer_v2.py

index dfdc383..77b25d2 100644 (file)
@@ -29,7 +29,10 @@ __all__ = ["record_v2", "inspect_file"]
 
 
 def _get_writer(file):
-    def write_fn(*items):
+    def write_fn(items):
+        if not isinstance(items, (list, tuple)):
+            items = [items]
+
         for item in items:
             np.array([item.numel()], dtype="int32").tofile(file)
             item.detach().cpu().numpy().tofile(file)
@@ -74,9 +77,9 @@ def record_v2(model, iteration, input_dims, label_dims, name):
         inputs = _rand_like(*input_dims, rand="float")
         labels = _rand_like(*label_dims, rand="float")
         output, loss = model(inputs, labels)
-        write_fn(*inputs)
-        write_fn(*labels)
-        write_fn(*(t for _, t in params_translated(model)))
+        write_fn(inputs)
+        write_fn(labels)
+        write_fn(list(t for _, t in params_translated(model)))
         write_fn(output)
 
         optimizer.zero_grad()
@@ -96,7 +99,7 @@ def record_v2(model, iteration, input_dims, label_dims, name):
 # @brief inpsect if file is created correctly
 # @note this just checks if offset is corretly set, The result have to inspected
 # manually
-def inspect_file(file_name):
+def inspect_file(file_name, show_content=True):
     with open(file_name, "rb") as f:
         sz = int.from_bytes(f.read(4), byteorder="little")
         if not sz:
@@ -107,5 +110,7 @@ def inspect_file(file_name):
             if not sz:
                 break
             print("size: ", sz)
-            print(np.fromfile(f, dtype="float32", count=sz))
+            t = np.fromfile(f, dtype="float32", count=sz)
+            if show_content:
+                print(t)
 
index 70889c5..93865bd 100644 (file)
@@ -9,6 +9,7 @@
 # @author Jihoon lee <jhoon.it.lee@samsung.com>
 
 import torch
+from collections.abc import Iterable
 
 __all__ = ["params_translated"]
 
@@ -20,6 +21,9 @@ handler_book = []
 # This is to imitate function overloadding
 def register_for_(classes):
     for already_registered_classes, _ in handler_book:
+        if not isinstance(classes, Iterable):
+            classes = (classes, )
+
         for cls_ in classes:
             if isinstance(cls_, already_registered_classes):
                 raise ValueError("class is already registered %s" % cls_.__name__)
@@ -34,6 +38,13 @@ def register_for_(classes):
 def default_translate_(model):
     yield from model.named_parameters(recurse=False)
 
+@register_for_(torch.nn.Linear)
+def fc_translate(model):
+    params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+    def transpose_(weight):
+        return (weight[0], weight[1].transpose(1, 0))
+    new_params = [transpose_(params[0]), params[1]]
+    yield from new_params
 
 @register_for_(torch.nn.LSTMCell)
 def lstm_translate(model):
@@ -46,16 +57,15 @@ def lstm_translate(model):
     new_params = [transpose_(params[0]), transpose_(params[1]), bias]
     yield from new_params
 
-
 def translate(model):
     for child in model.children():
         for registered_classes, fn in handler_book:
             if isinstance(child, registered_classes):
                 yield from fn(child)
-            else:
-                yield from translate(child)
+                break
+        else: # default case
+            yield from translate(child)
     yield from default_translate_(model)
 
-
 params_translated = translate