Support attributes when copying modules (#19040)
authorDavid Riazati <davidriazati@fb.com>
Wed, 10 Apr 2019 22:56:42 +0000 (15:56 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 23:12:29 +0000 (16:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19040
ghimport-source-id: 37933efd717795751283cae8141e2e2caaae2e95

Differential Revision: D14878128

Pulled By: driazati

fbshipit-source-id: 7ef5f7b1b16b9bf9254e8503564fa3a750d841ab

test/test_jit.py
torch/csrc/jit/script/module.h

index f2b5d44..b7e4499 100644 (file)
@@ -3505,6 +3505,38 @@ a")
             for i in range(len(script_funs)):
                 self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
 
+    def test_module_copy_with_attributes(self):
+        class Vocabulary(torch.jit.ScriptModule):
+            def __init__(self, vocab_list):
+                super().__init__()
+                self._vocab = torch.jit.Attribute(vocab_list, List[str])
+                self.some_idx = torch.jit.Attribute(2, int)
+                self.idx = torch.jit.Attribute(
+                    {word: i for i, word in enumerate(vocab_list)}, Dict[str, int]
+                )
+
+            @torch.jit.script_method
+            def lookup_indices_1d(self, values):
+                # type: (List[str]) -> List[int]
+                result = torch.jit.annotate(List[int], [])
+                # Direct list iteration not supported
+                for i in range(len(values)):
+                    value = values[i]
+                    result.append(self.idx.get(value, self.some_idx))
+                return result
+
+            @torch.jit.script_method
+            def forward(self, values):
+                # type: (List[List[str]]) -> List[List[int]]
+                result = torch.jit.annotate(List[List[int]], [])
+                # Direct list iteration not supported
+                for i in range(len(values)):
+                    result.append(self.lookup_indices_1d(values[i]))
+                return result
+
+        v = Vocabulary(list('uabcdefg'))
+        v.copy()
+
     def test_tuple_to_opt_list(self):
         @torch.jit.script
         def foo(x):
index 13233d9..d354b2b 100644 (file)
@@ -633,11 +633,13 @@ struct Module {
       parameter_remap[param] = curr->parameter_slot(param.name());
     }
     for (auto& attr : get_attributes()) {
-      if (!attr.type()->isSubtypeOf(TensorType::get())) {
-        continue;
+      if (attr.type()->isSubtypeOf(TensorType::get())) {
+        curr->register_buffer(attr.name(), attr.value().toTensor());
+        parameter_remap[attr] = *curr->find_buffer(attr.name());
+      } else {
+        curr->register_attribute(attr.name(), attr.type(), attr.value());
+        parameter_remap[attr] = *curr->find_attribute(attr.name());
       }
-      curr->register_buffer(attr.name(), attr.value().toTensor());
-      parameter_remap[attr] = *curr->find_buffer(attr.name());
     }
     for (auto& mod : get_modules()) {
       names.push_back(mod->name());