[mlir] Extend AffineForEmptyLoopFolder
authorAmy Zhuang <amy.zhuang@intel.com>
Wed, 9 Mar 2022 01:17:22 +0000 (17:17 -0800)
committerAmy Zhuang <amy.zhuang@intel.com>
Wed, 9 Mar 2022 01:17:22 +0000 (17:17 -0800)
Currently when we fold an empty loop, we assume that any loop
with iterArgs returns its iterArgs in order, which is not always
the case. It may return values defined outside of the loop or
return its iterArgs out of order. This patch adds support to
those cases.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D120776

mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir

index 90aa7ba..c295206 100644 (file)
@@ -1657,6 +1657,16 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
 }
 
 namespace {
+/// Returns constant trip count in trivial cases.
+static Optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+  int64_t step = forOp.getStep();
+  if (!forOp.hasConstantBounds() || step <= 0)
+    return None;
+  int64_t lb = forOp.getConstantLowerBound();
+  int64_t ub = forOp.getConstantUpperBound();
+  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
 /// This is a pattern to fold trivially empty loop bodies.
 /// TODO: This should be moved into the folding hook.
 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
@@ -1667,8 +1677,46 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
     // Check that the body only contains a yield.
     if (!llvm::hasSingleElement(*forOp.getBody()))
       return failure();
-    // The initial values of the iteration arguments would be the op's results.
-    rewriter.replaceOp(forOp, forOp.getIterOperands());
+    if (forOp.getNumResults() == 0)
+      return success();
+    Optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+    if (tripCount.hasValue() && tripCount.getValue() == 0) {
+      // The initial values of the iteration arguments would be the op's
+      // results.
+      rewriter.replaceOp(forOp, forOp.getIterOperands());
+      return success();
+    }
+    SmallVector<Value, 4> replacements;
+    auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+    auto iterArgs = forOp.getRegionIterArgs();
+    bool hasValDefinedOutsideLoop = false;
+    bool iterArgsNotInOrder = false;
+    for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+      Value val = yieldOp.getOperand(i);
+      auto iterArgIt = llvm::find(iterArgs, val);
+      if (iterArgIt == iterArgs.end()) {
+        // `val` is defined outside of the loop.
+        assert(forOp.isDefinedOutsideOfLoop(val) &&
+               "must be defined outside of the loop");
+        hasValDefinedOutsideLoop = true;
+        replacements.push_back(val);
+      } else {
+        unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+        if (pos != i)
+          iterArgsNotInOrder = true;
+        replacements.push_back(forOp.getIterOperands()[pos]);
+      }
+    }
+    // Bail out when the trip count is unknown and the loop returns any value
+    // defined outside of the loop or any iterArg out of order.
+    if (!tripCount.hasValue() &&
+        (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+      return failure();
+    // Bail out when the loop iterates more than once and it returns any iterArg
+    // out of order.
+    if (tripCount.hasValue() && tripCount.getValue() >= 2 && iterArgsNotInOrder)
+      return failure();
+    rewriter.replaceOp(forOp, replacements);
     return success();
   }
 };
@@ -1681,11 +1729,10 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 /// Returns true if the affine.for has zero iterations in trivial cases.
 static bool hasTrivialZeroTripCount(AffineForOp op) {
-  if (!op.hasConstantBounds())
-    return false;
-  int64_t lb = op.getConstantLowerBound();
-  int64_t ub = op.getConstantUpperBound();
-  return ub - lb <= 0;
+  Optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
+  if (tripCount.hasValue() && tripCount.getValue() == 0)
+    return true;
+  return false;
 }
 
 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
index bed5f0c..cd6b08a 100644 (file)
@@ -475,6 +475,112 @@ func @fold_empty_loops() -> index {
 
 // -----
 
+// CHECK-LABEL:  func @fold_empty_loop()
+func @fold_empty_loop() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %c2, %arg1 : index, index
+  }
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[two]], %[[one]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loops_trip_count_1()
+func @fold_empty_loops_trip_count_1() -> (index, index, index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res1:2 = affine.for %i = 0 to 1 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %c1, %arg0 : index, index
+  }
+  %res2:2 = affine.for %i = 0 to 2 step 3 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[one]], %[[two]], %[[zero]], %[[two]]
+  return %res1#0, %res1#1, %res2#0, %res2#1 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loop_trip_count_0()
+func @fold_empty_loop_trip_count_0() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res:2 = affine.for %i = 0 to 0 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %c1, %arg0 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[two]], %[[zero]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loop_trip_count_unknown
+func @fold_empty_loop_trip_count_unknown(%in : index) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg0, %arg1 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-NEXT: return %[[zero]], %[[one]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_1
+func @empty_loops_not_folded_1(%in : index) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res = affine.for %i = 0 to %in iter_args(%arg = %c0) -> index {
+    affine.yield %c1 : index
+  }
+  return %res : index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_2
+func @empty_loops_not_folded_2(%in : index) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_3
+func @empty_loops_not_folded_3() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
 // CHECK-LABEL:  func @fold_zero_iter_loops
 // CHECK-SAME: %[[ARG:.*]]: index
 func @fold_zero_iter_loops(%in : index) -> index {