[mlir][SCF] Fix incorrect API usage in RewritePatterns
authorMatthias Springer <me@m-sp.org>
Mon, 27 Feb 2023 08:35:37 +0000 (09:35 +0100)
committerMatthias Springer <me@m-sp.org>
Mon, 27 Feb 2023 08:36:14 +0000 (09:36 +0100)
Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D144636

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index fce6633bb7105281f738cdc697860e009dabfa4e..acc7e3ef8e1dc7f5e2a9b7f4b0bfb4149f623501 100644 (file)
@@ -546,6 +546,11 @@ public:
       updateRootInPlace(op, [&]() { operand.set(to); });
     }
   }
+  void replaceAllUsesWith(ValueRange from, ValueRange to) {
+    assert(from.size() == to.size() && "incorrect number of replacements");
+    for (auto it : llvm::zip(from, to))
+      replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. It also marks every modified uses and notifies the rewriter that an
index f8fd2016cd9a969b4ca8ff5635027aa51448f685..84caca9806e22f6b79d7adea77c62e0ffee96868 100644 (file)
@@ -1441,21 +1441,23 @@ public:
         failed(foldDynamicIndexList(rewriter, mixedStep)))
       return failure();
 
-    SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
-    SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
-    dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
-                               staticLowerBound);
-    op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
-    op.setStaticLowerBound(staticLowerBound);
-
-    dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
-                               staticUpperBound);
-    op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
-    op.setStaticUpperBound(staticUpperBound);
-
-    dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
-    op.getDynamicStepMutable().assign(dynamicStep);
-    op.setStaticStep(staticStep);
+    rewriter.updateRootInPlace(op, [&]() {
+      SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
+      SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
+      dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
+                                 staticLowerBound);
+      op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
+      op.setStaticLowerBound(staticLowerBound);
+
+      dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
+                                 staticUpperBound);
+      op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
+      op.setStaticUpperBound(staticUpperBound);
+
+      dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
+      op.getDynamicStepMutable().assign(dynamicStep);
+      op.setStaticStep(staticStep);
+    });
     return success();
   }
 };
@@ -3073,7 +3075,8 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
                 op.getLoc(), term.getCondition().getType(),
                 rewriter.getBoolAttr(true));
 
-          std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+          rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
+                                      constantTrue);
           replaced = true;
         }
       }
index 5a9e0217edbda29e83d04e1c43884dbefc1f427a..8863b0833d3e7274e653225eb0d7a02d0b17ca66 100644 (file)
@@ -78,7 +78,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     // Rewrite uses of the for-loop block arguments to the new while-loop
     // "after" arguments
     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
-      barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
+      rewriter.replaceAllUsesWith(barg.value(),
+                                  afterBlock->getArgument(barg.index()));
 
     // Inline for-loop body operations into 'after' region.
     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
@@ -88,7 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
       SmallVector<Value> yieldOperands = yieldOp.getOperands();
       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
-      yieldOp->setOperands(yieldOperands);
+      rewriter.updateRootInPlace(
+          yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
     }
 
     // We cannot do a direct replacement of the forOp since the while op returns
@@ -96,7 +98,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     // carried in the set of iterargs). Instead, rewrite uses of the forOp
     // results.
     for (const auto &arg : llvm::enumerate(forOp.getResults()))
-      arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
+      rewriter.replaceAllUsesWith(arg.value(),
+                                  whileOp.getResult(arg.index() + 1));
 
     rewriter.eraseOp(forOp);
     return success();
index 82c2223b6eec04baa8aa61b5adb6e2b60e442eba..3ee440da66dcd62c5fe9263b02625108fb2df4e6 100644 (file)
@@ -144,7 +144,7 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
   b.setInsertionPointAfter(forOp);
   partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
   partialIteration.getLowerBoundMutable().assign(splitBound);
-  forOp.replaceAllUsesWith(partialIteration->getResults());
+  b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults());
   partialIteration.getInitArgsMutable().assign(forOp->getResults());
 
   // Set new upper loop bound.
@@ -221,11 +221,13 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
     if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration)))
       return failure();
     // Apply label, so that the same loop is not rewritten a second time.
