From 5164622ba462fe07fc9f2325fccf7f85aecb3ec8 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Wed, 10 Apr 2019 22:21:45 -0700 Subject: [PATCH] Revert D14878128: [jit] Support attributes when copying modules Differential Revision: D14878128 Original commit changeset: 7ef5f7b1b16b fbshipit-source-id: 3818222a897f8c01bc67f550ed0fd3ddecf61015 --- test/test_jit.py | 32 -------------------------------- torch/csrc/jit/script/module.h | 10 ++++------ 2 files changed, 4 insertions(+), 38 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 26160ad..787dee4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3513,38 +3513,6 @@ a") for i in range(len(script_funs)): self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)) - def test_module_copy_with_attributes(self): - class Vocabulary(torch.jit.ScriptModule): - def __init__(self, vocab_list): - super().__init__() - self._vocab = torch.jit.Attribute(vocab_list, List[str]) - self.some_idx = torch.jit.Attribute(2, int) - self.idx = torch.jit.Attribute( - {word: i for i, word in enumerate(vocab_list)}, Dict[str, int] - ) - - @torch.jit.script_method - def lookup_indices_1d(self, values): - # type: (List[str]) -> List[int] - result = torch.jit.annotate(List[int], []) - # Direct list iteration not supported - for i in range(len(values)): - value = values[i] - result.append(self.idx.get(value, self.some_idx)) - return result - - @torch.jit.script_method - def forward(self, values): - # type: (List[List[str]]) -> List[List[int]] - result = torch.jit.annotate(List[List[int]], []) - # Direct list iteration not supported - for i in range(len(values)): - result.append(self.lookup_indices_1d(values[i])) - return result - - v = Vocabulary(list('uabcdefg')) - v.copy() - def test_tuple_to_opt_list(self): @torch.jit.script def foo(x): diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index d354b2b..13233d9 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -633,13 +633,11 @@ struct Module { parameter_remap[param] = curr->parameter_slot(param.name()); } for (auto& attr : get_attributes()) { - if (attr.type()->isSubtypeOf(TensorType::get())) { - curr->register_buffer(attr.name(), attr.value().toTensor()); - parameter_remap[attr] = *curr->find_buffer(attr.name()); - } else { - curr->register_attribute(attr.name(), attr.type(), attr.value()); - parameter_remap[attr] = *curr->find_attribute(attr.name()); + if (!attr.type()->isSubtypeOf(TensorType::get())) { + continue; } + curr->register_buffer(attr.name(), attr.value().toTensor()); + parameter_remap[attr] = *curr->find_buffer(attr.name()); } for (auto& mod : get_modules()) { names.push_back(mod->name()); -- 2.7.4