[ONNX] Fix controlflow shape inference with contrib op (#60707) (#62762)
authorBowenBao <bowbao@microsoft.com>
Fri, 20 Aug 2021 19:44:29 +0000 (12:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 19:45:53 +0000 (12:45 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62762

`ONNXShapeTypeInference` for node `n` is skipped if `n` is non ONNX namespace, or if `n` contains any non ONNX namespace nodes. This prevents controlflow nodes containing contrib ops from running `SpecialPostProcess`, which sets up correct node output shape/type information in rare cases.

This PR depends on opset 14 export https://github.com/pytorch/pytorch/pull/59486

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375180

Pulled By: msaroufim

fbshipit-source-id: 5deacec39f091deb4d75ddd9e660e12fca7f16c5

Co-authored-by: BowenBao <bowbao@microsoft.com>
test/onnx/test_custom_ops.py
torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h
torch/csrc/jit/passes/onnx/shape_type_inference.cpp

index 739f267..04ac9a0 100644 (file)
@@ -125,5 +125,37 @@ class TestCustomAutogradFunction(unittest.TestCase):
         model = MyModule()
         run_model_test(self, model, input=(x, ))
 
+class TestExportAsContribOps(unittest.TestCase):
+    opset_version = 14
+    keep_initializers_as_inputs = False
+    onnx_shape_inference = True
+
+    def test_contrib_op_with_loop(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.gelu = torch.nn.GELU()
+
+            def forward(self, x):
+                res = []
+                res2 = []
+                for i in range(x.size(0)):
+                    if len(res) > 0:
+                        res2.append(res[0])
+                    else:
+                        res2.append(self.gelu(x[0]))
+                    res.append(x[0])
+                return torch.stack(res), torch.stack(res2)
+
+        def symbolic_custom_gelu(g, input):
+            return g.op("com.microsoft::Gelu", input).setType(input.type())
+
+        from torch.onnx import register_custom_op_symbolic
+        register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1)
+
+        x = torch.randn(3, 3, 4, requires_grad=True)
+        model = torch.jit.script(M())
+        run_model_test(self, model, input=(x, ))
+
 if __name__ == "__main__":
     unittest.main()
index abfb547..b0a310b 100644 (file)
@@ -238,9 +238,7 @@ std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
   auto new_outputs = ConvertSequenceDependencies(node, opset_version);
 
   // Copy type of block output to node output.
-  for (size_t i = 0; i < node->outputs().size(); ++i) {
-    node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
-  }
+  FixupONNXControlflowNodeOutputs(node);
   TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
   return new_outputs;
 }
@@ -347,25 +345,90 @@ void ONNXFixupUninitializedOutput(Node* node) {
           graph, else_block, else_block_output, then_block_output);
       if_node->outputs()[i]->setType(else_block->outputs()[i]->type());
     }
-    auto then_tensor_type =
-        then_block->outputs().at(i)->type()->castRaw<TensorType>();
-    auto else_tensor_type =
-        else_block->outputs().at(i)->type()->castRaw<TensorType>();
-    if (then_tensor_type && else_tensor_type) {
-      const auto& then_shape = then_tensor_type->symbolic_sizes();
-      const auto& else_shape = else_tensor_type->symbolic_sizes();
-      std::vector<::c10::ShapeSymbol> dims;
-      if (then_shape.rank() && else_shape.rank() &&
-          then_shape.rank() == else_shape.rank()) {
-        for (const auto j : c10::irange(then_shape.rank().value())) {
-          if (then_shape[j] == else_shape[j]) {
-            dims.emplace_back(then_shape[j]);
-          } else {
-            dims.emplace_back(::c10::ShapeSymbol::newSymbol());
-          }
+  }
+}
+
+void ONNXMergeIfBlockOutputShapes(Node* node) {
+  TORCH_INTERNAL_ASSERT(node->kind() == ::c10::onnx::If);
+  Block* then_block = node->blocks().at(0);
+  Block* else_block = node->blocks().at(1);
+
+  TORCH_INTERNAL_ASSERT(
+      then_block->outputs().size() == else_block->outputs().size())
+
+  auto findCommonShape =
+      [](const ::c10::SymbolicShape& a,
+         const ::c10::SymbolicShape& b) -> ::c10::SymbolicShape {
+    std::vector<::c10::ShapeSymbol> dims;
+    if (a.rank() && b.rank() && a.rank() == b.rank()) {
+      for (const auto j : c10::irange(a.rank().value())) {
+        if (a[j] == b[j]) {
+          dims.emplace_back(a[j]);
+        } else {
+          dims.emplace_back(::c10::ShapeSymbol::newSymbol());
         }
-        if_node->output(i)->setType(
-            then_tensor_type->withSymbolicShapes(::c10::SymbolicShape(dims)));
+      }
+      return ::c10::SymbolicShape(dims);
+    }
+    if (a.rank() && a.rank().value() > 0) {
+      return a;
+    }
+    if (b.rank() && b.rank().value() > 0) {
+      return b;
+    }
+
+    return ::c10::SymbolicShape();
+  };
+
+  auto mergeTensorType =
+      [&findCommonShape](TensorTypePtr a, TensorTypePtr b) -> TensorTypePtr {
+    if (a && b) {
+      const auto& a_shape = a->symbolic_sizes();
+      const auto& b_shape = b->symbolic_sizes();
+      auto commonShape = findCommonShape(a_shape, b_shape);
+      return a->withSymbolicShapes(commonShape);
+    } else if (a) {
+      return a;
+    } else if (b) {
+      return b;
+    }
+    return nullptr;
+  };
+
+  auto mergeListType = [&mergeTensorType](
+                           ListTypePtr a, ListTypePtr b) -> ListTypePtr {
+    if (a && b) {
+      auto a_tensor_type = a->getElementType()->cast<TensorType>();
+      auto b_tensor_type = b->getElementType()->cast<TensorType>();
+      auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
+      if (tensor_type) {
+        return a->withContained({tensor_type})->cast<ListType>();
+      }
+      // Both branches produce ListType without tensor shape.
+      return a;
+    } else if (a) {
+      return a;
+    } else if (b) {
+      return b;
+    }
+    return nullptr;
+  };
+
+  for (const auto i : c10::irange(else_block->outputs().size())) {
+    auto then_type = then_block->outputs().at(i)->type();
+    auto else_type = else_block->outputs().at(i)->type();
+    auto then_tensor_type = then_type->cast<TensorType>();
+    auto else_tensor_type = else_type->cast<TensorType>();
+    auto then_list_type = then_type->cast<ListType>();
+    auto else_list_type = else_type->cast<ListType>();
+    if (then_tensor_type || else_tensor_type) {
+      if (auto tensor_type =
+              mergeTensorType(then_tensor_type, else_tensor_type)) {
+        node->output(i)->setType(tensor_type);
+      }
+    } else if (then_list_type || else_list_type) {
+      if (auto list_type = mergeListType(then_list_type, else_list_type)) {
+        node->output(i)->setType(list_type);
       }
     }
   }
@@ -376,16 +439,13 @@ std::vector<Value*> FixupONNXIfNode(Node* node, int opset_version) {
     return node->outputs().vec();
   }
   GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph());
-  auto* if_node = node;
   FixupONNXSubblockOutputs(node);
-  ONNXFixupUninitializedOutput(if_node);
+  ONNXFixupUninitializedOutput(node);
   // Copy type of block output to node output.
-  for (size_t i = 0; i < node->outputs().size(); ++i) {
-    node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
-  }
+  ONNXMergeIfBlockOutputShapes(node);
 
   GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
-  return if_node->outputs().vec();
+  return node->outputs().vec();
 }
 
 std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version) {
@@ -401,5 +461,36 @@ std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version) {
   }
 }
 
