From 65dfe1203ffab064d4e32fa8f76833042369d2f5 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 3 Apr 2019 10:51:41 -0700 Subject: [PATCH] add an assertion to check the param num (#18145) Summary: Introduce this check to see whether it will break any existing workflow Pull Request resolved: https://github.com/pytorch/pytorch/pull/18145 Reviewed By: dzhulgakov Differential Revision: D14511711 Pulled By: houseroad fbshipit-source-id: a7bb6ac84c9133fe94d3fe2f1a8566faed14a136 --- torch/onnx/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index b0cf121..bf468c3 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -253,6 +253,10 @@ def _model_to_graph(model, args, f, verbose=False, training=False, _set_input_and_output_names(graph, input_names, output_names) + # make sure that the param dict and the graph match each other + flatten_args, _ = torch._C._jit_flatten(args) + assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) + input_and_param_names = [val.uniqueName() for val in graph.inputs()] param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) -- 2.7.4