PyTorch frontend: fix handling of duplicate use of a model weight (#5897)
authorThomas Viehmann <tv.code@beamnet.de>
Wed, 24 Jun 2020 04:42:38 +0000 (06:42 +0200)
committerGitHub <noreply@github.com>
Wed, 24 Jun 2020 04:42:38 +0000 (10:12 +0530)
This happens e.g. in shared input/output embeddings in BERT
or siamese networks.

Thank you @siju-samuel for reporting.

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 374e1c2..9237303 100644 (file)
@@ -2335,6 +2335,7 @@ def convert_params(graph, state_dict):
     params = {}
     param_tensors = {}
     packed_param_map = {}
+    vars_by_name = {}
     seen = set()
 
     for node in getattr_nodes:
@@ -2352,10 +2353,14 @@ def convert_params(graph, state_dict):
                 assert full_attr in state_dict, err_msg
                 packed_param_map[full_attr_node_name] = full_attr
             elif full_attr in state_dict:
-                torch_tensor = state_dict[full_attr]
-                tensor, var = _get_tensor_and_var(torch_tensor,
-                                                  full_attr)
-                param_tensors[full_attr] = tensor
+                if full_attr in vars_by_name:
+                    var = vars_by_name[full_attr]
+                else:
+                    torch_tensor = state_dict[full_attr]
+                    tensor, var = _get_tensor_and_var(torch_tensor,
+                                                      full_attr)
+                    param_tensors[full_attr] = tensor
+                    vars_by_name[full_attr] = var
                 params[full_attr_node_name] = var
 
     return params, param_tensors, packed_param_map
index d564965..12d1260 100644 (file)
@@ -2390,6 +2390,23 @@ def test_weight_names():
     assert set(params.keys()) == set(n for n, p in tm.named_parameters())
 
 
+def test_duplicate_weight_use():
+    # The test cases doesn't make any sense as a neural network,
+    # the issue popped up in shared input/output embeddings of bert,
+    # but this is quicker
+    class Test(Module):
+        def __init__(self):
+            super().__init__()
+            self.lin = torch.nn.Linear(5, 3)
+
+        def forward(self, x):
+            x = self.lin(x)
+            x = x @ self.lin.weight
+            return x
+
+    verify_model(Test(), input_data=[torch.randn(5, 5)])
+
+
 def test_forward_matmul():
     torch.set_grad_enabled(False)
 
@@ -2556,6 +2573,7 @@ if __name__ == "__main__":
     test_forward_traced_function()
     test_forward_dtypes()
     test_weight_names()
+    test_duplicate_weight_use()
 
     # Single operator tests
     test_forward_add()