From fefe655baafb9aa11ae3e2a34b19aef1f47e2b8d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 16 Jan 2023 16:23:58 +0100 Subject: [PATCH] [mlir][NFC] GreedyPatternRewriteDriver: Consistent return values All `apply...` functions now return a LogicalResult indicating whether the iterative process converged or not. Differential Revision: https://reviews.llvm.org/D141845 --- .../mlir/Transforms/GreedyPatternRewriteDriver.h | 14 +++++--- .../Affine/TransformOps/AffineTransformOps.cpp | 7 +++- .../Utils/GreedyPatternRewriteDriver.cpp | 37 ++++++++++++++-------- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 5478587..aaafebf 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -80,6 +80,9 @@ inline LogicalResult applyPatternsAndFoldGreedily( /// success if no more patterns can be matched. `erased` is set to true if `op` /// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. +/// +/// Returns success if the iterative process converged and no more patterns can +/// be matched. LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, bool *erased = nullptr); @@ -93,10 +96,13 @@ LogicalResult applyOpPatternsAndFold(Operation *op, /// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a /// result of folding, becoming dead, or via pattern rewrites. If more far /// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. Returns true if at all any IR was rewritten. -bool applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict); +/// used. +/// +/// Returns success if the iterative process converged and no more patterns can +/// be matched. `changed` is set to true if the IR was modified at all. +LogicalResult applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict, bool *changed = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 24ed10e..282a35b 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -131,7 +131,12 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results, SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); // Apply the simplification pattern to a fixpoint. - (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true); + if (failed( + applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) { + auto diag = emitDefiniteFailure() + << "affine.min/max simplification did not converge"; + return diag; + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index f0794f8..6bd3994 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -574,7 +574,8 @@ public: : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), strictMode(strict) {} - bool simplifyLocally(ArrayRef op); + LogicalResult simplifyLocally(ArrayRef op, + bool *changed = nullptr); void addToWorklist(Operation *op) override { if (!strictMode || strictModeFilteredOps.contains(op)) @@ -625,13 +626,16 @@ private: // there is no strong rationale to re-add all operations into the worklist and // rerun until an iteration changes nothing. If more widereaching simplification // is desired, GreedyPatternRewriteDriver should be used. -bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, + bool *changed) { if (strictMode) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } - bool changed = false; + if (changed) + *changed = false; worklist.clear(); worklistMap.clear(); for (Operation *op : ops) @@ -657,7 +661,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { if (isOpTriviallyDead(op)) { notifyOperationRemoved(op); op->erase(); - changed = true; + if (changed) + *changed = true; continue; } @@ -687,7 +692,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { bool inPlaceUpdate; if (succeeded(folder.tryToFold(op, processGeneratedConstants, preReplaceAction, &inPlaceUpdate))) { - changed = true; + if (changed) + *changed = true; if (!inPlaceUpdate) { // Op has been erased. continue; @@ -698,12 +704,13 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { // notified of any necessary changes, so there is nothing else to do // here. if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; + if (changed) + *changed = true; ++numRewrites; } } - return changed; + return success(worklist.empty()); } /// Rewrites only `op` using the supplied canonicalization patterns and @@ -726,14 +733,18 @@ LogicalResult mlir::applyOpPatternsAndFold( return converged; } -bool mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict) { - if (ops.empty()) - return false; +LogicalResult +mlir::applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + bool strict, bool *changed) { + if (ops.empty()) { + if (changed) + *changed = false; + return success(); + } // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strict); - return driver.simplifyLocally(ops); + return driver.simplifyLocally(ops, changed); } -- 2.7.4