Update 'applyPatternsGreedily' to work on the regions of any operations.
authorRiver Riddle <riverriddle@google.com>
Mon, 15 Jul 2019 16:52:52 +0000 (09:52 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 16 Jul 2019 20:44:39 +0000 (13:44 -0700)
'applyPatternsGreedily' is a useful utility outside of just function regions.

PiperOrigin-RevId: 258182937

mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 0e4d6ea3337f6c9f71c5d8c7db556c2005c8c871..97efae159797c75b196c4c60fdedceff6d228bbf 100644 (file)
@@ -22,7 +22,6 @@
 
 namespace mlir {
 
-class FuncOp;
 class PatternRewriter;
 
 //===----------------------------------------------------------------------===//
@@ -417,11 +416,13 @@ private:
   OwningRewritePatternList patterns;
 };
 
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner. Return true if no more
-/// patterns can be matched in the result function.
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
 ///
-bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns);
+bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
 
 /// Helper class to create a list of rewrite patterns given a list of their
 /// types and a list of attributes perfect-forwarded to each of the conversion
index 33ae17d610c123617b49cd5b58e64a423f494a98..2e8ecfa5dab252a1682c94c4ffd31ad809540dda 100644 (file)
@@ -31,6 +31,7 @@ namespace mlir {
 
 // Forward declarations.
 class Block;
+class FuncOp;
 class MLIRContext;
 class Operation;
 class Type;
index c2f885ac1654a4cbd261ee4bbdf7ff6e2aab9179..52952178b37828594a2c1c0c2801ce25a762da15 100644 (file)
@@ -20,7 +20,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/StandardOps/Ops.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -35,8 +34,7 @@ using namespace mlir;
 
 static llvm::cl::opt<unsigned> maxPatternMatchIterations(
     "mlir-max-pattern-match-iterations",
-    llvm::cl::desc(
-        "Max number of iterations scanning the functions for pattern match"),
+    llvm::cl::desc("Max number of iterations scanning for pattern match"),
     llvm::cl::init(10));
 
 namespace {
@@ -53,7 +51,7 @@ public:
 
   /// Perform the rewrites. Return true if the rewrite converges in
   /// `maxIterations`.
-  bool simplifyFunction(Region *region, int maxIterations);
+  bool simplify(Operation *op, int maxIterations);
 
   void addToWorklist(Operation *op) {
     // Check to see if the worklist already contains this op.
@@ -135,8 +133,8 @@ private:
 
   /// The worklist for this transformation keeps track of the operations that
   /// need to be revisited, plus their index in the worklist.  This allows us to
-  /// efficiently remove operations from the worklist when they are erased from
-  /// the function, even if they aren't the root of a pattern.
+  /// efficiently remove operations from the worklist when they are erased, even
+  /// if they aren't the root of a pattern.
   std::vector<Operation *> worklist;
   DenseMap<Operation *, unsigned> worklistMap;
 
@@ -146,16 +144,16 @@ private:
 } // end anonymous namespace
 
 /// Perform the rewrites.
-bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
-                                                  int maxIterations) {
+bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
   // Add the given operation to the worklist.
   auto collectOps = [this](Operation *op) { addToWorklist(op); };
 
   bool changed = false;
   int i = 0;
   do {
-    // Add all operations to the worklist.
-    region->walk(collectOps);
+    // Add all nested operations to the worklist.
+    for (auto &region : op->getRegions())
+      region.walk(collectOps);
 
     // These are scratch vectors used in the folding loop below.
     SmallVector<Value *, 8> originalOperands, resultValues;
@@ -212,19 +210,25 @@ bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
   return !changed;
 }
 
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner. Return true if no more
-/// patterns can be matched in the result function.
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
 ///
-bool mlir::applyPatternsGreedily(FuncOp fn,
+bool mlir::applyPatternsGreedily(Operation *op,
                                  OwningRewritePatternList &&patterns) {
-  GreedyPatternRewriteDriver driver(fn.getContext(), std::move(patterns));
-  bool converged =
-      driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations);
+  // The top-level operation must be known to be isolated from above to
+  // prevent performing canonicalizations on operations defined at or above
+  // the region containing 'op'.
+  if (!op->isKnownIsolatedFromAbove())
+    return false;
+
+  GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
+  bool converged = driver.simplify(op, maxPatternMatchIterations);
   LLVM_DEBUG(if (!converged) {
-    llvm::dbgs()
-        << "The pattern rewrite doesn't converge after scanning the function "
-        << maxPatternMatchIterations << " times";
+    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+                 << maxPatternMatchIterations << " times";
   });
   return converged;
 }