[ONNX] Fix an issue that optimizations might adjust graph inputs unexpectedly. (...
authorBowenBao <bowbao@microsoft.com>
Fri, 20 Aug 2021 19:44:29 +0000 (12:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 19:46:52 +0000 (12:46 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62763

This PR is to fix the issue that the graph inputs might be updated when we export the model in inference mode.

When a model is export in inference mode, some optimizations will be made. One side effect of these optimizations is: the inputs of graph might be adjusted. Such optimizatiosn include:

1. Conv and BatchNorm op fusion.
2. Do constant folding.

If the user sets export_params=False, or set keep_initializers_as_inputs=True, it's highly possible that the user wants to provide the corresponding parameters or initiliazers as the inputs of the graph.
In such situation, no matter the model is export in inference mode or training mode, exporter needs to prevent above optimizations from adjusting the graph inputs. By this, the inputs of graph could match inputs that users provided.

The changes in this PR, add an additional common judgement to see if the above optimizations needs to be done or not. From the value of export_params and keep_initializers_as_inputs arguments, infer if the graph inputs are allowed to be adjusted.
If no, these optimizations will be ignored, even other requirements are matched.

Besides these code changes, the comments of some parameters below have been updated so that users have more thoughts when they consider how to leverage these parameters for different purposes:

1. export_params
2. training
3. do_constant_folding
4. keep_initializers_as_inputs

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375183

Pulled By: msaroufim

fbshipit-source-id: 4db8b9695649eb32a3a0fefa950ee2e5651bdba0

Co-authored-by: fatcat-z <jiz@microsoft.com>
test/onnx/expect/TestOperators.test_prelu.expect
test/onnx/expect/TestOperators.test_retain_param_name_disabled.expect
torch/_C/__init__.pyi.in
torch/csrc/jit/passes/onnx/eval_peephole.cpp
torch/csrc/jit/passes/onnx/eval_peephole.h
torch/csrc/jit/python/init.cpp
torch/onnx/__init__.py
torch/onnx/utils.py

index e19623c..be0328e 100644 (file)
@@ -3,19 +3,29 @@ producer_name: "pytorch"
 producer_version: "CURRENT_VERSION"
 graph {
   node {
+    input: "weight"
+    output: "2"
+    name: "Unsqueeze_0"
+    op_type: "Unsqueeze"
+    attribute {
+      name: "axes"
+      ints: 1
+      ints: 2
+      type: INTS
+    }
+  }
+  node {
     input: "input"
-    input: "4"
+    input: "2"
     output: "3"
-    name: "PRelu_0"
+    name: "PRelu_1"
     op_type: "PRelu"
   }
   name: "torch-jit-export"
   initializer {
     dims: 2
-    dims: 1
-    dims: 1
     data_type: 1
-    name: "4"
+    name: "weight"
     raw_data: "\000\000\200>\000\000\200>"
   }
   input {
@@ -41,7 +51,7 @@ graph {
     }
   }
   input {
-    name: "4"
+    name: "weight"
     type {
       tensor_type {
         elem_type: 1
@@ -49,12 +59,6 @@ graph {
           dim {
             dim_value: 2
           }
-          dim {
-            dim_value: 1
-          }
-          dim {
-            dim_value: 1
-          }
         }
       }
     }
index 5eeaa87..aa9499e 100644 (file)
@@ -3,32 +3,56 @@ producer_name: "pytorch"
 producer_version: "CURRENT_VERSION"
 graph {
   node {
+    input: "1"
+    output: "3"
+    name: "Transpose_0"
+    op_type: "Transpose"
+    attribute {
+      name: "perm"
+      ints: 1
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
     input: "input.1"
-    input: "7"
+    input: "3"
     output: "4"
-    name: "MatMul_0"
+    name: "MatMul_1"
     op_type: "MatMul"
   }
   node {
+    input: "2"
+    output: "5"
+    name: "Transpose_2"
+    op_type: "Transpose"
+    attribute {
+      name: "perm"
+      ints: 1
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
     input: "4"
-    input: "8"
+    input: "5"
     output: "6"
-    name: "MatMul_1"
+    name: "MatMul_3"
     op_type: "MatMul"
   }
   name: "torch-jit-export"
   initializer {
-    dims: 4
     dims: 5
+    dims: 4
     data_type: 1
-    name: "7"
+    name: "1"
     raw_data: "\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@"
   }
   initializer {
-    dims: 5
     dims: 6
+    dims: 5
     data_type: 1
-    name: "8"
+    name: "2"
     raw_data: "\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@"
   }
   input {
@@ -48,32 +72,32 @@ graph {
     }
   }
   input {
-    name: "7"
+    name: "1"
     type {
       tensor_type {
         elem_type: 1
         shape {
           dim {
-            dim_value: 4
+            dim_value: 5
           }
           dim {
-            dim_value: 5
+            dim_value: 4
           }
         }
       }
     }
   }
   input {
-    name: "8"
+    name: "2"
     type {
       tensor_type {
         elem_type: 1
         shape {
           dim {
-            dim_value: 5
+            dim_value: 6
           }
           dim {
-            dim_value: 6
+            dim_value: 5
           }
         }
       }
index b683a60..4d0245c 100644 (file)
@@ -324,7 +324,7 @@ def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
 def _jit_pass_onnx_fold_if(graph: Graph) -> None: ...
 def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
 def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
-def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue], isAllowedToAdjustGraphInputs: _bool) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
index 05afb69..4bad936 100644 (file)
@@ -141,14 +141,27 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
   }
 }
 
-void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
+void EvalPeepholeONNX(
+    Block* b,
+    ParamMap& paramsDict,
+    bool isAllowedToAdjustGraphInputs) {
   auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
-  fuseConvBatchNorm(b, valsToParamsMap);
+
+  // Optimizations like fusing Conv and BatchNorm ops may adjust the graph
+  // inputs. If the graph inputs are not allowed to be adjusted, for example
+  // export_params is False, such optimizations will be skipped.
+  if (isAllowedToAdjustGraphInputs) {
+    fuseConvBatchNorm(b, valsToParamsMap);
+  }
+
   buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
 }
 
-void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
-  EvalPeepholeONNX(g->block(), paramsDict);
+void EvalPeepholeONNX(
+    std::shared_ptr<Graph>& g,
+    ParamMap& paramsDict,
+    bool isAllowedToAdjustGraphInputs) {
+  EvalPeepholeONNX(g->block(), paramsDict, isAllowedToAdjustGraphInputs);
   GRAPH_DUMP("After EvalPeepholeONNX:", g);
 }
 
index 6f8961d..d953f2c 100644 (file)
@@ -9,7 +9,8 @@ namespace jit {
 
 void EvalPeepholeONNX(
     std::shared_ptr<Graph>& g,
-    std::map<std::string, IValue>& paramDict);
+    std::map<std::string, IValue>& paramDict,
+    bool isAllowedToAdjustGraphInputs);
 
 } // namespace jit
 
index 992e60e..86b64b8 100644 (file)
@@ -203,8 +203,9 @@ void initJITBindings(PyObject* module) {
       .def(
           "_jit_pass_onnx_eval_peephole",
           [](std::shared_ptr<Graph>& graph,
-             std::map<std::string, IValue>& paramsDict) {
-            EvalPeepholeONNX(graph, paramsDict);
+             std::map<std::string, IValue>& paramsDict,
+             bool isAllowedToAdjustGraphInputs) {
+            EvalPeepholeONNX(graph, paramsDict, isAllowedToAdjustGraphInputs);
             return paramsDict;
           },
           pybind11::return_value_policy::move)
index b726b2b..e058acc 100644 (file)
@@ -103,11 +103,17 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
         export_params (bool, default True): if True, all parameters will
             be exported. Set this to False if you want to export an untrained model.
             In this case, the exported model will first take all of its parameters
-            as arguments, with the ordering as specified by ``model.state_dict().values()``
+            as arguments, with the ordering as specified by ``model.state_dict().values()``.
+            This helps in stripping parameters from the model which is useful for training.
+            Besides, if this is False, any optimization that may adjust graph inputs will
+            be skipped - for example, Conv and BatchNorm fusion.
         verbose (bool, default False): if True, prints a description of the
             model being exported to stdout.
         training (enum, default TrainingMode.EVAL):
-            * ``TrainingMode.EVAL``: export the model in inference mode.
+            * ``TrainingMode.EVAL``: export the model in inference mode. In this case, optimizations
+              (e.g., fusing Conv and BatchNorm ops) may adjust graph inputs by modifying model params
+              and model param names. Such adjustment could be skipped by setting export_params = False
+              or keep_initializers_as_inputs = True.
             * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
               False and in training mode if model.training is True.
             * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations
@@ -184,6 +190,8 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
         do_constant_folding (bool, default False): Apply the constant-folding optimization.
             Constant-folding will replace some of the ops that have all constant inputs
             with pre-computed constant nodes.
+            Since this optimization adjusts model initializers, it will be disabled if
+            export_params = False or keep_initializers_as_inputs = True.
         example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None):
             Must be provided when exporting a ScriptModule or ScriptFunction, ignored otherwise.
             Used to determine the type and shape of the outputs without tracing the execution of
@@ -265,9 +273,13 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 
         keep_initializers_as_inputs (bool, default None): If True, all the
             initializers (typically corresponding to parameters) in the
-            exported graph will also be added as inputs to the graph. If False,
-            then initializers are not added as inputs to the graph, and only
-            the non-parameter inputs are added as inputs.
+            exported graph will also be added as inputs to the graph.
+
+            If False, then initializers are not added as inputs to the graph, and only
+            the non-parameter inputs are added as inputs. Meanwhile, the optimization
+            that might adjust graph inputs will be skipped (e.g., fusing Conv and
+            BatchNorm ops), even when the user export this model in inference mode.
+
             This may allow for better optimizations (e.g. constant folding) by
             backends/runtimes.
 
index 41ba20f..7860e38 100644 (file)
@@ -439,7 +439,8 @@ def _model_to_graph(model, args, verbose=False,
                     example_outputs=None,
                     _retain_param_name=False, do_constant_folding=True,
                     _disable_torch_constant_prop=False, fixed_batch_size=False,
-                    training=None, dynamic_axes=None):
+                    training=None, dynamic_axes=None, export_params=True,
+                    keep_initializers_as_inputs=False):
     r"""Converts model into an ONNX graph.
 
     Returns:
@@ -498,10 +499,12 @@ def _model_to_graph(model, args, verbose=False,
 
     params_dict = _get_named_param_dict(graph, params)
 
-    if training is None or training == TrainingMode.EVAL:
-        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
+    allow_adjust_graph_inputs = (export_params and not keep_initializers_as_inputs)
+    if (training is None or training == TrainingMode.EVAL):
+        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict, allow_adjust_graph_inputs)
 
-    if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
+    if do_constant_folding and allow_adjust_graph_inputs and \
+            _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
         params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
                                                             _export_onnx_opset_version)
         torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
@@ -569,7 +572,9 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
                                                         output_names, operator_export_type,
                                                         example_outputs, _retain_param_name,
                                                         val_do_constant_folding, fixed_batch_size=fixed_batch_size,
-                                                        training=training)
+                                                        training=training,
+                                                        export_params=export_params,
+                                                        keep_initializers_as_inputs=val_keep_init_as_ip)
 
         return graph._pretty_print_onnx(params_dict, opset_version, False,
                                         operator_export_type, google_printer,
@@ -685,7 +690,9 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
                                 val_do_constant_folding,
                                 fixed_batch_size=fixed_batch_size,
                                 training=training,
-                                dynamic_axes=dynamic_axes)
+                                dynamic_axes=dynamic_axes,
+                                export_params=export_params,
+                                keep_initializers_as_inputs=val_keep_init_as_ip)
 
             # TODO: Don't allocate a in-memory string for the protobuf
             defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE