for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
+ def test_module_copy_with_attributes(self):
+ class Vocabulary(torch.jit.ScriptModule):
+ def __init__(self, vocab_list):
+ super().__init__()
+ self._vocab = torch.jit.Attribute(vocab_list, List[str])
+ self.some_idx = torch.jit.Attribute(2, int)
+ self.idx = torch.jit.Attribute(
+ {word: i for i, word in enumerate(vocab_list)}, Dict[str, int]
+ )
+
+ @torch.jit.script_method
+ def lookup_indices_1d(self, values):
+ # type: (List[str]) -> List[int]
+ result = torch.jit.annotate(List[int], [])
+ # Direct list iteration not supported
+ for i in range(len(values)):
+ value = values[i]
+ result.append(self.idx.get(value, self.some_idx))
+ return result
+
+ @torch.jit.script_method
+ def forward(self, values):
+ # type: (List[List[str]]) -> List[List[int]]
+ result = torch.jit.annotate(List[List[int]], [])
+ # Direct list iteration not supported
+ for i in range(len(values)):
+ result.append(self.lookup_indices_1d(values[i]))
+ return result
+
+ v = Vocabulary(list('uabcdefg'))
+ v.copy()
+
def test_tuple_to_opt_list(self):
@torch.jit.script
def foo(x):
parameter_remap[param] = curr->parameter_slot(param.name());
}
for (auto& attr : get_attributes()) {
- if (!attr.type()->isSubtypeOf(TensorType::get())) {
- continue;
+ if (attr.type()->isSubtypeOf(TensorType::get())) {
+ curr->register_buffer(attr.name(), attr.value().toTensor());
+ parameter_remap[attr] = *curr->find_buffer(attr.name());
+ } else {
+ curr->register_attribute(attr.name(), attr.type(), attr.value());
+ parameter_remap[attr] = *curr->find_attribute(attr.name());
}
- curr->register_buffer(attr.name(), attr.value().toTensor());
- parameter_remap[attr] = *curr->find_buffer(attr.name());
}
for (auto& mod : get_modules()) {
names.push_back(mod->name());