From 71d50c890bad943ab23ee9b32638b2366351f8f8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 14 Jun 2023 08:41:19 +0200 Subject: [PATCH] [mlir][IR] Improve listener notifications for ops without results `RewriterBase::Listener::notifyOperationReplaced` notifies observers that an op is about to be replaced with a range of values. This notification is not very useful for ops without results, because it does not specify the replacement op (and it cannot be deduced from the replacement values). It provides no additional information over the `notifyOperationRemoved` notification. This revision adds an additional notification when a rewriter replaces an op with another op. By default, this notification triggers the original "op replaced with values" notification, so there is no functional change for existing code. This new API is useful for the transform dialect, which needs to track op replacements. (Updated in a subsequent revision.) Also includes minor documentation improvements. Differential Revision: https://reviews.llvm.org/D152814 --- mlir/include/mlir/IR/PatternMatch.h | 35 ++++++++++++----- .../mlir/Transforms/DialectConversion.h | 10 +++-- .../AMDGPU/Transforms/EmulateAtomics.cpp | 2 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 5 +-- mlir/lib/IR/PatternMatch.cpp | 38 +++++++++++-------- .../Transforms/Utils/DialectConversion.cpp | 7 ++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 8 ++-- 7 files changed, 67 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 4614649caae1..3843ff249ddf 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -406,13 +406,24 @@ public: virtual void notifyOperationModified(Operation *op) {} /// Notify the listener that the specified operation is about to be replaced - /// with the set of values potentially produced by new operations. This is - /// called before the uses of the operation have been changed. + /// with another operation. This is called before the uses of the old + /// operation have been changed. + /// + /// By default, this function calls the "operation replaced with values" + /// notification. + virtual void notifyOperationReplaced(Operation *op, + Operation *replacement) { + notifyOperationReplaced(op, replacement->getResults()); + } + + /// Notify the listener that the specified operation is about to be replaced + /// with the a range of values, potentially produced by other operations. + /// This is called before the uses of the operation have been changed. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. + /// Notify the listener that the specified operation is about to be erased. + /// At this point, the operation has zero uses. virtual void notifyOperationRemoved(Operation *op) {} /// Notify the listener that the pattern failed to match the given @@ -444,6 +455,9 @@ public: void notifyOperationModified(Operation *op) override { listener->notifyOperationModified(op); } + void notifyOperationReplaced(Operation *op, Operation *newOp) override { + listener->notifyOperationReplaced(op, newOp); + } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { listener->notifyOperationReplaced(op, replacement); @@ -505,15 +519,20 @@ public: /// This method replaces the results of the operation with the specified list /// of values. The number of provided values must match the number of results - /// of the operation. + /// of the operation. The replaced op is erased. virtual void replaceOp(Operation *op, ValueRange newValues); + /// This method replaces the results of the operation with the specified + /// new op (replacement). The number of results of the two operations must + /// match. The replaced op is erased. + virtual void replaceOp(Operation *op, Operation *newOp); + /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { auto newOp = create(op->getLoc(), std::forward(args)...); - replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); + replaceOp(op, newOp.getOperation()); return newOp; } @@ -666,10 +685,6 @@ protected: private: void operator=(const RewriterBase &) = delete; RewriterBase(const RewriterBase &) = delete; - - /// 'op' and 'newOp' are known to have the same number of results, replace the - /// uses of op with uses of newOp. - void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f242eea76778..f5206b1a4da4 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -695,15 +695,17 @@ public: /// patterns even if a failure is encountered during the rewrite step. bool canRecoverFromRewriteFailure() const override { return true; } - /// PatternRewriter hook for replacing the results of an operation when the - /// given functor returns true. + /// PatternRewriter hook for replacing an operation when the given functor + /// returns "true". void replaceOpWithIf( Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function functor) override; - /// PatternRewriter hook for replacing the results of an operation. + /// PatternRewriter hook for replacing an operation. void replaceOp(Operation *op, ValueRange newValues) override; - using PatternRewriter::replaceOp; + + /// PatternRewriter hook for replacing an operation. + void replaceOp(Operation *op, Operation *newOp) override; /// PatternRewriter hook for erasing a dead operation. The uses of this /// operation *must* be made dead by the end of the conversion process, diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index d07d6518d57c..9dfe07797ff4 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -139,7 +139,7 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); rewriter.create(loc, canLeave, afterAtomic, ValueRange{}, loopBlock, atomicRes); - rewriter.replaceOp(atomicOp, {}); + rewriter.eraseOp(atomicOp); return success(); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 8ed790c421a4..01ad2dd20e7c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -331,10 +331,7 @@ struct SimplifyAllocConst : public OpRewritePattern { alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), alloc.getAlignmentAttr()); // Insert a cast so we have the same type as the old alloc. - auto resultCast = - rewriter.create(alloc.getLoc(), alloc.getType(), newAlloc); - - rewriter.replaceOp(alloc, {resultCast}); + rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc); return success(); } }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 052696d5cb13..db920c14ea08 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -262,12 +262,12 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of -/// the operation. +/// the operation. The replaced op is erased. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - // Notify the listener that we're about to remove this op. + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); @@ -275,9 +275,28 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { for (auto it : llvm::zip(op->getResults(), newValues)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + // Erase the op. + eraseOp(op); +} + +/// This method replaces the results of the operation with the specified new op +/// (replacement). The number of results of the two operations must match. The +/// replaced op is erased. +void RewriterBase::replaceOp(Operation *op, Operation *newOp) { + assert(op && newOp && "expected non-null op"); + assert(op->getNumResults() == newOp->getNumResults() && + "ops have different number of results"); + + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + rewriteListener->notifyOperationReplaced(op, newOp); + + // Replace results one-by-one. Also notifies the listener of modifications. + for (auto it : llvm::zip(op->getResults(), newOp->getResults())) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + // Erase the old op. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of @@ -364,17 +383,6 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { return block->splitBlock(before); } -/// 'op' and 'newOp' are known to have the same number of results, replace the -/// uses of op with uses of newOp -void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, - Operation *newOp) { - assert(op->getNumResults() == newOp->getNumResults() && - "replacement op doesn't match results of original op"); - if (op->getNumResults() == 1) - return replaceOp(op, newOp->getResult(0)); - return replaceOp(op, newOp->getResults()); -} - /// Move the blocks that belong to "region" before the given position in /// another region. The two regions must be different. The caller is in /// charge to update create the operation transferring the control flow to the diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 615c8e4a99ce..411111358b1e 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1452,7 +1452,14 @@ void ConversionPatternRewriter::replaceOpWithIf( "replaceOpWithIf is currently not supported by DialectConversion"); } +void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { + assert(op && newOp && "expected non-null op"); + replaceOp(op, newOp->getResults()); +} + void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); LLVM_DEBUG({ impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 3a1faeabe84c..8a0ca056cd8b 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -601,7 +601,7 @@ struct TestCreateBlock : public RewritePattern { Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); rewriter.create(loc); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return success(); } }; @@ -621,7 +621,7 @@ struct TestCreateIllegalBlock : public RewritePattern { // Create an illegal op to ensure the conversion fails. rewriter.create(loc, i32Type); rewriter.create(loc); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return success(); } }; @@ -793,8 +793,8 @@ struct TestNonRootReplacement : public RewritePattern { auto illegalOp = rewriter.create(op->getLoc(), resultType); auto legalOp = rewriter.create(op->getLoc(), resultType); - rewriter.replaceOp(illegalOp, {legalOp}); - rewriter.replaceOp(op, {illegalOp}); + rewriter.replaceOp(illegalOp, legalOp); + rewriter.replaceOp(op, illegalOp); return success(); } }; -- 2.34.1