+void FixupONNXControlflowNodeOutputs(Node* n) {
+  switch (n->kind()) {
+    case ::c10::onnx::Loop: {
+      auto loop_carried_output_size = n->blocks().at(0)->inputs().size() - 2;
+      for (auto i : c10::irange(n->outputs().size())) {
+        auto type = n->blocks().at(0)->outputs().at(i + 1)->type();
+        if (i < loop_carried_output_size) {
+          n->output(i)->setType(type);
+        } else {
+          if (auto t_type = type->cast<TensorType>()) {
+            auto sizes = t_type->symbolic_sizes().sizes();
+            if (sizes.has_value()) {
+              sizes.value().emplace(
+                  sizes.value().begin(), c10::ShapeSymbol::newSymbol());
+              type = t_type->withSymbolicShapes(sizes.value());
+            }
+          }
+          n->output(i)->setType(type);
+        }
+      }
+      break;
+    }
+    case ::c10::onnx::If: {
+      ONNXMergeIfBlockOutputShapes(n);
+      break;
+    }
+    default:
+      break;
+  }
+}
+
 } // namespace jit
 } // namespace torch
index fad7611..8d33c2d 100644 (file)
@@ -6,6 +6,7 @@ namespace torch {
 namespace jit {
 
 std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version);
+void FixupONNXControlflowNodeOutputs(Node* n);
 
 } // namespace jit
 } // namespace torch
