elif full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
- full_attr_node_name)
- param_tensors[full_attr_node_name] = tensor
+ full_attr)
+ param_tensors[full_attr] = tensor
params[full_attr_node_name] = var
return params, param_tensors, packed_param_map
verify_model(fn, input_data=[tensor1, tensor2])
+def test_weight_names():
+ tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
+ mod, params = relay.frontend.from_pytorch(tm, [('input', (2, 3))])
+ assert set(params.keys()) == set(n for n, p in tm.named_parameters())
+
+
def test_forward_matmul():
torch.set_grad_enabled(False)
if __name__ == "__main__":
+ # some structural tests
test_forward_traced_function()
test_forward_dtypes()
+ test_weight_names()
+
# Single operator tests
test_forward_add()
test_forward_subtract()