[mlir] GreedyPatternRewriter: Add ancestors to worklist
authorMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 09:42:01 +0000 (10:42 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 09:51:28 +0000 (10:51 +0100)
When adding an op to the worklist, also add its ancestors to the worklist. This allows for RewritePatterns to match an op `a` based on what is inside of the body of `a`.

This change fixes a problem that became apparent with `vector.warp_execute_on_lane_0`, but could probably be triggered with similar patterns. The pattern extracts an op `b` with `eligible = true` from the body of an op `a`:
```
test.a {
  %0 = test.b() {eligible = true}
  yield %0
}
```

Afterwards:
```
%0 = test.b() {eligible = true}
test.a {
  yield %0
}
```

The pattern is an `OpRewritePattern<OpA>`. For some reason, `test.a` is not on the GreedyPatternRewriter's worklist. E.g., because no pattern could be applied and it was removed. Now, another pattern updates `test.b`, so that `eligible` is changed from `true` to `false`. The `OpRewritePattern<OpA>` could now be applied, but (without this revision) `test.a` is still not on the worklist.

Note: In the above example, an `OpRewritePattern<OpB>` could have been used instead of an `OpRewritePattern<OpA>`. With such a design, we can run into the same problem (when the `eligible` attr is on `test.a` and `test.b` is removed from the worklist because no patterns could be applied).

Note: This change uncovered an unrelated bug in TestSCFUtils.cpp that was triggered due to a change in the order in which ops are processed. A TODO is added to the broken code and test cases are adapted so that the bug is no longer triggered.

Differential Revision: https://reviews.llvm.org/D140304

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/Dialect/SCF/loop-pipelining.mlir
mlir/test/IR/greedy-pattern-rewriter-driver.mlir
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index cdb0b78..2cf895e 100644 (file)
@@ -42,7 +42,9 @@ public:
   /// Simplify the operations within the given regions.
   bool simplify(MutableArrayRef<Region> regions);
 
-  /// Add the given operation to the worklist.
+  /// Add the given operation to the worklist. Parent ops may or may not be
+  /// added to the worklist, depending on the type of rewrite driver. By
+  /// default, parent ops are added.
   virtual void addToWorklist(Operation *op);
 
   /// Pop the next operation from the worklist.
@@ -56,6 +58,9 @@ public:
   void finalizeRootUpdate(Operation *op) override;
 
 protected:
+  /// Add the given operation to the worklist.
+  void addSingleOpToWorklist(Operation *op);
+
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
   void notifyOperationInserted(Operation *op) override;
@@ -101,6 +106,10 @@ protected:
   GreedyRewriteConfig config;
 
 private:
+  /// Only ops within this scope are simplified. This is set at the beginning
+  /// of `simplify()` to the current scope the rewriter operates on.
+  DenseSet<Region *> scope;
+
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -119,6 +128,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 }
 
 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+  for (Region &r : regions)
+    scope.insert(&r);
+
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -306,6 +318,24 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
 }
 
 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+  // Gather potential ancestors while looking for a "scope" parent region.
+  SmallVector<Operation *, 8> ancestors;
+  ancestors.push_back(op);
+  while (Region *region = op->getParentRegion()) {
+    if (scope.contains(region)) {
+      // All gathered ops are in fact ancestors.
+      for (Operation *op : ancestors)
+        addSingleOpToWorklist(op);
+      break;
+    }
+    op = region->getParentOp();
+    if (!op)
+      break;
+    ancestors.push_back(op);
+  }
+}
+
+void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
   // Check to see if the worklist already contains this op.
   if (worklistMap.count(op))
     return;