index f630cf0..8ade722 100644 (file)
@@ -4,6 +4,7 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
 #include <torch/csrc/jit/passes/onnx/constant_map.h>
+#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
 #include <torch/csrc/jit/passes/onnx/fold_if_node.h>
 #include <torch/csrc/jit/passes/onnx/helper.h>
 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
@@ -180,7 +181,21 @@ void UpdateTorchValueByOnnxValueInfo(
   }
 }
 
-bool IsSupportedNode(const Node* n) {
+bool IsValidONNXControlflowNode(const Node* n) {
+  // Skip when block size is zero. This is when the node is being created,
+  // and doesn't have subblocks attached yet. Run shape inference for these
+  // nodes later, when the subgraph has already completed shape inferencing.
+  auto node_kind = n->kind();
+  if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
+    if (n->blocks().size() == 0) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool IsValidONNXNode(const Node* n) {
   auto node_kind = n->kind();
 
   if (!node_kind.is_onnx()) {
@@ -188,18 +203,14 @@ bool IsSupportedNode(const Node* n) {
     return false;
   }
 
-  // Skip when block size is zero. This is when the node is first created,
-  // doesn't have subblocks attached yet. Run shape inference for these nodes
-  // when the subgraph has already completed shape inferencing.
-  if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
-    if (n->blocks().size() == 0) {
-      return false;
-    }
-    for (auto b : n->blocks()) {
-      for (auto b_n : b->nodes()) {
-        if (!IsSupportedNode(b_n)) {
-          return false;
-        }
+  if (!IsValidONNXControlflowNode(n)) {
+    return false;
+  }
+
+  for (auto b : n->blocks()) {
+    for (auto b_n : b->nodes()) {
+      if (!IsValidONNXNode(b_n)) {
+        return false;
       }
     }
   }
@@ -1310,6 +1321,20 @@ void SpecialPostProcess(Node* n) {
       }
       break;
     }
+    case ::c10::onnx::If: {
+      if (!IsValidONNXControlflowNode(n)) {
+        break;
+      }
+      FixupONNXControlflowNodeOutputs(n);
+      break;
+    }
+    case ::c10::onnx::Loop: {
+      if (!IsValidONNXControlflowNode(n)) {
+        break;
+      }
+      FixupONNXControlflowNodeOutputs(n);
+      break;
+    }
   }
 }
 
@@ -1391,64 +1416,67 @@ void ONNXShapeTypeInference(
     int opset_version) {
   GRAPH_UPDATE(
       "Running ONNX shape inference for node: ", n->kind().toDisplayString());
-  if (!IsSupportedNode(n)) {
-    return;
-  }
-  // Create a Graph containing only the single node n.
-  // This graph is later converted to ONNX to run shape inference.
-  auto n_graph = std::make_shared<Graph>();
-  auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
-  n_graph->insertNode(clone_node);
+  if (IsValidONNXNode(n)) {
+    // Create a Graph containing only the single node n.
+    // This graph is later converted to ONNX to run shape inference.
+    auto n_graph = std::make_shared<Graph>();
+    auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
+    n_graph->insertNode(clone_node);
 
-  // Register all node outputs as graph outputs.
-  for (auto output : clone_node->outputs()) {
-    n_graph->registerOutput(output);
-  }
+    // Register all node outputs as graph outputs.
+    for (auto output : clone_node->outputs()) {
+      n_graph->registerOutput(output);
+    }
 
-  // Use scalar_type_analysis without low precision cast
-  ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
+    // Use scalar_type_analysis without low precision cast
+    ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
 
-  GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
-  GRAPH_DEBUG(
-      "Cloned torch graph to run shape inference: ", n_graph->toString());
-
-  if (IsGraphValidForInference(n_graph)) {
-    // TODO: Some ops have conversion happen at Peephole pass.
-    //       The conversion here is incomplete for these ops.
-    //       e.g: ListConstruct, ListUnpack, etc.
-    std::shared_ptr<onnx::ModelProto> model_proto;
-    SymbolDimMap symbol_map;
-    ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
+    GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
     GRAPH_DEBUG(
-        "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
-
-    // infer shape
-    try {
-      onnx::shape_inference::InferShapes(*model_proto);
-      UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
-    } catch (std::runtime_error& ex) {
-      // TODO: include this as warning once we have a more consolidated warning
-      // system.
+        "Cloned torch graph to run shape inference: ", n_graph->toString());
+
+    if (IsGraphValidForInference(n_graph)) {
+      // TODO: Some ops have conversion happen at Peephole pass.
+      //       The conversion here is incomplete for these ops.
+      //       e.g: ListConstruct, ListUnpack, etc.
+      std::shared_ptr<onnx::ModelProto> model_proto;
+      SymbolDimMap symbol_map;
+      ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
       GRAPH_DEBUG(
-          "ONNX shape inference fails with: ",
-          ex.what(),
-          " on graph: ",
-          n_graph->toString());
-      // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
-      const char shape_err[] = "ShapeInferenceError";
-      // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
-      const char type_err[] = "TypeInferenceError";
-      if ((strstr(ex.what(), shape_err) == nullptr) &&
-          (strstr(ex.what(), type_err) == nullptr)) {
-        throw;
+          "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
+
+      // infer shape
+      try {
+        onnx::shape_inference::InferShapes(*model_proto);
+        UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
+      } catch (std::runtime_error& ex) {
+        // TODO: include this as warning once we have a more consolidated
+        // warning system.
+        GRAPH_DEBUG(
+            "ONNX shape inference fails with: ",
+            ex.what(),
+            " on graph: ",
+            n_graph->toString());
+        // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
+        const char shape_err[] = "ShapeInferenceError";
+        // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
+        const char type_err[] = "TypeInferenceError";
+        // NOLINTNEXTLINE(modernize-use-nullptr)
+        if ((strstr(ex.what(), shape_err) == NULL) &&
+            // NOLINTNEXTLINE(modernize-use-nullptr)
+            (strstr(ex.what(), type_err) == NULL)) {
+          throw;
+        }
       }
+      GRAPH_DEBUG(
+          "ONNX graph after shape inference: ", prettyPrint(*model_proto));
     }
-    GRAPH_DEBUG(
-        "ONNX graph after shape inference: ", prettyPrint(*model_proto));
   }
 
   SpecialPostProcess(n);
-  ProcessConstantValueMap(n, opset_version);
+  if (IsValidONNXNode(n)) {
+    ProcessConstantValueMap(n, opset_version);
+  }
   GRAPH_DEBUG(
       "Torch graph after shape inference:", n->owningGraph()->toString());
 }