-    partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+    rewriter.updateRootInPlace(partialIteration, [&]() {
+      partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+      partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+    });
     rewriter.updateRootInPlace(forOp, [&]() {
       forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
     });
-    partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
     return success();
   }
 
index d3dfd16ba04426798a6c4be3f0b74e50ede32bfe..a3ce8a63d4c9fd45792699b58ab2a399ebe0a6b6 100644 (file)
@@ -302,6 +302,7 @@ func.func @to_select_same_val(%cond: i1) -> (index, index) {
 // CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
 // CHECK:           return [[V0]], [[C1]] : index, index
 
+// -----
 
 func.func @to_select_with_body(%cond: i1) -> index {
   %c0 = arith.constant 0 : index
@@ -323,6 +324,7 @@ func.func @to_select_with_body(%cond: i1) -> index {
 // CHECK:             "test.op"() : () -> ()
 // CHECK:           }
 // CHECK:           return [[V0]] : index
+
 // -----
 
 func.func @to_select2(%cond: i1) -> (index, index) {
@@ -363,6 +365,10 @@ func.func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
 //  CHECK-NEXT:     %[[R:.*]] = call @make_i32() : () -> i32
 //  CHECK-NEXT:     return %[[R]] : i32
 
+// -----
+
+func.func private @make_i32() -> i32
+
 func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
   %a = call @make_i32() : () -> (i32)
   %b = call @make_i32() : () -> (i32)
@@ -523,6 +529,8 @@ func.func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8)
   return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
 }
 
+// -----
+
 // CHECK-LABEL: @merge_yielding_nested_if_nv1
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
@@ -547,6 +555,8 @@ func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @merge_yielding_nested_if_nv2
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
@@ -571,6 +581,8 @@ func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
   return %r : i32
 }
 
+// -----
+
 // CHECK-LABEL: @merge_fail_yielding_nested_if
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
@@ -1125,6 +1137,8 @@ func.func @while_unused_result() -> i32 {
 // CHECK-NEXT:         }
 // CHECK-NEXT:         return %[[res]] : i32
 
+// -----
+
 // CHECK-LABEL: @while_cmp_lhs
 func.func @while_cmp_lhs(%arg0 : i32) {
   %0 = scf.while () : () -> i32 {
@@ -1152,6 +1166,8 @@ func.func @while_cmp_lhs(%arg0 : i32) {
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
 
+// -----
+
 // CHECK-LABEL: @while_cmp_rhs
 func.func @while_cmp_rhs(%arg0 : i32) {
   %0 = scf.while () : () -> i32 {
@@ -1210,6 +1226,7 @@ func.func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
 
+// -----
 
 // CHECK-LABEL: @combineIfs2
 func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1236,6 +1253,7 @@ func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]] : i32
 
+// -----
 
 // CHECK-LABEL: @combineIfs3
 func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1262,6 +1280,8 @@ func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]] : i32
 
+// -----
+
 // CHECK-LABEL: @combineIfs4
 func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
   scf.if %arg0 {
@@ -1280,6 +1300,8 @@ func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
 
+// -----
+
 // CHECK-LABEL: @combineIfsUsed
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
@@ -1310,6 +1332,8 @@ func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
 
+// -----
+
 // CHECK-LABEL: @combineIfsNot
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
@@ -1332,6 +1356,8 @@ func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
 
+// -----
+
 // CHECK-LABEL: @combineIfsNot2
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
@@ -1353,6 +1379,7 @@ func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:     } else {
 // CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
+
 // -----
 
 // CHECK-LABEL: func @propagate_into_execute_region
@@ -1403,7 +1430,6 @@ func.func @execute_region_elim() {
 // CHECK-NEXT:       "test.bar"(%[[VAL]]) : (i64) -> ()
 // CHECK-NEXT:     }
 
-
 // -----
 
 // CHECK-LABEL: func @func_execute_region_elim
@@ -1439,7 +1465,6 @@ func.func @func_execute_region_elim() {
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
 
-
 // -----
 
 // CHECK-LABEL: func @func_execute_region_elim_multi_yield