Fix Module::copy_into
authorJames Sun <jamessun@fb.com>
Thu, 20 Dec 2018 01:06:54 +0000 (17:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 01:09:59 +0000 (17:09 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15393

Differential Revision: D13519477

Pulled By: highker

fbshipit-source-id: d62928597ec0700b550e7cf481c8febae57b200d

test/test_jit.py
torch/csrc/jit/script/module.h
torch/jit/__init__.py

index 84daa2e..5828801 100644 (file)
@@ -7226,6 +7226,45 @@ a")
         self.assertTrue(imported.ssm.asm._has_parameter('param'))
         self.assertTrue(hasattr(imported.ssm.asm, 'param'))
 
+    def test_trace_parameter(self):
+        class Param(nn.Module):
+            def __init__(self):
+                super(Param, self).__init__()
+                self.register_parameter("bias", nn.Parameter(torch.Tensor(4, 4)))
+
+            def forward(self, x):
+                return x
+
+        class M3(torch.jit.ScriptModule):
+            def __init__(self, model):
+                super(M3, self).__init__(False)
+                self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.traced(x)
+
+        class M2(nn.Module):
+            def __init__(self, model):
+                super(M2, self).__init__()
+                self.module = M3(model)
+
+            def forward(self, x):
+                return self.module(x)
+
+        class M1(torch.jit.ScriptModule):
+            def __init__(self, model):
+                super(M1, self).__init__(False)
+                self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.traced(x)
+
+        module = M1(Param())
+        f = io.BytesIO()
+        torch.jit.save(module, f)
+
     def test_call_traced_module_from_traced_module(self):
         class TracedModule1(torch.nn.Module):
             def __init__(self):
index 48684b6..ae6375c 100644 (file)
@@ -472,8 +472,11 @@ struct Module {
 
   void save(const std::string& filename);
 
-  void copy_into(std::function<std::shared_ptr<Module>(std::vector<std::string>)> module_lookup, std::vector<std::string> names = {}) const {
-    std::unordered_map<at::Tensor*, at::Tensor*> parameter_remap;
+  void copy_into(std::function<std::shared_ptr<Module>(
+      std::vector<std::string>)> module_lookup,
+      // parameter_remap is needed when a parent module uses a parameter of a submodule
+      std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
+      std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
     for (auto &kv : parameters) {
       curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
@@ -481,13 +484,15 @@ struct Module {
     }
     for (auto &kv : modules) {
       names.push_back(kv.key());
-      kv.value().module->copy_into(module_lookup, names);
+      // Submodules must be translated first, otherwise parameter_remap entries
+      // will not be filled in for methods of this module.
+      kv.value().module->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
     for (auto &kv : methods) {
       std::vector<at::Tensor*> params;
       for (auto &p : kv.value()->params()) {
-        params.push_back(parameter_remap[p]);
+        params.push_back(parameter_remap.at(p));
       }
       curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
     }
index 88f52f9..fe90ab1 100644 (file)
@@ -1152,7 +1152,7 @@ if _enabled:
                         setattr(curr, name, ScriptModule())
                     curr = getattr(curr, name)
                 return curr
-            self._copy_into(module_lookup, [])
+            self._copy_into(module_lookup, {}, [])
             return m
 
         def __getstate__(self):