From 612998f2eeb1140102a9f6d16ad3738748a4698c Mon Sep 17 00:00:00 2001 From: David Riazati Date: Wed, 10 Apr 2019 15:56:42 -0700 Subject: [PATCH] Support attributes when copying modules (#19040) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19040 ghimport-source-id: 37933efd717795751283cae8141e2e2caaae2e95 Differential Revision: D14878128 Pulled By: driazati fbshipit-source-id: 7ef5f7b1b16b9bf9254e8503564fa3a750d841ab --- test/test_jit.py | 32 ++++++++++++++++++++++++++++++++ torch/csrc/jit/script/module.h | 10 ++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index f2b5d44..b7e4499 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3505,6 +3505,38 @@ a") for i in range(len(script_funs)): self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)) + def test_module_copy_with_attributes(self): + class Vocabulary(torch.jit.ScriptModule): + def __init__(self, vocab_list): + super().__init__() + self._vocab = torch.jit.Attribute(vocab_list, List[str]) + self.some_idx = torch.jit.Attribute(2, int) + self.idx = torch.jit.Attribute( + {word: i for i, word in enumerate(vocab_list)}, Dict[str, int] + ) + + @torch.jit.script_method + def lookup_indices_1d(self, values): + # type: (List[str]) -> List[int] + result = torch.jit.annotate(List[int], []) + # Direct list iteration not supported + for i in range(len(values)): + value = values[i] + result.append(self.idx.get(value, self.some_idx)) + return result + + @torch.jit.script_method + def forward(self, values): + # type: (List[List[str]]) -> List[List[int]] + result = torch.jit.annotate(List[List[int]], []) + # Direct list iteration not supported + for i in range(len(values)): + result.append(self.lookup_indices_1d(values[i])) + return result + + v = Vocabulary(list('uabcdefg')) + v.copy() + def test_tuple_to_opt_list(self): @torch.jit.script def foo(x): diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 13233d9..d354b2b 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -633,11 +633,13 @@ struct Module { parameter_remap[param] = curr->parameter_slot(param.name()); } for (auto& attr : get_attributes()) { - if (!attr.type()->isSubtypeOf(TensorType::get())) { - continue; + if (attr.type()->isSubtypeOf(TensorType::get())) { + curr->register_buffer(attr.name(), attr.value().toTensor()); + parameter_remap[attr] = *curr->find_buffer(attr.name()); + } else { + curr->register_attribute(attr.name(), attr.type(), attr.value()); + parameter_remap[attr] = *curr->find_attribute(attr.name()); } - curr->register_buffer(attr.name(), attr.value().toTensor()); - parameter_remap[attr] = *curr->find_buffer(attr.name()); } for (auto& mod : get_modules()) { names.push_back(mod->name()); -- 2.7.4