torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
+ def test_unique_state_dict(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ shared_param = torch.nn.Parameter(torch.ones(1))
+ self.register_parameter('w1', shared_param)
+ self.register_parameter('w2', shared_param)
+
+ def forward(self, input):
+ return input + self.w1 + self.w2
+
+ model = MyModule()
+ unittest.TestCase.assertEqual(
+ self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1)
+ unittest.TestCase.assertEqual(
+ self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1)
+
def test_trace_dict_input(self):
class Bar(torch.nn.Module):
def __init__(self):
def _unique_state_dict(module, keep_vars=False):
- state_dict = module.state_dict(keep_vars=keep_vars)
+ # since Parameter.data always creates a new torch.Tensor instance,
+ # id(v) doesn't work with it. So we always get the Parameter or Buffer
+ # as values, and deduplicate the params using Parameters and Buffers
+ state_dict = module.state_dict(keep_vars=True)
filtered_dict = type(state_dict)()
seen_ids = set()
for k, v in state_dict.items():
if id(v) in seen_ids:
continue
seen_ids.add(id(v))
- filtered_dict[k] = v
+ if keep_vars:
+ filtered_dict[k] = v
+ else:
+ filtered_dict[k] = v.data
return filtered_dict