From 054456eb93705914d507f9334efdd448be775186 Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 14 Dec 2018 15:05:24 -0800 Subject: [PATCH] Preserve module hierarchy on traced modules (#15101) Summary: We need this, for example, to properly call `_unpack` when we have a traced module in the hierarchy Pull Request resolved: https://github.com/pytorch/pytorch/pull/15101 Differential Revision: D13468467 Pulled By: jamesr66a fbshipit-source-id: c2b6740b12cde6e23395d12e42d4fc2c4c7ca3f2 --- test/test_jit.py | 57 ++++++++++++++++++++++++++++++++++++++++-- torch/csrc/jit/script/init.cpp | 3 ++- torch/csrc/jit/script/module.h | 21 ++++++++++++++++ torch/jit/__init__.py | 18 ++++++++++++- 4 files changed, 95 insertions(+), 4 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 732229a..e6197db 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7141,6 +7141,59 @@ a") # Note: neg op from the traced function should be properly inlined self.assertExpected(canonical(tm.graph)) + def test_trace_hierarchy(self): + # Test that we preserve the module hierarchy for a ScriptModule + # submodule during tracing + + class AnotherScriptMod(torch.jit.ScriptModule): + def __init__(self): + super(AnotherScriptMod, self).__init__() + self.param = torch.nn.Parameter(torch.rand(1, 2, 3)) + + @torch.jit.script_method + def bar(self): + return torch.zeros(4, 5) + + class SomeScriptMod(torch.jit.ScriptModule): + def __init__(self): + super(SomeScriptMod, self).__init__() + self.asm = AnotherScriptMod() + + @torch.jit.script_method + def foo(self): + return torch.zeros(3, 4) + + @torch.jit.script_method + def bar(self): + return torch.zeros(4, 3) + + class TraceMe(torch.nn.Module): + def __init__(self): + super(TraceMe, self).__init__() + self.ssm = SomeScriptMod() + + def forward(self, x): + return self.ssm.bar() + x + + orig = TraceMe() + traced = torch.jit.trace(orig, (torch.rand(4, 3, dtype=torch.float),)) + # for each of these checks, check that *BOTH* the underlying + # _C.ScriptModule object has the expected method/param, as well as the + # Python object that wraps it. + self.assertTrue(traced.ssm._has_method('foo')) + self.assertTrue(hasattr(traced.ssm, 'foo')) + + imported = self.getExportImportCopy(traced) + + self.assertTrue(imported.ssm._has_method('foo')) + self.assertTrue(hasattr(imported.ssm, 'foo')) + + self.assertTrue(imported.ssm.asm._has_method('bar')) + self.assertTrue(hasattr(imported.ssm.asm, 'bar')) + + self.assertTrue(imported.ssm.asm._has_parameter('param')) + self.assertTrue(hasattr(imported.ssm.asm, 'param')) + def test_call_traced_module_from_traced_module(self): class TracedModule1(torch.nn.Module): def __init__(self): @@ -7186,11 +7239,11 @@ a") class ScriptMod(torch.jit.ScriptModule): def __init__(self): super(ScriptMod, self).__init__() - self.param = torch.nn.Parameter(torch.rand(5, 7)) + self.param_foo = torch.nn.Parameter(torch.rand(5, 7)) @torch.jit.script_method def forward(self, x): - return torch.mm(x, self.param) + return torch.mm(x, self.param_foo) class TracedModule(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index eafe330..ecfa4f4 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -687,7 +687,8 @@ void initJitScriptBindings(PyObject* module) { PythonPrint(ss, self, tensors, false); return ss.str(); }) - .def("apply", &Module::apply); + .def("apply", &Module::apply) + .def("_copy_into", &Module::copy_into); py::class_(m, "ScriptMethod", py::dynamic_attr()) .def("graph", [&](Method& self) { diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index ed3fade..0b57080 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -460,6 +460,27 @@ struct Module { void save(const std::string& filename); + void copy_into(std::function(std::vector)> module_lookup, std::vector names = {}) const { + std::unordered_map parameter_remap; + auto curr = module_lookup(names); + for (auto &kv : parameters) { + curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer); + parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key()); + } + for (auto &kv : modules) { + names.push_back(kv.key()); + kv.value().module->copy_into(module_lookup, names); + names.pop_back(); + } + for (auto &kv : methods) { + std::vector params; + for (auto &p : kv.value()->params()) { + params.push_back(parameter_remap[p]); + } + curr->create_method(kv.key(), kv.value()->graph()->copy(), params); + } + } + private: void to_impl( const c10::optional& device, diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index f94bd2b..03dc011 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1141,6 +1141,19 @@ if _enabled: rcb = createResolutionCallback(frames_up=1) self._define(lang, rcb, True) + def copy(self): + m = ScriptModule() + + def module_lookup(names): + curr = m + for name in names: + if not hasattr(curr, name): + setattr(curr, name, ScriptModule()) + curr = getattr(curr, name) + return curr + self._copy_into(module_lookup, []) + return m + class WeakScriptModuleProxy(ScriptModule): def __init__(self, original, stubs): # Guards behavior of __setattr__ and __getattr__ so ScriptModule @@ -1305,7 +1318,10 @@ class TracedModule(ScriptModule): raise ValueError("Modules that have hooks assigned can't be compiled") for name, submodule in orig._modules.items(): - self._modules[name] = TracedModule(submodule, id_set, optimize=optimize) + if isinstance(submodule, ScriptModule) and not isinstance(submodule, TracedModule): + self._modules[name] = submodule.copy() + else: + self._modules[name] = TracedModule(submodule, id_set, optimize=optimize) self._freeze() -- 2.7.4