Preserve module hierarchy on traced modules (#15101)
authorJames Reed <jamesreed@fb.com>
Fri, 14 Dec 2018 23:05:24 +0000 (15:05 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 23:07:51 +0000 (15:07 -0800)
Summary:
We need this, for example, to properly call `_unpack` when we have a traced module in the hierarchy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15101

Differential Revision: D13468467

Pulled By: jamesr66a

fbshipit-source-id: c2b6740b12cde6e23395d12e42d4fc2c4c7ca3f2

test/test_jit.py
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/jit/__init__.py

index 732229a..e6197db 100644 (file)
@@ -7141,6 +7141,59 @@ a")
         # Note: neg op from the traced function should be properly inlined
         self.assertExpected(canonical(tm.graph))
 
+    def test_trace_hierarchy(self):
+        # Test that we preserve the module hierarchy for a ScriptModule
+        # submodule during tracing
+
+        class AnotherScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(AnotherScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
+
+            @torch.jit.script_method
+            def bar(self):
+                return torch.zeros(4, 5)
+
+        class SomeScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(SomeScriptMod, self).__init__()
+                self.asm = AnotherScriptMod()
+
+            @torch.jit.script_method
+            def foo(self):
+                return torch.zeros(3, 4)
+
+            @torch.jit.script_method
+            def bar(self):
+                return torch.zeros(4, 3)
+
+        class TraceMe(torch.nn.Module):
+            def __init__(self):
+                super(TraceMe, self).__init__()
+                self.ssm = SomeScriptMod()
+
+            def forward(self, x):
+                return self.ssm.bar() + x
+
+        orig = TraceMe()
+        traced = torch.jit.trace(orig, (torch.rand(4, 3, dtype=torch.float),))
+        # for each of these checks, check that *BOTH* the underlying
+        # _C.ScriptModule object has the expected method/param, as well as the
+        # Python object that wraps it.
+        self.assertTrue(traced.ssm._has_method('foo'))
+        self.assertTrue(hasattr(traced.ssm, 'foo'))
+
+        imported = self.getExportImportCopy(traced)
+
+        self.assertTrue(imported.ssm._has_method('foo'))
+        self.assertTrue(hasattr(imported.ssm, 'foo'))
+
+        self.assertTrue(imported.ssm.asm._has_method('bar'))
+        self.assertTrue(hasattr(imported.ssm.asm, 'bar'))
+
+        self.assertTrue(imported.ssm.asm._has_parameter('param'))
+        self.assertTrue(hasattr(imported.ssm.asm, 'param'))
+
     def test_call_traced_module_from_traced_module(self):
         class TracedModule1(torch.nn.Module):
             def __init__(self):
@@ -7186,11 +7239,11 @@ a")
         class ScriptMod(torch.jit.ScriptModule):
             def __init__(self):
                 super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(5, 7))
+                self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
 
             @torch.jit.script_method
             def forward(self, x):
-                return torch.mm(x, self.param)
+                return torch.mm(x, self.param_foo)
 
         class TracedModule(torch.nn.Module):
             def __init__(self):
index eafe330..ecfa4f4 100644 (file)
@@ -687,7 +687,8 @@ void initJitScriptBindings(PyObject* module) {
         PythonPrint(ss, self, tensors, false);
         return ss.str();
       })
-      .def("apply", &Module::apply);
+      .def("apply", &Module::apply)
+      .def("_copy_into", &Module::copy_into);
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
     .def("graph", [&](Method& self) {
index ed3fade..0b57080 100644 (file)
@@ -460,6 +460,27 @@ struct Module {
 
   void save(const std::string& filename);
 
+  void copy_into(std::function<std::shared_ptr<Module>(std::vector<std::string>)> module_lookup, std::vector<std::string> names = {}) const {
+    std::unordered_map<at::Tensor*, at::Tensor*> parameter_remap;
+    auto curr = module_lookup(names);
+    for (auto &kv : parameters) {
+      curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
+      parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
+    }
+    for (auto &kv : modules) {
+      names.push_back(kv.key());
+      kv.value().module->copy_into(module_lookup, names);
+      names.pop_back();
+    }
+    for (auto &kv : methods) {
+      std::vector<at::Tensor*> params;
+      for (auto &p : kv.value()->params()) {
+        params.push_back(parameter_remap[p]);
+      }
+      curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
+    }
+  }
+
  private:
   void to_impl(
       const c10::optional<at::Device>& device,
index f94bd2b..03dc011 100644 (file)
@@ -1141,6 +1141,19 @@ if _enabled:
             rcb = createResolutionCallback(frames_up=1)
             self._define(lang, rcb, True)
 
+        def copy(self):
+            m = ScriptModule()
+
+            def module_lookup(names):
+                curr = m
+                for name in names:
+                    if not hasattr(curr, name):
+                        setattr(curr, name, ScriptModule())
+                    curr = getattr(curr, name)
+                return curr
+            self._copy_into(module_lookup, [])
+            return m
+
     class WeakScriptModuleProxy(ScriptModule):
         def __init__(self, original, stubs):
             # Guards behavior of __setattr__ and __getattr__ so ScriptModule
@@ -1305,7 +1318,10 @@ class TracedModule(ScriptModule):
             raise ValueError("Modules that have hooks assigned can't be compiled")
 
         for name, submodule in orig._modules.items():
-            self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
+            if isinstance(submodule, ScriptModule) and not isinstance(submodule, TracedModule):
+                self._modules[name] = submodule.copy()
+            else:
+                self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
 
         self._freeze()