def _get_writer(file):
- def write_fn(*items):
+ def write_fn(items):
+ if not isinstance(items, (list, tuple)):
+ items = [items]
+
for item in items:
np.array([item.numel()], dtype="int32").tofile(file)
item.detach().cpu().numpy().tofile(file)
inputs = _rand_like(*input_dims, rand="float")
labels = _rand_like(*label_dims, rand="float")
output, loss = model(inputs, labels)
- write_fn(*inputs)
- write_fn(*labels)
- write_fn(*(t for _, t in params_translated(model)))
+ write_fn(inputs)
+ write_fn(labels)
+ write_fn(list(t for _, t in params_translated(model)))
write_fn(output)
optimizer.zero_grad()
# @brief inpsect if file is created correctly
# @note this just checks if offset is corretly set, The result have to inspected
# manually
-def inspect_file(file_name):
+def inspect_file(file_name, show_content=True):
with open(file_name, "rb") as f:
sz = int.from_bytes(f.read(4), byteorder="little")
if not sz:
if not sz:
break
print("size: ", sz)
- print(np.fromfile(f, dtype="float32", count=sz))
+ t = np.fromfile(f, dtype="float32", count=sz)
+ if show_content:
+ print(t)
# @author Jihoon lee <jhoon.it.lee@samsung.com>
import torch
+from collections.abc import Iterable
__all__ = ["params_translated"]
# This is to imitate function overloadding
def register_for_(classes):
for already_registered_classes, _ in handler_book:
+ if not isinstance(classes, Iterable):
+ classes = (classes, )
+
for cls_ in classes:
if isinstance(cls_, already_registered_classes):
raise ValueError("class is already registered %s" % cls_.__name__)
def default_translate_(model):
yield from model.named_parameters(recurse=False)
+@register_for_(torch.nn.Linear)
+def fc_translate(model):
+ params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+ def transpose_(weight):
+ return (weight[0], weight[1].transpose(1, 0))
+ new_params = [transpose_(params[0]), params[1]]
+ yield from new_params
@register_for_(torch.nn.LSTMCell)
def lstm_translate(model):
new_params = [transpose_(params[0]), transpose_(params[1]), bias]
yield from new_params
-
def translate(model):
for child in model.children():
for registered_classes, fn in handler_book:
if isinstance(child, registered_classes):
yield from fn(child)
- else:
- yield from translate(child)
+ break
+ else: # default case
+ yield from translate(child)
yield from default_translate_(model)
-
params_translated = translate