From 6edfb628f9cc1008d6d0dd7719483458a324daa8 Mon Sep 17 00:00:00 2001 From: Amy Zhuang Date: Tue, 8 Mar 2022 17:17:22 -0800 Subject: [PATCH] [mlir] Extend AffineForEmptyLoopFolder Currently when we fold an empty loop, we assume that any loop with iterArgs returns its iterArgs in order, which is not always the case. It may return values defined outside of the loop or return its iterArgs out of order. This patch adds support to those cases. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D120776 --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 61 +++++++++++++++-- mlir/test/Dialect/Affine/canonicalize.mlir | 106 +++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 90aa7ba..c295206 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1657,6 +1657,16 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { } namespace { +/// Returns constant trip count in trivial cases. +static Optional getTrivialConstantTripCount(AffineForOp forOp) { + int64_t step = forOp.getStep(); + if (!forOp.hasConstantBounds() || step <= 0) + return None; + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); + return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +} + /// This is a pattern to fold trivially empty loop bodies. /// TODO: This should be moved into the folding hook. struct AffineForEmptyLoopFolder : public OpRewritePattern { @@ -1667,8 +1677,46 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern { // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); - // The initial values of the iteration arguments would be the op's results. - rewriter.replaceOp(forOp, forOp.getIterOperands()); + if (forOp.getNumResults() == 0) + return success(); + Optional tripCount = getTrivialConstantTripCount(forOp); + if (tripCount.hasValue() && tripCount.getValue() == 0) { + // The initial values of the iteration arguments would be the op's + // results. + rewriter.replaceOp(forOp, forOp.getIterOperands()); + return success(); + } + SmallVector replacements; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto iterArgs = forOp.getRegionIterArgs(); + bool hasValDefinedOutsideLoop = false; + bool iterArgsNotInOrder = false; + for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { + Value val = yieldOp.getOperand(i); + auto iterArgIt = llvm::find(iterArgs, val); + if (iterArgIt == iterArgs.end()) { + // `val` is defined outside of the loop. + assert(forOp.isDefinedOutsideOfLoop(val) && + "must be defined outside of the loop"); + hasValDefinedOutsideLoop = true; + replacements.push_back(val); + } else { + unsigned pos = std::distance(iterArgs.begin(), iterArgIt); + if (pos != i) + iterArgsNotInOrder = true; + replacements.push_back(forOp.getIterOperands()[pos]); + } + } + // Bail out when the trip count is unknown and the loop returns any value + // defined outside of the loop or any iterArg out of order. + if (!tripCount.hasValue() && + (hasValDefinedOutsideLoop || iterArgsNotInOrder)) + return failure(); + // Bail out when the loop iterates more than once and it returns any iterArg + // out of order. + if (tripCount.hasValue() && tripCount.getValue() >= 2 && iterArgsNotInOrder) + return failure(); + rewriter.replaceOp(forOp, replacements); return success(); } }; @@ -1681,11 +1729,10 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, /// Returns true if the affine.for has zero iterations in trivial cases. static bool hasTrivialZeroTripCount(AffineForOp op) { - if (!op.hasConstantBounds()) - return false; - int64_t lb = op.getConstantLowerBound(); - int64_t ub = op.getConstantUpperBound(); - return ub - lb <= 0; + Optional tripCount = getTrivialConstantTripCount(op); + if (tripCount.hasValue() && tripCount.getValue() == 0) + return true; + return false; } LogicalResult AffineForOp::fold(ArrayRef operands, diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index bed5f0c..cd6b08a 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -475,6 +475,112 @@ func @fold_empty_loops() -> index { // ----- +// CHECK-LABEL: func @fold_empty_loop() +func @fold_empty_loop() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %c2, %arg1 : index, index + } + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[two]], %[[one]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loops_trip_count_1() +func @fold_empty_loops_trip_count_1() -> (index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res1:2 = affine.for %i = 0 to 1 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %c1, %arg0 : index, index + } + %res2:2 = affine.for %i = 0 to 2 step 3 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[one]], %[[two]], %[[zero]], %[[two]] + return %res1#0, %res1#1, %res2#0, %res2#1 : index, index, index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loop_trip_count_0() +func @fold_empty_loop_trip_count_0() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %res:2 = affine.for %i = 0 to 0 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) { + affine.yield %c1, %arg0 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[two:.*]] = arith.constant 2 + // CHECK-NEXT: return %[[two]], %[[zero]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @fold_empty_loop_trip_count_unknown +func @fold_empty_loop_trip_count_unknown(%in : index) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg0, %arg1 : index, index + } + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 + // CHECK-DAG: %[[one:.*]] = arith.constant 1 + // CHECK-NEXT: return %[[zero]], %[[one]] + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_1 +func @empty_loops_not_folded_1(%in : index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res = affine.for %i = 0 to %in iter_args(%arg = %c0) -> index { + affine.yield %c1 : index + } + return %res : index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_2 +func @empty_loops_not_folded_2(%in : index) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + return %res#0, %res#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @empty_loops_not_folded_3 +func @empty_loops_not_folded_3() -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: affine.for + %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) { + affine.yield %arg1, %arg0 : index, index + } + return %res#0, %res#1 : index, index +} + +// ----- + // CHECK-LABEL: func @fold_zero_iter_loops // CHECK-SAME: %[[ARG:.*]]: index func @fold_zero_iter_loops(%in : index) -> index { -- 2.7.4