Ops that were modifed in-place (`finalizeRootUpdate` was called) should be reprocessed by the GreedyPatternRewriter. This is currently not happening with `GreedyRewriteConfig::maxIterations = 1`.
Note: If your project goes into an infinite loop because of this change, you likely have one or multiple faulty patterns that modify the same operations in-place (`updateRootInplace`) indefinitely.
Differential Revision: https://reviews.llvm.org/D138038
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ b.updateRootInPlace(xferOp, [&]() {
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ });
return success();
}
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
- xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ b.updateRootInPlace(xferOp, [&]() {
+ xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+ });
return success();
}
else
createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
- xferOp->erase();
+ b.eraseOp(xferOp);
return success();
}
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
failed(filter(xferOp)))
return failure();
- rewriter.startRootUpdate(xferOp);
- if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
- rewriter.finalizeRootUpdate(xferOp);
- return success();
- }
- rewriter.cancelRootUpdate(xferOp);
- return failure();
+ return splitFullAndPartialTransfer(rewriter, xferOp, options);
}
/// If the specified operation is in the worklist, remove it.
void removeFromWorklist(Operation *op);
+ /// Notifies the driver that the specified operation may have been modified
+ /// in-place.
+ void finalizeRootUpdate(Operation *op) override;
+
protected:
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
addToWorklist(op);
}
+void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ addToWorklist(op);
+}
+
void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
--- /dev/null
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+
+// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
+func.func @add_to_worklist_after_inplace_update() {
+ // The following op is updated in-place and should be added back to the
+ // worklist of the GreedyPatternRewriteDriver (regardless of the value of
+ // config.max_iterations).
+
+ // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ return
+}
}
};
+/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
+/// attribute with value smaller than MaxVal, it increments the value by 1.
+template <int MaxVal>
+struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
+ using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AnyAttrOfOp op,
+ PatternRewriter &rewriter) const override {
+ auto intAttr = op.getAttr().dyn_cast<IntegerAttr>();
+ if (!intAttr)
+ return failure();
+ int64_t val = intAttr.getInt();
+ if (val >= MaxVal)
+ return failure();
+ rewriter.updateRootInPlace(
+ op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
FolderInsertBeforePreviouslyFoldedConstantPattern,
FolderCommutativeOp2WithConstant>(&getContext());
+ // Additional patterns for testing the GreedyPatternRewriteDriver.
+ patterns.insert<IncrementIntAttribute<3>>(&getContext());
+
GreedyRewriteConfig config;
config.useTopDownTraversal = this->useTopDownTraversal;
+ config.maxIterations = this->maxIterations;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+ Option<int> maxIterations{
+ *this, "max-iterations",
+ llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
+ llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};
struct TestStrictPatternDriver