[mlir] canonicalize away zero-iteration SCF for loops
authorAlex Zinenko <zinenko@google.com>
Fri, 20 Nov 2020 18:22:30 +0000 (19:22 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 23 Nov 2020 14:04:31 +0000 (15:04 +0100)
An SCF 'for' loop does not iterate if its lower bound is equal to its upper
bound. Remove loops where both bounds are the same SSA value as such bounds are
guaranteed to be equal. Similarly, remove 'parallel' loops where at least one
pair of respective lower/upper bounds is specified by the same SSA value.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D91880

mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index 5da9f7c..48b1b47 100644 (file)
@@ -521,6 +521,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
+    // If the upper bound is the same as the lower bound, the loop does not
+    // iterate, just remove it.
+    if (op.lowerBound() == op.upperBound()) {
+      rewriter.replaceOp(op, op.getIterOperands());
+      return success();
+    }
+
     auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
     auto ub = op.upperBound().getDefiningOp<ConstantOp>();
     if (!lb || !ub)
@@ -1066,11 +1073,30 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
     return success();
   }
 };
+
+/// Removes parallel loops in which at least one lower/upper bound pair consists
+/// of the same values - such loops have an empty iteration domain.
+struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
+  using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ParallelOp op,
+                                PatternRewriter &rewriter) const override {
+    for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
+      if (std::get<0>(dim) == std::get<1>(dim)) {
+        rewriter.replaceOp(op, op.initVals());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
-  results.insert<CollapseSingleIterationLoops>(context);
+  results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
index faac86b..d575634 100644 (file)
@@ -32,30 +32,6 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // -----
 
-func @no_iteration(%A: memref<?x?xi32>) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) {
-    %c42 = constant 42 : i32
-    store %c42, %A[%i0, %i1] : memref<?x?xi32>
-    scf.yield
-  }
-  return
-}
-
-// CHECK-LABEL:   func @no_iteration(
-// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?xi32>) {
-// CHECK:           [[C0:%.*]] = constant 0 : index
-// CHECK:           [[C1:%.*]] = constant 1 : index
-// CHECK:           [[C42:%.*]] = constant 42 : i32
-// CHECK:           scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) {
-// CHECK:             store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref<?x?xi32>
-// CHECK:             scf.yield
-// CHECK:           }
-// CHECK:           return
-
-// -----
-
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -241,6 +217,22 @@ func @remove_zero_iteration_loop() {
   return
 }
 
+// CHECK-LABEL: @remove_zero_iteration_loop_vals
+func @remove_zero_iteration_loop_vals(%arg0: index) {
+  %c2 = constant 2 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK-NOT: test.op
+  %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
 // CHECK-LABEL: @replace_single_iteration_loop
 func @replace_single_iteration_loop() {
   // CHECK: %[[LB:.*]] = constant 42
@@ -278,3 +270,24 @@ func @replace_single_iteration_loop_non_unit_step() {
   "test.consume"(%0) : (i32) -> ()
   return
 }
+
+// CHECK-LABEL: @remove_empty_parallel_loop
+func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> f32
+  // CHECK-NOT: scf.parallel
+  // CHECK-NOT: test.produce
+  // CHECK-NOT: test.transform
+  %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 {
+    %1 = "test.produce"() : () -> f32
+    scf.reduce(%1) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
+      scf.reduce.return %2 : f32
+    }
+    scf.yield
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (f32) -> ()
+  return
+}