From eb5ec039607145c1d0d3b2a275047ce82b060e46 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 17 May 2019 15:57:49 -0700 Subject: [PATCH] Refactor PatternRewriter to inherit from FuncBuilder instead of Builder. This is necessary for allowing more complicated rewrites in the future that may do things like update the insertion point (e.g. for rewrites involving regions). -- PiperOrigin-RevId: 248803153 --- mlir/include/mlir/IR/Builders.h | 3 ++- mlir/include/mlir/IR/PatternMatch.h | 4 ++-- mlir/lib/IR/Builders.cpp | 2 ++ mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 12 ++++-------- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index d852a80..ca12f39 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -195,6 +195,7 @@ public: } explicit FuncBuilder(Function &func) : FuncBuilder(&func) {} + virtual ~FuncBuilder(); /// Create a function builder and set insertion point to the given /// operation, which will cause subsequent insertions to go right before it. @@ -262,7 +263,7 @@ public: Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. - Operation *createOperation(const OperationState &state); + virtual Operation *createOperation(const OperationState &state); /// Create operation of specific op type at the current insertion point. template diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 51528c1..e7e2f40 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -206,7 +206,7 @@ protected: /// to apply patterns and observe their effects (e.g. to keep worklists or /// other data structures up to date). /// -class PatternRewriter : public Builder { +class PatternRewriter : public FuncBuilder { public: /// Create operation of specific op type at the current insertion point /// without verifying to see if it is valid. @@ -282,7 +282,7 @@ public: ArrayRef valuesToRemoveIfDead = {}); protected: - PatternRewriter(MLIRContext *context) : Builder(context) {} + PatternRewriter(Function *fn) : FuncBuilder(fn) {} virtual ~PatternRewriter(); // These are the callback methods that subclasses can choose to implement if diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 574102c..276f1e3 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -330,6 +330,8 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { // Operations. //===----------------------------------------------------------------------===// +FuncBuilder::~FuncBuilder() {} + /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 58940c1..4f26050 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -46,8 +46,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(Function &fn, OwningRewritePatternList &&patterns) - : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), - builder(&fn) { + : PatternRewriter(&fn), matcher(std::move(patterns), *this) { worklist.reserve(64); } @@ -89,7 +88,7 @@ protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. Operation *createOperation(const OperationState &state) override { - auto *result = builder.createOperation(state); + auto *result = FuncBuilder::createOperation(state); addToWorklist(result); return result; } @@ -133,9 +132,6 @@ private: /// The low-level pattern matcher. RewritePatternMatcher matcher; - /// This builder is used to create new operations. - FuncBuilder builder; - /// The worklist for this transformation keeps track of the operations that /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from @@ -147,7 +143,7 @@ private: /// Perform the rewrites. bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { - Function *fn = builder.getFunction(); + Function *fn = getFunction(); FoldHelper helper(fn); bool changed = false; @@ -201,7 +197,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { } // Make sure that any new operations are inserted at this point. - builder.setInsertionPoint(op); + setInsertionPoint(op); // Try to match one of the canonicalization patterns. The rewriter is // automatically notified of any necessary changes, so there is nothing -- 2.7.4