[mlir] Add an option to still use bottom-up traversal
authorAdrian Kuegel <akuegel@google.com>
Mon, 22 Mar 2021 08:42:57 +0000 (09:42 +0100)
committerAdrian Kuegel <akuegel@google.com>
Mon, 22 Mar 2021 08:49:44 +0000 (09:49 +0100)
GreedyPatternRewriteDriver was changed from bottom-up traversal to top-down traversal. Not all passes work yet with that change for traversal order. To give some time for fixing, add an option to allow to switch back to bottom-up traversal. Use this option in FusionOfTensorOpsPass which fails otherwise.

Differential Revision: https://reviews.llvm.org/D99059

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 4a084c5..aa06c02 100644 (file)
@@ -35,26 +35,26 @@ namespace mlir {
 ///       before attempting to match any of the provided patterns.
 LogicalResult
 applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternList &patterns);
+                             const FrozenRewritePatternList &patterns,
+                             bool useTopDownTraversal = true);
 
 /// Rewrite the regions of the specified operation, with a user-provided limit
 /// on iterations to attempt before reaching convergence.
-LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternList &patterns,
-                             unsigned maxIterations);
+LogicalResult applyPatternsAndFoldGreedily(
+    Operation *op, const FrozenRewritePatternList &patterns,
+    unsigned maxIterations, bool useTopDownTraversal = true);
 
 /// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternList &patterns);
+                             const FrozenRewritePatternList &patterns,
+                             bool useTopDownTraversal = true);
 
 /// Rewrite the given regions, with a user-provided limit on iterations to
 /// attempt before reaching convergence.
-LogicalResult
-applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternList &patterns,
-                             unsigned maxIterations);
+LogicalResult applyPatternsAndFoldGreedily(
+    MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
+    unsigned maxIterations, bool useTopDownTraversal = true);
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
 /// by selecting the highest benefits patterns in a greedy manner. Returns
index a61102d..1e94dfd 100644 (file)
@@ -1115,7 +1115,8 @@ struct FusionOfTensorOpsPass
     Operation *op = getOperation();
     OwningRewritePatternList patterns(op->getContext());
     populateLinalgTensorOpsFusionPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
+                                       /*useTopDown=*/false);
   }
 };
 
index 38aa749..c4b5fe0 100644 (file)
@@ -37,8 +37,10 @@ namespace {
 class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
-                                      const FrozenRewritePatternList &patterns)
-      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
+                                      const FrozenRewritePatternList &patterns,
+                                      bool useTopDownTraversal)
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx),
+        useTopDownTraversal(useTopDownTraversal) {
     worklist.reserve(64);
 
     // Apply a simple cost model based solely on pattern benefit.
@@ -134,6 +136,9 @@ private:
 
   /// Non-pattern based folder for operations.
   OperationFolder folder;
+
+  // Whether to use top-down or bottom-up traversal order.
+  bool useTopDownTraversal;
 };
 } // end anonymous namespace
 
@@ -153,14 +158,19 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 
     // Add all nested operations to the worklist in preorder.
     for (auto &region : regions)
-      region.walk<WalkOrder::PreOrder>(
-          [this](Operation *op) { worklist.push_back(op); });
-
-    // Reverse the list so our pop-back loop processes them in-order.
-    std::reverse(worklist.begin(), worklist.end());
-    // Remember the reverse index.
-    for (unsigned i = 0, e = worklist.size(); i != e; ++i)
-      worklistMap[worklist[i]] = i;
+      if (useTopDownTraversal)
+        region.walk<WalkOrder::PreOrder>(
+            [this](Operation *op) { worklist.push_back(op); });
+      else
+        region.walk([this](Operation *op) { addToWorklist(op); });
+
+    if (useTopDownTraversal) {
+      // Reverse the list so our pop-back loop processes them in-order.
+      std::reverse(worklist.begin(), worklist.end());
+      // Remember the reverse index.
+      for (unsigned i = 0, e = worklist.size(); i != e; ++i)
+        worklistMap[worklist[i]] = i;
+    }
 
     // These are scratch vectors used in the folding loop below.
     SmallVector<Value, 8> originalOperands, resultValues;
@@ -232,27 +242,28 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 ///
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(Operation *op,
-                                   const FrozenRewritePatternList &patterns) {
-  return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations);
-}
-LogicalResult
-mlir::applyPatternsAndFoldGreedily(Operation *op,
                                    const FrozenRewritePatternList &patterns,
-                                   unsigned maxIterations) {
-  return applyPatternsAndFoldGreedily(op->getRegions(), patterns,
-                                      maxIterations);
+                                   bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
+                                      useTopDownTraversal);
 }
-/// Rewrite the given regions, which must be isolated from above.
-LogicalResult
-mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                                   const FrozenRewritePatternList &patterns) {
-  return applyPatternsAndFoldGreedily(regions, patterns,
-                                      maxPatternMatchIterations);
+LogicalResult mlir::applyPatternsAndFoldGreedily(
+    Operation *op, const FrozenRewritePatternList &patterns,
+    unsigned maxIterations, bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
+                                      useTopDownTraversal);
 }
+/// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
                                    const FrozenRewritePatternList &patterns,
-                                   unsigned maxIterations) {
+                                   bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(
+      regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
+}
+LogicalResult mlir::applyPatternsAndFoldGreedily(
+    MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
+    unsigned maxIterations, bool useTopDownTraversal) {
   if (regions.empty())
     return success();
 
@@ -267,7 +278,8 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
          "patterns can only be applied to operations IsolatedFromAbove");
 
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
+  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
+                                    useTopDownTraversal);
   bool converged = driver.simplify(regions, maxIterations);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "