[mlir][NFC] GreedyPatternRewriteDriver: Consistent return values
authorMatthias Springer <springerm@google.com>
Mon, 16 Jan 2023 15:23:58 +0000 (16:23 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 16 Jan 2023 15:30:12 +0000 (16:30 +0100)
All `apply...` functions now return a LogicalResult indicating whether the iterative process converged or not.

Differential Revision: https://reviews.llvm.org/D141845

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 5478587..aaafebf 100644 (file)
@@ -80,6 +80,9 @@ inline LogicalResult applyPatternsAndFoldGreedily(
 /// success if no more patterns can be matched. `erased` is set to true if `op`
 /// was folded away or erased as a result of becoming dead. Note: This does not
 /// apply any patterns recursively to the regions of `op`.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched.
 LogicalResult applyOpPatternsAndFold(Operation *op,
                                      const FrozenRewritePatternSet &patterns,
                                      bool *erased = nullptr);
@@ -93,10 +96,13 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
 /// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a
 /// result of folding, becoming dead, or via pattern rewrites. If more far
 /// reaching simplification is desired, applyPatternsAndFoldGreedily should be
-/// used. Returns true if at all any IR was rewritten.
-bool applyOpPatternsAndFold(ArrayRef<Operation *> ops,
-                            const FrozenRewritePatternSet &patterns,
-                            bool strict);
+/// used.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched. `changed` is set to true if the IR was modified at all.
+LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+                                     const FrozenRewritePatternSet &patterns,
+                                     bool strict, bool *changed = nullptr);
 
 } // namespace mlir
 
index 24ed10e..282a35b 100644 (file)
@@ -131,7 +131,12 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
                   SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   // Apply the simplification pattern to a fixpoint.
-  (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true);
+  if (failed(
+          applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) {
+    auto diag = emitDefiniteFailure()
+                << "affine.min/max simplification did not converge";
+    return diag;
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
index f0794f8..6bd3994 100644 (file)
@@ -574,7 +574,8 @@ public:
       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
         strictMode(strict) {}
 
-  bool simplifyLocally(ArrayRef<Operation *> op);
+  LogicalResult simplifyLocally(ArrayRef<Operation *> op,
+                                bool *changed = nullptr);
 
   void addToWorklist(Operation *op) override {
     if (!strictMode || strictModeFilteredOps.contains(op))
@@ -625,13 +626,16 @@ private:
 // there is no strong rationale to re-add all operations into the worklist and
 // rerun until an iteration changes nothing. If more widereaching simplification
 // is desired, GreedyPatternRewriteDriver should be used.
-bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
+LogicalResult
+MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
+                                             bool *changed) {
   if (strictMode) {
     strictModeFilteredOps.clear();
     strictModeFilteredOps.insert(ops.begin(), ops.end());
   }
 
-  bool changed = false;
+  if (changed)
+    *changed = false;
   worklist.clear();
   worklistMap.clear();
   for (Operation *op : ops)
@@ -657,7 +661,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
     if (isOpTriviallyDead(op)) {
       notifyOperationRemoved(op);
       op->erase();
-      changed = true;
+      if (changed)
+        *changed = true;
       continue;
     }
 
@@ -687,7 +692,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
     bool inPlaceUpdate;
     if (succeeded(folder.tryToFold(op, processGeneratedConstants,
                                    preReplaceAction, &inPlaceUpdate))) {
-      changed = true;
+      if (changed)
+        *changed = true;
       if (!inPlaceUpdate) {
         // Op has been erased.
         continue;
@@ -698,12 +704,13 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
     // notified of any necessary changes, so there is nothing else to do
     // here.
     if (succeeded(matcher.matchAndRewrite(op, *this))) {
-      changed = true;
+      if (changed)
+        *changed = true;
       ++numRewrites;
     }
   }
 
-  return changed;
+  return success(worklist.empty());
 }
 
 /// Rewrites only `op` using the supplied canonicalization patterns and
@@ -726,14 +733,18 @@ LogicalResult mlir::applyOpPatternsAndFold(
   return converged;
 }
 
-bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
-                                  const FrozenRewritePatternSet &patterns,
-                                  bool strict) {
-  if (ops.empty())
-    return false;
+LogicalResult
+mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+                             const FrozenRewritePatternSet &patterns,
+                             bool strict, bool *changed) {
+  if (ops.empty()) {
+    if (changed)
+      *changed = false;
+    return success();
+  }
 
   // Start the pattern driver.
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
                                      strict);
-  return driver.simplifyLocally(ops);
+  return driver.simplifyLocally(ops, changed);
 }