@@ -540,7 +570,8 @@ namespace {
 /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
 /// perform folding for a supplied set of ops. It repeatedly simplifies while
 /// restricting the rewrites to only the provided set of ops or optionally
-/// to those directly affected by it (result users or operand providers).
+/// to those directly affected by it (result users or operand providers). Parent
+/// ops are not considered.
 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
@@ -553,7 +584,7 @@ public:
 
   void addToWorklist(Operation *op) override {
     if (!strictMode || strictModeFilteredOps.contains(op))
-      GreedyPatternRewriteDriver::addToWorklist(op);
+      GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
   }
 
 private:
index 0d7a635..bfd33d5 100644 (file)
@@ -22,13 +22,13 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
 // CHECK: return %[[RESULT]]
 
-// ----
+// -----
 
 // CHECK-LABEL: func @ctlz
 func.func @ctlz(%arg: i32) -> i32 {
-  // CHECK: %[[C0:.+]] = arith.constant 0 : i32
-  // CHECK: %[[C32:.+]] = arith.constant 32 : i32
-  // CHECK: %[[C1:.+]] = arith.constant 1 : i32
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
   // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
   // CHECK:   %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
   // CHECK:   scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
index ead1e71..68b5133 100644 (file)
@@ -417,7 +417,7 @@ func.func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
 // Prologue:
 //       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
 //  CHECK-NEXT:   %[[ADD0:.*]] = arith.addf %[[L0]], %[[CSTF]] : f32
@@ -426,19 +426,22 @@ func.func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
 //  CHECK-NEXT:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
 //  CHECK-SAME:     %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
-//  CHECK-NEXT:     %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32
+//  CHECK-NEXT:     %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
 //  CHECK-NEXT:     %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
 //  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-//  CHECK-NEXT:     scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
+//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
 //  CHECK-NEXT:   }
 // Epilogue:
-//  CHECK-NEXT:   %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
-//  CHECK-NEXT:   return %[[ADD2]] : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32
+//  CHECK-NEXT:   %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32
+//  CHECK-NEXT:   %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32
+//  CHECK-NEXT:   return %[[MUL2]] : f32
 func.func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cf = arith.constant 1.0 : f32
+  %cf = arith.constant 2.0 : f32
   %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
     %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
     %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
@@ -455,7 +458,7 @@ func.func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
 // Prologue:
 //       CHECK:   %[[L0:.*]] = scf.execute_region
 //  CHECK-NEXT:     memref.load %[[A]][%[[C0]]] : memref<?xf32>
@@ -467,23 +470,26 @@ func.func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
 //       CHECK:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
 //  CHECK-SAME:     %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
+//       CHECK:     %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32
 //       CHECK:     %[[ADD1:.*]] = scf.execute_region
-//  CHECK-NEXT:       arith.addf %[[LARG]], %[[ADDARG]] : f32
+//  CHECK-NEXT:       arith.addf %[[LARG]], %[[MUL0]] : f32
 //       CHECK:     %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
 //       CHECK:     %[[L2:.*]] = scf.execute_region
 //  CHECK-NEXT:       memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-//       CHECK:     scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
+//       CHECK:     scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
 //  CHECK-NEXT:   }
 // Epilogue:
+//       CHECK:   %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32
 //       CHECK:   %[[ADD2:.*]] = scf.execute_region
-//  CHECK-NEXT:    arith.addf %[[R]]#2, %[[R]]#1 : f32
-//       CHECK:   return %[[ADD2]] : f32
+//  CHECK-NEXT:    arith.addf %[[R]]#2, %[[MUL1]] : f32
+//       CHECK:   %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32
+//       CHECK:   return %[[MUL2]] : f32
 
 func.func @region_backedge_different_stage(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cf = arith.constant 1.0 : f32
+  %cf = arith.constant 2.0 : f32
   %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
     %A_elem = scf.execute_region -> f32 {
       %A_elem1 = memref.load %A[%i0] : memref<?xf32>
@@ -507,7 +513,7 @@ func.func @region_backedge_different_stage(%A: memref<?xf32>) -> f32 {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
 // Prologue:
 //       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
 // Kernel:
@@ -515,18 +521,20 @@ func.func @region_backedge_different_stage(%A: memref<?xf32>) -> f32 {
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
 //  CHECK-SAME:     %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
 //  CHECK-NEXT:     %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CSTF]] : f32
 //  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
 //  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-//  CHECK-NEXT:     scf.yield %[[ADD0]], %[[L2]] : f32, f32
+//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[L2]] : f32, f32
 //  CHECK-NEXT:   }
 // Epilogue:
 //  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
-//  CHECK-NEXT:   return %[[ADD1]] : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CSTF]] : f32
+//  CHECK-NEXT:   return %[[MUL1]] : f32
 func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cf = arith.constant 1.0 : f32
+  %cf = arith.constant 2.0 : f32
   %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
     %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
     %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
@@ -538,7 +546,7 @@ func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
 
 // -----
 
-// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref<?xf32>, %[[ARG1:.+]]: memref<?xf32>, %[[ARG2:.+]]: memref<?xf32>) {
+// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref<?xf32>, %[[ARG1:.+]]: memref<?xf32>, %[[ARG2:.+]]: memref<?xf32>, %[[CF:.*]]: f32) {
 // CHECK:   %[[C0:.+]] = arith.constant 0 :
 // CHECK:   %[[C3:.+]] = arith.constant 3 :
 // CHECK:   %[[C1:.+]] = arith.constant 1 :
@@ -590,11 +598,10 @@ func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
   __test_pipelining_stage__ = 1,
   __test_pipelining_op_order__ = 2
 }
-func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result: memref<?xf32>) {
+func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result: memref<?xf32>, %cf: f32) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cf = arith.constant 1.0 : f32
   %a_buf = memref.alloc() : memref<2x8xf32>
   %b_buf = memref.alloc() : memref<2x8xf32>
   scf.for %i0 = %c0 to %c4 step %c1 {
index 4f1a06f..6f4923a 100644 (file)
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN:     -allow-unregistered-dialect --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @add_to_worklist_after_inplace_update()
 func.func @add_to_worklist_after_inplace_update() {
@@ -10,3 +11,16 @@ func.func @add_to_worklist_after_inplace_update() {
   "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @add_ancestors_to_worklist()
+func.func @add_ancestors_to_worklist() {
+       // CHECK: "foo.maybe_eligible_op"() {eligible} : () -> index
+  // CHECK-NEXT: "test.one_region_op"()
+  "test.one_region_op"() ({
+    %0 = "foo.maybe_eligible_op" () : () -> (index)
+    "foo.yield"(%0) : (index) -> ()
+  }) {hoist_eligible_ops}: () -> ()
+  return
+}
index 4beccf9..151da35 100644 (file)
@@ -140,6 +140,8 @@ struct TestSCFPipeliningPass
       auto attrCycle =
           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
       if (attrCycle && attrStage) {
+        // TODO: Index can be out-of-bounds if ops of the loop body disappear
+        // due to folding.
         schedule[attrCycle.getInt()] =
             std::make_pair(op, unsigned(attrStage.getInt()));
       }
index 2573f76..d3ef160 100644 (file)
@@ -167,6 +167,38 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
   }
 };
 
+/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
+struct MakeOpEligible : public RewritePattern {
+  MakeOpEligible(MLIRContext *context)
+      : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->hasAttr("eligible"))
+      return failure();
+    rewriter.updateRootInPlace(
+        op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
+    return success();
+  }
+};
+
+/// This pattern hoists eligible ops out of a "test.one_region_op".
+struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
+  using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(test::OneRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *terminator = op.getRegion().front().getTerminator();
+    Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
+    if (toBeHoisted->getParentOp() != op)
+      return failure();
+    if (!toBeHoisted->hasAttr("eligible"))
+      return failure();
+    toBeHoisted->moveBefore(op);
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -183,7 +215,8 @@ struct TestPatternDriver
     // Verify named pattern is generated with expected name.
     patterns.add<FoldingPattern, TestNamedPatternRule,
                  FolderInsertBeforePreviouslyFoldedConstantPattern,
-                 FolderCommutativeOp2WithConstant>(&getContext());
+                 FolderCommutativeOp2WithConstant, HoistEligibleOps,
+                 MakeOpEligible>(&getContext());
 
     // Additional patterns for testing the GreedyPatternRewriteDriver.
     patterns.insert<IncrementIntAttribute<3>>(&getContext());