[MLIR][SCF] Inline single block ExecuteRegionOp
authorWilliam S. Moses <gh@wsmoses.com>
Thu, 24 Jun 2021 16:33:54 +0000 (12:33 -0400)
committerWilliam S. Moses <gh@wsmoses.com>
Thu, 24 Jun 2021 17:15:26 +0000 (13:15 -0400)
This commit adds a canonicalization pass which inlines any single block execute region

Differential Revision: https://reviews.llvm.org/D104865

mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index a558439..c10441f 100644 (file)
@@ -111,7 +111,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
   // TODO: If the parent is a func like op (which would be the case if all other
   // ops are from the std dialect), the inliner logic could be readily used to
   // inline.
-  let hasCanonicalizer = 0;
+  let hasCanonicalizer = 1;
 
   // TODO: can fold if it returns a constant.
   // TODO: Single block execute_region ops can be readily inlined irrespective
index 0d22e17..99d2386 100644 (file)
@@ -73,6 +73,19 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
 
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
+                                Region &region, ValueRange blockArgs = {}) {
+  assert(llvm::hasSingleElement(region) && "expected single-region block");
+  Block *block = &region.front();
+  Operation *terminator = block->getTerminator();
+  ValueRange results = terminator->getOperands();
+  rewriter.mergeBlockBefore(block, op, blockArgs);
+  rewriter.replaceOp(op, results);
+  rewriter.eraseOp(terminator);
+}
+
 ///
 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
 ///    block+
@@ -118,6 +131,37 @@ static LogicalResult verify(ExecuteRegionOp op) {
   return success();
 }
 
+// Inline an ExecuteRegionOp if it only contains one block.
+//     "test.foo"() : () -> ()
+//      %v = scf.execute_region -> i64 {
+//        %x = "test.val"() : () -> i64
+//        scf.yield %x : i64
+//      }
+//      "test.bar"(%v) : (i64) -> ()
+//
+//  becomes
+//
+//     "test.foo"() : () -> ()
+//     %x = "test.val"() : () -> i64
+//     "test.bar"(%v) : (i64) -> ()
+//
+struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
+  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExecuteRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.region().getBlocks().size() != 1)
+      return failure();
+    replaceOpWithRegion(rewriter, op, op.region());
+    return success();
+  }
+};
+
+void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<SingleBlockExecuteInliner>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
@@ -444,19 +488,6 @@ LoopNest mlir::scf::buildLoopNest(
                        });
 }
 
-/// Replaces the given op with the contents of the given single-block region,
-/// using the operands of the block terminator to replace operation results.
-static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
-                                Region &region, ValueRange blockArgs = {}) {
-  assert(llvm::hasSingleElement(region) && "expected single-region block");
-  Block *block = &region.front();
-  Operation *terminator = block->getTerminator();
-  ValueRange results = terminator->getOperands();
-  rewriter.mergeBlockBefore(block, op, blockArgs);
-  rewriter.replaceOp(op, results);
-  rewriter.eraseOp(terminator);
-}
-
 namespace {
 // Fold away ForOp iter arguments when:
 // 1) The op yields the iter arguments.
index 8b57bc8..8692f2d 100644 (file)
@@ -921,9 +921,30 @@ func @propagate_into_execute_region() {
     }
     "test.bar"(%v) : (i64) -> ()
     // CHECK:      %[[C2:.*]] = constant 2 : i64
-    // CHECK:        scf.execute_region -> i64 {
-    // CHECK-NEXT:     scf.yield %[[C2]] : i64
-    // CHECK-NEXT:   }
+    // CHECK: "test.foo"
+    // CHECK-NEXT: "test.bar"(%[[C2]]) : (i64) -> ()
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @execute_region_elim
+func @execute_region_elim() {
+  affine.for %i = 0 to 100 {
+    "test.foo"() : () -> ()
+    %v = scf.execute_region -> i64 {
+      %x = "test.val"() : () -> i64
+      scf.yield %x : i64
+    }
+    "test.bar"(%v) : (i64) -> ()
+  }
+  return
+}
+
+// CHECK-NEXT:     affine.for %arg0 = 0 to 100 {
+// CHECK-NEXT:       "test.foo"() : () -> ()
+// CHECK-NEXT:       %[[VAL:.*]] = "test.val"() : () -> i64
+// CHECK-NEXT:       "test.bar"(%[[VAL]]) : (i64) -> ()
+// CHECK-NEXT:     }
+