From 7ef77606c5a75480cf6fe203525ae77396e648fd Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Wed, 20 Oct 2021 21:21:12 +0900 Subject: [PATCH] [Fix] fix recorder registration This patch fix registration bug when there is more than two overloaded delgation **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- test/input_gen/recorder_v2.py | 17 +++++++++++------ test/input_gen/transLayer_v2.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/test/input_gen/recorder_v2.py b/test/input_gen/recorder_v2.py index dfdc383..77b25d2 100644 --- a/test/input_gen/recorder_v2.py +++ b/test/input_gen/recorder_v2.py @@ -29,7 +29,10 @@ __all__ = ["record_v2", "inspect_file"] def _get_writer(file): - def write_fn(*items): + def write_fn(items): + if not isinstance(items, (list, tuple)): + items = [items] + for item in items: np.array([item.numel()], dtype="int32").tofile(file) item.detach().cpu().numpy().tofile(file) @@ -74,9 +77,9 @@ def record_v2(model, iteration, input_dims, label_dims, name): inputs = _rand_like(*input_dims, rand="float") labels = _rand_like(*label_dims, rand="float") output, loss = model(inputs, labels) - write_fn(*inputs) - write_fn(*labels) - write_fn(*(t for _, t in params_translated(model))) + write_fn(inputs) + write_fn(labels) + write_fn(list(t for _, t in params_translated(model))) write_fn(output) optimizer.zero_grad() @@ -96,7 +99,7 @@ def record_v2(model, iteration, input_dims, label_dims, name): # @brief inpsect if file is created correctly # @note this just checks if offset is corretly set, The result have to inspected # manually -def inspect_file(file_name): +def inspect_file(file_name, show_content=True): with open(file_name, "rb") as f: sz = int.from_bytes(f.read(4), byteorder="little") if not sz: @@ -107,5 +110,7 @@ def inspect_file(file_name): if not sz: break print("size: ", sz) - print(np.fromfile(f, dtype="float32", count=sz)) + t = np.fromfile(f, dtype="float32", count=sz) + if show_content: + print(t) diff --git a/test/input_gen/transLayer_v2.py b/test/input_gen/transLayer_v2.py index 70889c5..93865bd 100644 --- a/test/input_gen/transLayer_v2.py +++ b/test/input_gen/transLayer_v2.py @@ -9,6 +9,7 @@ # @author Jihoon lee import torch +from collections.abc import Iterable __all__ = ["params_translated"] @@ -20,6 +21,9 @@ handler_book = [] # This is to imitate function overloadding def register_for_(classes): for already_registered_classes, _ in handler_book: + if not isinstance(classes, Iterable): + classes = (classes, ) + for cls_ in classes: if isinstance(cls_, already_registered_classes): raise ValueError("class is already registered %s" % cls_.__name__) @@ -34,6 +38,13 @@ def register_for_(classes): def default_translate_(model): yield from model.named_parameters(recurse=False) +@register_for_(torch.nn.Linear) +def fc_translate(model): + params = [(name, tensor.detach()) for name, tensor in model.named_parameters()] + def transpose_(weight): + return (weight[0], weight[1].transpose(1, 0)) + new_params = [transpose_(params[0]), params[1]] + yield from new_params @register_for_(torch.nn.LSTMCell) def lstm_translate(model): @@ -46,16 +57,15 @@ def lstm_translate(model): new_params = [transpose_(params[0]), transpose_(params[1]), bias] yield from new_params - def translate(model): for child in model.children(): for registered_classes, fn in handler_book: if isinstance(child, registered_classes): yield from fn(child) - else: - yield from translate(child) + break + else: # default case + yield from translate(child) yield from default_translate_(model) - params_translated = translate -- 2.7.4