}
namespace {
+/// Returns constant trip count in trivial cases.
+static Optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+ int64_t step = forOp.getStep();
+ if (!forOp.hasConstantBounds() || step <= 0)
+ return None;
+ int64_t lb = forOp.getConstantLowerBound();
+ int64_t ub = forOp.getConstantUpperBound();
+ return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
/// This is a pattern to fold trivially empty loop bodies.
/// TODO: This should be moved into the folding hook.
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
// Check that the body only contains a yield.
if (!llvm::hasSingleElement(*forOp.getBody()))
return failure();
- // The initial values of the iteration arguments would be the op's results.
- rewriter.replaceOp(forOp, forOp.getIterOperands());
+ if (forOp.getNumResults() == 0)
+ return success();
+ Optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+ if (tripCount.hasValue() && tripCount.getValue() == 0) {
+ // The initial values of the iteration arguments would be the op's
+ // results.
+ rewriter.replaceOp(forOp, forOp.getIterOperands());
+ return success();
+ }
+ SmallVector<Value, 4> replacements;
+ auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+ auto iterArgs = forOp.getRegionIterArgs();
+ bool hasValDefinedOutsideLoop = false;
+ bool iterArgsNotInOrder = false;
+ for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+ Value val = yieldOp.getOperand(i);
+ auto iterArgIt = llvm::find(iterArgs, val);
+ if (iterArgIt == iterArgs.end()) {
+ // `val` is defined outside of the loop.
+ assert(forOp.isDefinedOutsideOfLoop(val) &&
+ "must be defined outside of the loop");
+ hasValDefinedOutsideLoop = true;
+ replacements.push_back(val);
+ } else {
+ unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+ if (pos != i)
+ iterArgsNotInOrder = true;
+ replacements.push_back(forOp.getIterOperands()[pos]);
+ }
+ }
+ // Bail out when the trip count is unknown and the loop returns any value
+ // defined outside of the loop or any iterArg out of order.
+ if (!tripCount.hasValue() &&
+ (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+ return failure();
+ // Bail out when the loop iterates more than once and it returns any iterArg
+ // out of order.
+ if (tripCount.hasValue() && tripCount.getValue() >= 2 && iterArgsNotInOrder)
+ return failure();
+ rewriter.replaceOp(forOp, replacements);
return success();
}
};
/// Returns true if the affine.for has zero iterations in trivial cases.
static bool hasTrivialZeroTripCount(AffineForOp op) {
- if (!op.hasConstantBounds())
- return false;
- int64_t lb = op.getConstantLowerBound();
- int64_t ub = op.getConstantUpperBound();
- return ub - lb <= 0;
+ Optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
+ if (tripCount.hasValue() && tripCount.getValue() == 0)
+ return true;
+ return false;
}
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
// -----
+// CHECK-LABEL: func @fold_empty_loop()
+func @fold_empty_loop() -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+ affine.yield %c2, %arg1 : index, index
+ }
+ // CHECK-DAG: %[[one:.*]] = arith.constant 1
+ // CHECK-DAG: %[[two:.*]] = arith.constant 2
+ // CHECK-NEXT: return %[[two]], %[[one]]
+ return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_empty_loops_trip_count_1()
+func @fold_empty_loops_trip_count_1() -> (index, index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %res1:2 = affine.for %i = 0 to 1 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+ affine.yield %c1, %arg0 : index, index
+ }
+ %res2:2 = affine.for %i = 0 to 2 step 3 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+ affine.yield %arg1, %arg0 : index, index
+ }
+ // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+ // CHECK-DAG: %[[one:.*]] = arith.constant 1
+ // CHECK-DAG: %[[two:.*]] = arith.constant 2
+ // CHECK-NEXT: return %[[one]], %[[two]], %[[zero]], %[[two]]
+ return %res1#0, %res1#1, %res2#0, %res2#1 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_empty_loop_trip_count_0()
+func @fold_empty_loop_trip_count_0() -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %res:2 = affine.for %i = 0 to 0 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+ affine.yield %c1, %arg0 : index, index
+ }
+ // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+ // CHECK-DAG: %[[two:.*]] = arith.constant 2
+ // CHECK-NEXT: return %[[two]], %[[zero]]
+ return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_empty_loop_trip_count_unknown
+func @fold_empty_loop_trip_count_unknown(%in : index) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+ affine.yield %arg0, %arg1 : index, index
+ }
+ // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+ // CHECK-DAG: %[[one:.*]] = arith.constant 1
+ // CHECK-NEXT: return %[[zero]], %[[one]]
+ return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @empty_loops_not_folded_1
+func @empty_loops_not_folded_1(%in : index) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: affine.for
+ %res = affine.for %i = 0 to %in iter_args(%arg = %c0) -> index {
+ affine.yield %c1 : index
+ }
+ return %res : index
+}
+
+// -----
+
+// CHECK-LABEL: func @empty_loops_not_folded_2
+func @empty_loops_not_folded_2(%in : index) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: affine.for
+ %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+ affine.yield %arg1, %arg0 : index, index
+ }
+ return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @empty_loops_not_folded_3
+func @empty_loops_not_folded_3() -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: affine.for
+ %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+ affine.yield %arg1, %arg0 : index, index
+ }
+ return %res#0, %res#1 : index, index
+}
+
+// -----
+
// CHECK-LABEL: func @fold_zero_iter_loops
// CHECK-SAME: %[[ARG:.*]]: index
func @fold_zero_iter_loops(%in : index) -> index {