[mlir] Extend SimplifyTrivialLoops
authorAmy Zhuang <amy.zhuang@intel.com>
Thu, 17 Mar 2022 15:40:25 +0000 (08:40 -0700)
committerAmy Zhuang <amy.zhuang@intel.com>
Thu, 17 Mar 2022 16:11:24 +0000 (09:11 -0700)
Fold away empty loops that iterate at least once and only return
values defined outside of the loop.

Reviewed By: bondhugula, dcaballe

Differential Revision: https://reviews.llvm.org/D121148

mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index e025cf6..cd2b8a0 100644 (file)
@@ -713,8 +713,9 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   }
 };
 
-/// Rewriting pattern that erases loops that are known not to iterate and
-/// replaces single-iteration loops with their bodies.
+/// Rewriting pattern that erases loops that are known not to iterate, replaces
+/// single-iteration loops with their bodies, and removes empty loops that
+/// iterate at least once and only return values defined outside of the loop.
 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
   using OpRewritePattern<ForOp>::OpRewritePattern;
 
@@ -756,7 +757,19 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       return success();
     }
 
-    return failure();
+    // Now we are left with loops that have more than 1 iterations.
+    Block &block = op.getRegion().front();
+    if (!llvm::hasSingleElement(block))
+      return failure();
+    // If the loop is empty, iterates at least once, and only returns values
+    // defined outside of the loop, remove it and replace it with yield values.
+    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
+    auto yieldOperands = yieldOp.getOperands();
+    if (llvm::any_of(yieldOperands,
+                     [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
+      return failure();
+    rewriter.replaceOp(op, yieldOperands);
+    return success();
   }
 };
 
index 955f6bb..6edf9ff 100644 (file)
@@ -363,6 +363,26 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
 
 // -----
 
+// Test that an empty loop which iterates at least once and only returns
+// values defined outside of the loop is folded away.
+func @for_yields_4() -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %a = arith.constant 3 : i32
+  %b = arith.constant 4 : i32
+  %r = scf.for %i = %c0 to %c2 step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r : i32
+}
+
+// CHECK-LABEL:   func @for_yields_4
+//  CHECK-NEXT:     %[[b:.*]] = arith.constant 4 : i32
+//  CHECK-NEXT:     return %[[b]] : i32
+
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func @replace_true_if() {
   %true = arith.constant true