params = {}
param_tensors = {}
packed_param_map = {}
+ vars_by_name = {}
seen = set()
for node in getattr_nodes:
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
- torch_tensor = state_dict[full_attr]
- tensor, var = _get_tensor_and_var(torch_tensor,
- full_attr)
- param_tensors[full_attr] = tensor
+ if full_attr in vars_by_name:
+ var = vars_by_name[full_attr]
+ else:
+ torch_tensor = state_dict[full_attr]
+ tensor, var = _get_tensor_and_var(torch_tensor,
+ full_attr)
+ param_tensors[full_attr] = tensor
+ vars_by_name[full_attr] = var
params[full_attr_node_name] = var
return params, param_tensors, packed_param_map
assert set(params.keys()) == set(n for n, p in tm.named_parameters())
+def test_duplicate_weight_use():
+ # The test cases doesn't make any sense as a neural network,
+ # the issue popped up in shared input/output embeddings of bert,
+ # but this is quicker
+ class Test(Module):
+ def __init__(self):
+ super().__init__()
+ self.lin = torch.nn.Linear(5, 3)
+
+ def forward(self, x):
+ x = self.lin(x)
+ x = x @ self.lin.weight
+ return x
+
+ verify_model(Test(), input_data=[torch.randn(5, 5)])
+
+
def test_forward_matmul():
torch.set_grad_enabled(False)
test_forward_traced_function()
test_forward_dtypes()
test_weight_names()
+ test_duplicate_weight_use()
# Single operator tests
test_forward_add()