[mlir] GreedyPatternRewriter: fix counting of iterations
authorMatthias Springer <springerm@google.com>
Tue, 10 Jan 2023 11:02:33 +0000 (12:02 +0100)
committerMatthias Springer <springerm@google.com>
Tue, 10 Jan 2023 11:21:08 +0000 (12:21 +0100)
The GreedyPatternRewriteDriver did previously not count the first iteration. I.e., when setting `config.maxIterations = 1`, two iterations were performed. In pratice, this number is not really important; we usually just need a limit in some reasonable order of magnitude. However, this fix allows us to write better convergence/worklist tests with carefully crafted test patterns to purposely trigger edge cases in the driver.

Similarly, the first rewrite was previously not counted towards `config.maxNumRewrites`.

For consistency, `OpPatternRewriteDriver` now uses `config.maxNumRewrites` instead of `config.maxIterations`; this driver does not have "iterations", it consists of a single loop (corresponding to the inner loop in the GreedyPatternRewriteDriver).

Differential Revision: https://reviews.llvm.org/D141365

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 0d6fdaf..5005a08 100644 (file)
@@ -96,10 +96,11 @@ protected:
   /// Non-pattern based folder for operations.
   OperationFolder folder;
 
-private:
+protected:
   /// Configuration information for how to simplify.
   GreedyRewriteConfig config;
 
+private:
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -147,8 +148,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
   };
 
   bool changed = false;
-  unsigned iteration = 0;
+  int64_t iteration = 0;
   do {
+    // Check if the iteration limit was reached.
+    if (iteration++ >= config.maxIterations &&
+        config.maxIterations != GreedyRewriteConfig::kNoLimit)
+      break;
+
     worklist.clear();
     worklistMap.clear();
 
@@ -184,7 +190,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
 
     changed = false;
     int64_t numRewrites = 0;
-    while (!worklist.empty()) {
+    while (!worklist.empty() &&
+           (numRewrites < config.maxNumRewrites ||
+            config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
       auto *op = popFromWorklist();
 
       // Nulls get added to the worklist when operations are removed, ignore
@@ -280,11 +288,10 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
 #else
       LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
 #endif
+
       if (succeeded(matchResult)) {
         changed = true;
-        if (numRewrites++ >= config.maxNumRewrites &&
-            config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
-          break;
+        ++numRewrites;
       }
     }
 
@@ -292,8 +299,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
     // is kept up to date.
     if (config.enableRegionSimplification)
       changed |= succeeded(simplifyRegions(*this, regions));
-  } while (changed && (iteration++ < config.maxIterations ||
-                       config.maxIterations == GreedyRewriteConfig::kNoLimit));
+  } while (changed);
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;
@@ -421,7 +427,7 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
   GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
   bool converged = driver.simplify(regions);
   LLVM_DEBUG(if (!converged) {
-    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+    llvm::dbgs() << "The pattern rewrite did not converge after scanning "
                  << config.maxIterations << " times\n";
   });
   return success(converged);
@@ -443,7 +449,8 @@ public:
     matcher.applyDefaultCostModel();
   }
 
-  LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
+  LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
+                                bool &erased);
 
   // These are hooks implemented for PatternRewriter.
 protected:
@@ -473,18 +480,22 @@ private:
 /// Performs the rewrites and folding only on `op`. The simplification
 /// converges if the op is erased as a result of being folded, replaced, or
 /// becoming dead, or no more changes happen in an iteration. Returns success if
-/// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
+/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
 /// gets erased.
 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
-                                                      int maxIterations,
+                                                      int64_t maxNumRewrites,
                                                       bool &erased) {
   bool changed = false;
   erased = false;
   opErasedViaPatternRewrites = false;
-  int iterations = 0;
-  // Iterate until convergence or until maxIterations. Deletion of the op as
+  int64_t numRewrites = 0;
+  // Iterate until convergence or until maxNumRewrites. Deletion of the op as
   // a result of being dead or folded is convergence.
   do {
+    if (numRewrites >= maxNumRewrites &&
+        maxNumRewrites != GreedyRewriteConfig::kNoLimit)
+      break;
+
     changed = false;
 
     // If the operation is trivially dead - remove it.
@@ -508,11 +519,13 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
 
     // Try to match one of the patterns. The rewriter is automatically
     // notified of any necessary changes, so there is nothing else to do here.
-    changed |= succeeded(matcher.matchAndRewrite(op, *this));
+    if (succeeded(matcher.matchAndRewrite(op, *this))) {
+      changed = true;
+      ++numRewrites;
+    }
     if ((erased = opErasedViaPatternRewrites))
       return success();
-  } while (changed && (++iterations < maxIterations ||
-                       maxIterations == GreedyRewriteConfig::kNoLimit));
+  } while (changed);
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return failure(changed);
@@ -601,7 +614,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
 
   // These are scratch vectors used in the folding loop below.
   SmallVector<Value, 8> originalOperands, resultValues;
-  while (!worklist.empty()) {
+  int64_t numRewrites = 0;
+  while (!worklist.empty() &&
+         (numRewrites < config.maxNumRewrites ||
+          config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
     Operation *op = popFromWorklist();
 
     // Nulls get added to the worklist when operations are removed, ignore
@@ -656,7 +672,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
     // Try to match one of the patterns. The rewriter is automatically
     // notified of any necessary changes, so there is nothing else to do
     // here.
-    changed |= succeeded(matcher.matchAndRewrite(op, *this));
+    if (succeeded(matcher.matchAndRewrite(op, *this))) {
+      changed = true;
+      ++numRewrites;
+    }
   }
 
   return changed;
@@ -672,12 +691,12 @@ LogicalResult mlir::applyOpPatternsAndFold(
   OpPatternRewriteDriver driver(op->getContext(), patterns);
   bool opErased;
   LogicalResult converged =
-      driver.simplifyLocally(op, config.maxIterations, opErased);
+      driver.simplifyLocally(op, config.maxNumRewrites, opErased);
   if (erased)
     *erased = opErased;
   LLVM_DEBUG(if (failed(converged)) {
-    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
-                 << config.maxIterations << " times";
+    llvm::dbgs() << "The pattern rewrite did not converge after "
+                 << config.maxNumRewrites << " rewrites";
   });
   return converged;
 }