Fix the duplication problem in _unique_state_dict (#18139)
authorLu Fang <lufang@fb.com>
Thu, 4 Apr 2019 06:14:07 +0000 (23:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 06:16:44 +0000 (23:16 -0700)
Summary:
Since parameter.data will create a new torch.Tensor each time, we get duplicate tensors when call _unique_state_dict now. Try to deduplicate it before creating new tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18139

Reviewed By: dzhulgakov

Differential Revision: D14511262

Pulled By: houseroad

fbshipit-source-id: cb69795d0b6509721220650bbb19edeb3459a503

test/test_jit.py
torch/jit/__init__.py

index 98d8bd4..9348447 100644 (file)
@@ -2318,6 +2318,23 @@ class TestJit(JitTestCase):
             torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
         self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
 
+    def test_unique_state_dict(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                shared_param = torch.nn.Parameter(torch.ones(1))
+                self.register_parameter('w1', shared_param)
+                self.register_parameter('w2', shared_param)
+
+            def forward(self, input):
+                return input + self.w1 + self.w2
+
+        model = MyModule()
+        unittest.TestCase.assertEqual(
+            self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1)
+        unittest.TestCase.assertEqual(
+            self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1)
+
     def test_trace_dict_input(self):
         class Bar(torch.nn.Module):
             def __init__(self):
index 4032f78..593f508 100644 (file)
@@ -216,14 +216,20 @@ def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_input
 
 
 def _unique_state_dict(module, keep_vars=False):
-    state_dict = module.state_dict(keep_vars=keep_vars)
+    # since Parameter.data always creates a new torch.Tensor instance,
+    # id(v) doesn't work with it. So we always get the Parameter or Buffer
+    # as values, and deduplicate the params using Parameters and Buffers
+    state_dict = module.state_dict(keep_vars=True)
     filtered_dict = type(state_dict)()
     seen_ids = set()
     for k, v in state_dict.items():
         if id(v) in seen_ids:
             continue
         seen_ids.add(id(v))
-        filtered_dict[k] = v
+        if keep_vars:
+            filtered_dict[k] = v
+        else:
+            filtered_dict[k] = v.data
     return filtered_dict