Support folding of ops with inner ops in GreedyPatternRewriteDriver.
authorAndy Ly <lyandy@google.com>
Mon, 26 Aug 2019 16:44:09 +0000 (09:44 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Aug 2019 16:44:39 +0000 (09:44 -0700)
This fixes a bug when folding ops with inner ops and inner ops are still being visited.

PiperOrigin-RevId: 265475780

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-canonicalize.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestOps.td

index 3fb1998..ddb92a5 100644 (file)
@@ -96,8 +96,6 @@ protected:
   // worklist anymore because we'd get dangling references to it.
   void notifyOperationRemoved(Operation *op) override {
     addToWorklist(op->getOperands());
-    removeFromWorklist(op);
-    folder.notifyRemoval(op);
     op->walk([this](Operation *operation) {
       removeFromWorklist(operation);
       folder.notifyRemoval(operation);
@@ -174,17 +172,16 @@ bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
       // If the operation has no side effects, and no users, then it is
       // trivially dead - remove it.
       if (op->hasNoSideEffect() && op->use_empty()) {
-        // Be careful to update bookkeeping in OperationFolder to keep
-        // consistency if this is a constant op.
-        folder.notifyRemoval(op);
+        // Be careful to update bookkeeping.
+        notifyOperationRemoved(op);
         op->erase();
         continue;
       }
 
       // Collects all the operands and result uses of the given `op` into work
-      // list.
+      // list. Also remove `op` and nested ops from worklist.
       originalOperands.assign(op->operand_begin(), op->operand_end());
-      auto collectOperandsAndUses = [&](Operation *op) {
+      auto preReplaceAction = [&](Operation *op) {
         // Add the operands to the worklist for visitation.
         addToWorklist(originalOperands);
 
@@ -193,10 +190,12 @@ bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
         for (auto *result : op->getResults())
           for (auto *operand : result->getUsers())
             addToWorklist(operand);
+
+        notifyOperationRemoved(op);
       };
 
       // Try to fold this op.
-      if (succeeded(folder.tryToFold(op, collectOps, collectOperandsAndUses))) {
+      if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
         changed |= true;
         continue;
       }
index a98be27..7bae125 100644 (file)
@@ -1,10 +1,29 @@
 // RUN: mlir-opt %s -canonicalize | FileCheck %s
 
-// CHECK-LABEL: func @remove_op_with_inner_ops
-func @remove_op_with_inner_ops() {
+// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
+func @remove_op_with_inner_ops_pattern() {
   // CHECK-NEXT: return
-  "test.op_with_region"() ({
+  "test.op_with_region_pattern"() ({
     "foo.op_with_region_terminator"() : () -> ()
   }) : () -> ()
   return
 }
+
+// CHECK-LABEL: func @remove_op_with_inner_ops_fold_no_side_effect
+func @remove_op_with_inner_ops_fold_no_side_effect() {
+  // CHECK-NEXT: return
+  "test.op_with_region_fold_no_side_effect"() ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @remove_op_with_inner_ops_fold
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32)
+func @remove_op_with_inner_ops_fold(%arg0 : i32) -> (i32) {
+  // CHECK-NEXT: return %[[ARG_0]]
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i32) -> (i32)
+  return %0 : i32
+}
index 8b44b6c..af5c5c8 100644 (file)
@@ -82,10 +82,11 @@ static ParseResult parsePolyForOp(OpAsmParser *parser, OperationState *result) {
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct TestRemoveOpWithInnerOps : public OpRewritePattern<TestOpWithRegion> {
-  using OpRewritePattern<TestOpWithRegion>::OpRewritePattern;
+struct TestRemoveOpWithInnerOps
+    : public OpRewritePattern<TestOpWithRegionPattern> {
+  using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(TestOpWithRegion op,
+  PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
                                      PatternRewriter &rewriter) const override {
     rewriter.replaceOp(op, llvm::None);
     return matchSuccess();
@@ -93,11 +94,15 @@ struct TestRemoveOpWithInnerOps : public OpRewritePattern<TestOpWithRegion> {
 };
 } // end anonymous namespace
 
-void TestOpWithRegion::getCanonicalizationPatterns(
+void TestOpWithRegionPattern::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<TestRemoveOpWithInnerOps>(context);
 }
 
+OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
+  return operand();
+}
+
 // Static initialization for Test dialect registration.
 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
 
index 2926930..fd2d235 100644 (file)
@@ -338,11 +338,25 @@ def : Pat<(OpAllAttrConstraint1
           (OpAllAttrConstraint2 $attr)>;
 
 // Op for testing RewritePattern removing op with inner ops.
-def TestOpWithRegion : TEST_Op<"op_with_region"> {
+def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
   let regions = (region SizedRegion<1>:$region);
   let hasCanonicalizer = 1;
 }
 
+// Op for testing trivial removal via folding of op with inner ops and no uses.
+def TestOpWithRegionFoldNoSideEffect : TEST_Op<
+    "op_with_region_fold_no_side_effect", [NoSideEffect]> {
+  let regions = (region SizedRegion<1>:$region);
+}
+
+// Op for testing folding of outer op with inner ops.
+def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> {
+  let arguments = (ins I32:$operand);
+  let results = (outs I32:$result);
+  let regions = (region SizedRegion<1>:$region);
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Symbol Binding)