Back out "[ONNX] Fix an issue that optimizations might adjust graph inputs unexpected...
authorMeghan Lele <meghanl@fb.com>
Thu, 26 Aug 2021 19:48:01 +0000 (12:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 19:49:42 +0000 (12:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63904

Fixes T98808160

Test Plan: T98808160

Reviewed By: msaroufim

Differential Revision: D30527450

fbshipit-source-id: 6262901a78ca929cecda1cf740893139aa26f1b4

test/onnx/expect/TestOperators.test_prelu.expect
test/onnx/expect/TestOperators.test_retain_param_name_disabled.expect
torch/_C/__init__.pyi.in
torch/csrc/jit/passes/onnx/eval_peephole.cpp
torch/csrc/jit/passes/onnx/eval_peephole.h
torch/csrc/jit/python/init.cpp
torch/onnx/__init__.py
torch/onnx/utils.py

index be0328e..e19623c 100644 (file)
@@ -3,29 +3,19 @@ producer_name: "pytorch"
 producer_version: "CURRENT_VERSION"
 graph {
   node {
-    input: "weight"
-    output: "2"
-    name: "Unsqueeze_0"
-    op_type: "Unsqueeze"
-    attribute {
-      name: "axes"
-      ints: 1
-      ints: 2
-      type: INTS
-    }
-  }
-  node {
     input: "input"
-    input: "2"
+    input: "4"
     output: "3"
-    name: "PRelu_1"
+    name: "PRelu_0"
     op_type: "PRelu"
   }
   name: "torch-jit-export"
   initializer {
     dims: 2
+    dims: 1
+    dims: 1
     data_type: 1
-    name: "weight"
+    name: "4"
     raw_data: "\000\000\200>\000\000\200>"
   }
   input {
@@ -51,7 +41,7 @@ graph {
     }
   }
   input {
-    name: "weight"
+    name: "4"
     type {
       tensor_type {
         elem_type: 1
@@ -59,6 +49,12 @@ graph {
           dim {
             dim_value: 2
           }
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 1
+          }
         }
       }
     }
index aa9499e..5eeaa87 100644 (file)
@@ -3,56 +3,32 @@ producer_name: "pytorch"
 producer_version: "CURRENT_VERSION"
 graph {
   node {
-    input: "1"
-    output: "3"
-    name: "Transpose_0"
-    op_type: "Transpose"
-    attribute {
-      name: "perm"
-      ints: 1
-      ints: 0
-      type: INTS
-    }
-  }
-  node {
     input: "input.1"
-    input: "3"
+    input: "7"
     output: "4"
-    name: "MatMul_1"
+    name: "MatMul_0"
     op_type: "MatMul"
   }
   node {
-    input: "2"
-    output: "5"
-    name: "Transpose_2"
-    op_type: "Transpose"
-    attribute {
-      name: "perm"
-      ints: 1
-      ints: 0
-      type: INTS
-    }
-  }
-  node {
     input: "4"
-    input: "5"
+    input: "8"
     output: "6"
-    name: "MatMul_3"
+    name: "MatMul_1"
     op_type: "MatMul"
   }
   name: "torch-jit-export"
   initializer {
-    dims: 5
     dims: 4
+    dims: 5
     data_type: 1
-    name: "1"
+    name: "7"
     raw_data: "\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@\000\000\000@"
   }
   initializer {
-    dims: 6
     dims: 5
+    dims: 6
     data_type: 1
-    name: "2"
+    name: "8"
     raw_data: "\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@"
   }
   input {
@@ -72,32 +48,32 @@ graph {
     }
   }
   input {
-    name: "1"
+    name: "7"
     type {
       tensor_type {
         elem_type: 1
         shape {
           dim {
-            dim_value: 5
+            dim_value: 4
           }
           dim {
-            dim_value: 4
+            dim_value: 5
           }
         }
       }
     }
   }
   input {
-    name: "2"
+    name: "8"
     type {
       tensor_type {
         elem_type: 1
         shape {
           dim {
-            dim_value: 6
+            dim_value: 5
           }
           dim {
-            dim_value: 5
+            dim_value: 6
           }
         }
       }
index 0b6bb6b..3629150 100644 (file)
@@ -326,7 +326,7 @@ def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
 def _jit_pass_onnx_fold_if(graph: Graph) -> None: ...
 def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
 def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
-def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue], isAllowedToAdjustGraphInputs: _bool) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
 def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
index 4bad936..05afb69 100644 (file)
@@ -141,27 +141,14 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
   }
 }
 
-void EvalPeepholeONNX(
-    Block* b,
-    ParamMap& paramsDict,
-    bool isAllowedToAdjustGraphInputs) {
+void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
   auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
-
-  // Optimizations like fusing Conv and BatchNorm ops may adjust the graph
-  // inputs. If the graph inputs are not allowed to be adjusted, for example
-  // export_params is False, such optimizations will be skipped.
-  if (isAllowedToAdjustGraphInputs) {
-    fuseConvBatchNorm(b, valsToParamsMap);
-  }
-
+  fuseConvBatchNorm(b, valsToParamsMap);
   buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
 }
 
-void EvalPeepholeONNX(
-    std::shared_ptr<Graph>& g,
-    ParamMap& paramsDict,
-    bool isAllowedToAdjustGraphInputs) {
-  EvalPeepholeONNX(g->block(), paramsDict, isAllowedToAdjustGraphInputs);
+void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
+  EvalPeepholeONNX(g->block(), paramsDict);
   GRAPH_DUMP("After EvalPeepholeONNX:", g);
 }
 
index d953f2c..6f8961d 100644 (file)
@@ -9,8 +9,7 @@ namespace jit {
 
 void EvalPeepholeONNX(
     std::shared_ptr<Graph>& g,
-    std::map<std::string, IValue>& paramDict,
-    bool isAllowedToAdjustGraphInputs);
+    std::map<std::string, IValue>& paramDict);
 
 } // namespace jit
 
index 645fea2..7e43e51 100644 (file)
@@ -203,9 +203,8 @@ void initJITBindings(PyObject* module) {
       .def(
           "_jit_pass_onnx_eval_peephole",
           [](std::shared_ptr<Graph>& graph,
-             std::map<std::string, IValue>& paramsDict,
-             bool isAllowedToAdjustGraphInputs) {
-            EvalPeepholeONNX(graph, paramsDict, isAllowedToAdjustGraphInputs);
+             std::map<std::string, IValue>& paramsDict) {
+            EvalPeepholeONNX(graph, paramsDict);
             return paramsDict;
           },
           pybind11::return_value_policy::move)
index e058acc..b726b2b 100644 (file)
@@ -103,17 +103,11 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
         export_params (bool, default True): if True, all parameters will
             be exported. Set this to False if you want to export an untrained model.
             In this case, the exported model will first take all of its parameters
-            as arguments, with the ordering as specified by ``model.state_dict().values()``.
-            This helps in stripping parameters from the model which is useful for training.
-            Besides, if this is False, any optimization that may adjust graph inputs will
-            be skipped - for example, Conv and BatchNorm fusion.
+            as arguments, with the ordering as specified by ``model.state_dict().values()``
         verbose (bool, default False): if True, prints a description of the
             model being exported to stdout.
         training (enum, default TrainingMode.EVAL):
-            * ``TrainingMode.EVAL``: export the model in inference mode. In this case, optimizations
-              (e.g., fusing Conv and BatchNorm ops) may adjust graph inputs by modifying model params
-              and model param names. Such adjustment could be skipped by setting export_params = False
-              or keep_initializers_as_inputs = True.
+            * ``TrainingMode.EVAL``: export the model in inference mode.
             * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
               False and in training mode if model.training is True.
             * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations
@@ -190,8 +184,6 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
         do_constant_folding (bool, default False): Apply the constant-folding optimization.
             Constant-folding will replace some of the ops that have all constant inputs
             with pre-computed constant nodes.
-            Since this optimization adjusts model initializers, it will be disabled if
-            export_params = False or keep_initializers_as_inputs = True.
         example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None):
             Must be provided when exporting a ScriptModule or ScriptFunction, ignored otherwise.
             Used to determine the type and shape of the outputs without tracing the execution of
@@ -273,13 +265,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 
         keep_initializers_as_inputs (bool, default None): If True, all the
             initializers (typically corresponding to parameters) in the
-            exported graph will also be added as inputs to the graph.
-
-            If False, then initializers are not added as inputs to the graph, and only
-            the non-parameter inputs are added as inputs. Meanwhile, the optimization
-            that might adjust graph inputs will be skipped (e.g., fusing Conv and
-            BatchNorm ops), even when the user export this model in inference mode.
-
+            exported graph will also be added as inputs to the graph. If False,
+            then initializers are not added as inputs to the graph, and only
+            the non-parameter inputs are added as inputs.
             This may allow for better optimizations (e.g. constant folding) by
             backends/runtimes.
 
index 7860e38..41ba20f 100644 (file)
@@ -439,8 +439,7 @@ def _model_to_graph(model, args, verbose=False,
                     example_outputs=None,
                     _retain_param_name=False, do_constant_folding=True,
                     _disable_torch_constant_prop=False, fixed_batch_size=False,
-                    training=None, dynamic_axes=None, export_params=True,
-                    keep_initializers_as_inputs=False):
+                    training=None, dynamic_axes=None):
     r"""Converts model into an ONNX graph.
 
     Returns:
@@ -499,12 +498,10 @@ def _model_to_graph(model, args, verbose=False,
 
     params_dict = _get_named_param_dict(graph, params)
 
-    allow_adjust_graph_inputs = (export_params and not keep_initializers_as_inputs)
-    if (training is None or training == TrainingMode.EVAL):
-        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict, allow_adjust_graph_inputs)
+    if training is None or training == TrainingMode.EVAL:
+        params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
 
-    if do_constant_folding and allow_adjust_graph_inputs and \
-            _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
+    if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
         params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
                                                             _export_onnx_opset_version)
         torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
@@ -572,9 +569,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
                                                         output_names, operator_export_type,
                                                         example_outputs, _retain_param_name,
                                                         val_do_constant_folding, fixed_batch_size=fixed_batch_size,
-                                                        training=training,
-                                                        export_params=export_params,
-                                                        keep_initializers_as_inputs=val_keep_init_as_ip)
+                                                        training=training)
 
         return graph._pretty_print_onnx(params_dict, opset_version, False,
                                         operator_export_type, google_printer,
@@ -690,9 +685,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
                                 val_do_constant_folding,
                                 fixed_batch_size=fixed_batch_size,
                                 training=training,
-                                dynamic_axes=dynamic_axes,
-                                export_params=export_params,
-                                keep_initializers_as_inputs=val_keep_init_as_ip)
+                                dynamic_axes=dynamic_axes)
 
             # TODO: Don't allocate a in-memory string for the protobuf
             defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE