keep parameter names from PyTorch (#5887)
authorThomas Viehmann <tv.code@beamnet.de>
Mon, 22 Jun 2020 23:40:45 +0000 (01:40 +0200)
committerGitHub <noreply@github.com>
Mon, 22 Jun 2020 23:40:45 +0000 (08:40 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index f70a64a..374e1c2 100644 (file)
@@ -2354,8 +2354,8 @@ def convert_params(graph, state_dict):
             elif full_attr in state_dict:
                 torch_tensor = state_dict[full_attr]
                 tensor, var = _get_tensor_and_var(torch_tensor,
-                                                  full_attr_node_name)
-                param_tensors[full_attr_node_name] = tensor
+                                                  full_attr)
+                param_tensors[full_attr] = tensor
                 params[full_attr_node_name] = var
 
     return params, param_tensors, packed_param_map
index 6ec3110..d564965 100644 (file)
@@ -2384,6 +2384,12 @@ def test_forward_dtypes():
         verify_model(fn, input_data=[tensor1, tensor2])
 
 
+def test_weight_names():
+    tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
+    mod, params = relay.frontend.from_pytorch(tm, [('input', (2, 3))])
+    assert set(params.keys()) == set(n for n, p in tm.named_parameters())
+
+
 def test_forward_matmul():
     torch.set_grad_enabled(False)
 
@@ -2546,8 +2552,11 @@ def test_forward_pretrained_bert_base_uncased():
 
 
 if __name__ == "__main__":
+    # some structural tests
     test_forward_traced_function()
     test_forward_dtypes()
+    test_weight_names()
+
     # Single operator tests
     test_forward_add()
     test_forward_subtract()