add an assertion to check the param num (#18145)
authorLu Fang <lufang@fb.com>
Wed, 3 Apr 2019 17:51:41 +0000 (10:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:23 +0000 (12:47 -0700)
Summary:
Introduce this check to see whether it will break any existing workflow
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18145

Reviewed By: dzhulgakov

Differential Revision: D14511711

Pulled By: houseroad

fbshipit-source-id: a7bb6ac84c9133fe94d3fe2f1a8566faed14a136

torch/onnx/utils.py

index b0cf121..bf468c3 100644 (file)
@@ -253,6 +253,10 @@ def _model_to_graph(model, args, f, verbose=False, training=False,
 
     _set_input_and_output_names(graph, input_names, output_names)
 
+    # make sure that the param dict and the graph match each other
+    flatten_args, _ = torch._C._jit_flatten(args)
+    assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())
+
     input_and_param_names = [val.uniqueName() for val in graph.inputs()]
     param_names = input_and_param_names[len(input_and_param_names) - len(params):]
     params_dict = dict(zip(param_names, params))