From 5c757087c741b5f1299a23f5dc3f1bcb7d46f8fa Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 22 Apr 2019 16:06:09 -0700 Subject: [PATCH] Apply patterns repeatly if the function is modified During the pattern rewrite, if the function is changed, i.e. ops created, deleted or swapped, the pattern rewriter needs to re-scan the function entirely and apply the patterns again, so the patterns whose root ops have been popped out from the working list nor an immediate users of the changed ops can be reconsidered. A command line flag is added to set the max number of iterations rescanning the function for pattern match. If the rewrite doesn' converge after this number, this compiling will continue and the result can be sub-optimal. One unit test is updated because this change fixed the missing optimization opportunities. -- PiperOrigin-RevId: 244754190 --- mlir/include/mlir/IR/PatternMatch.h | 10 +- mlir/lib/IR/PatternMatch.cpp | 5 +- .../Quantization/Transforms/ConvertSimQuant.cpp | 1 + .../Utils/GreedyPatternRewriteDriver.cpp | 204 ++++++++++++--------- .../Vectorize/lower_vector_transfers.mlir | 8 +- 5 files changed, 130 insertions(+), 98 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0b35bb3..3b02ed5 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -331,8 +331,9 @@ public: explicit RewritePatternMatcher(OwningRewritePatternList &&patterns, PatternRewriter &rewriter); - /// Try to match the given operation to a pattern and rewrite it. - void matchAndRewrite(Operation *op); + /// Try to match the given operation to a pattern and rewrite it. Return + /// true if any pattern matches. + bool matchAndRewrite(Operation *op); private: RewritePatternMatcher(const RewritePatternMatcher &) = delete; @@ -347,9 +348,10 @@ private: }; /// Rewrite the specified function by repeatedly applying the highest benefit -/// patterns in a greedy work-list driven manner. +/// patterns in a greedy work-list driven manner. Return true if no more +/// patterns can be matched in the result function. /// -void applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); +bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 7132408..539162b 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -134,7 +134,7 @@ RewritePatternMatcher::RewritePatternMatcher( } /// Try to match the given operation to a pattern and rewrite it. -void RewritePatternMatcher::matchAndRewrite(Operation *op) { +bool RewritePatternMatcher::matchAndRewrite(Operation *op) { for (auto &pattern : patterns) { // Ignore patterns that are for the wrong root or are impossible to match. if (pattern->getRootKind() != op->getName() || @@ -144,6 +144,7 @@ void RewritePatternMatcher::matchAndRewrite(Operation *op) { // Try to match and rewrite this pattern. The patterns are sorted by // benefit, so if we match we can immediately rewrite and return. if (pattern->matchAndRewrite(op, rewriter)) - return; + return true; } + return false; } diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index 7137424..1d2cabd 100644 --- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -52,6 +52,7 @@ public: // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; + return matchFailure(); } return matchSuccess(); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index dae0bf4..8c6a932 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -24,9 +24,20 @@ #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/ConstantFoldUtils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; +#define DEBUG_TYPE "pattern-matcher" + +static llvm::cl::opt maxPatternMatchIterations( + "mlir-max-pattern-match-iterations", + llvm::cl::desc( + "Max number of iterations scanning the functions for pattern match"), + llvm::cl::init(10)); + namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly @@ -38,13 +49,11 @@ public: : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), builder(&fn) { worklist.reserve(64); - - // Add all operations to the worklist. - fn.walk([&](Operation *op) { addToWorklist(op); }); } - /// Perform the rewrites. - void simplifyFunction(); + /// Perform the rewrites. Return true if the rewrite converges in + /// `maxIterations`. + bool simplifyFunction(unsigned maxIterations); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. @@ -114,8 +123,7 @@ private: // TODO(riverriddle) This is based on the fact that zero use operations // may be deleted, and that single use values often have more // canonicalization opportunities. - if (!operand->use_empty() && - std::next(operand->use_begin()) != operand->use_end()) + if (!operand->use_empty() && !operand->hasOneUse()) continue; if (auto *defInst = operand->getDefiningOp()) addToWorklist(defInst); @@ -138,99 +146,121 @@ private: }; // end anonymous namespace /// Perform the rewrites. -void GreedyPatternRewriteDriver::simplifyFunction() { - ConstantFoldHelper helper(builder.getFunction()); - - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; - - while (!worklist.empty()) { - auto *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore them. - if (op == nullptr) - continue; - - // If the operation has no side effects, and no users, then it is trivially - // dead - remove it. - if (op->hasNoSideEffect() && op->use_empty()) { - // Be careful to update bookkeeping in ConstantHelper to keep consistency - // if this is a constant op. - if (op->isa()) - helper.notifyRemoval(op); - op->erase(); - continue; - } +bool GreedyPatternRewriteDriver::simplifyFunction(unsigned maxIterations) { + Function *fn = builder.getFunction(); + ConstantFoldHelper helper(fn); - // Collects all the operands and result uses of the given `op` into work - // list. - auto collectOperandsAndUses = [this](Operation *op) { - // Add the operands to the worklist for visitation. - addToWorklist(op->getOperands()); - // Add all the users of the result to the worklist so we make sure - // to revisit them. - // - // TODO: Add a result->getUsers() iterator. - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { - for (auto &operand : op->getResult(i)->getUses()) - addToWorklist(operand.getOwner()); - } - }; + bool changed = false; + int i = 0; + do { + // Add all operations to the worklist. + fn->walk([&](Operation *op) { addToWorklist(op); }); - // Try to constant fold this op. - if (helper.tryToConstantFold(op, collectOperandsAndUses)) { - assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); - op->erase(); - continue; - } + // These are scratch vectors used in the folding loop below. + SmallVector originalOperands, resultValues; + + changed = false; + while (!worklist.empty()) { + auto *op = popFromWorklist(); + + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; - // Otherwise see if we can use the generic folder API to simplify the - // operation. - originalOperands.assign(op->operand_begin(), op->operand_end()); - resultValues.clear(); - if (succeeded(op->fold(resultValues))) { - // If the result was an in-place simplification (e.g. max(x,x,y) -> - // max(x,y)) then add the original operands to the worklist so we can make - // sure to revisit them. - if (resultValues.empty()) { - // Add the operands back to the worklist as there may be more - // canonicalization opportunities now. - addToWorklist(originalOperands); - } else { - // Otherwise, the operation is simplified away completely. - assert(resultValues.size() == op->getNumResults()); - - // Notify that we are replacing this operation. - notifyRootReplaced(op); - - // Replace the result values and erase the operation. - for (unsigned i = 0, e = resultValues.size(); i != e; ++i) { - auto *res = op->getResult(i); - if (!res->use_empty()) - res->replaceAllUsesWith(resultValues[i]); + // If the operation has no side effects, and no users, then it is + // trivially dead - remove it. + if (op->hasNoSideEffect() && op->use_empty()) { + // Be careful to update bookkeeping in ConstantHelper to keep + // consistency if this is a constant op. + if (op->isa()) + helper.notifyRemoval(op); + op->erase(); + continue; + } + + // Collects all the operands and result uses of the given `op` into work + // list. + auto collectOperandsAndUses = [this](Operation *op) { + // Add the operands to the worklist for visitation. + addToWorklist(op->getOperands()); + // Add all the users of the result to the worklist so we make sure + // to revisit them. + // + // TODO: Add a result->getUsers() iterator. + for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); } + }; - notifyOperationRemoved(op); + // Try to constant fold this op. + if (helper.tryToConstantFold(op, collectOperandsAndUses)) { + assert(op->hasNoSideEffect() && + "Constant folded op with side effects?"); op->erase(); + changed |= true; + continue; } - continue; - } - // Make sure that any new operations are inserted at this point. - builder.setInsertionPoint(op); + // Otherwise see if we can use the generic folder API to simplify the + // operation. + originalOperands.assign(op->operand_begin(), op->operand_end()); + resultValues.clear(); + if (succeeded(op->fold(resultValues))) { + // If the result was an in-place simplification (e.g. max(x,x,y) -> + // max(x,y)) then add the original operands to the worklist so we can + // make sure to revisit them. + if (resultValues.empty()) { + // Add the operands back to the worklist as there may be more + // canonicalization opportunities now. + addToWorklist(originalOperands); + } else { + // Otherwise, the operation is simplified away completely. + assert(resultValues.size() == op->getNumResults()); + + // Notify that we are replacing this operation. + notifyRootReplaced(op); + + // Replace the result values and erase the operation. + for (unsigned i = 0, e = resultValues.size(); i != e; ++i) { + auto *res = op->getResult(i); + if (!res->use_empty()) + res->replaceAllUsesWith(resultValues[i]); + } + + notifyOperationRemoved(op); + op->erase(); + } + changed |= true; + continue; + } - // Try to match one of the canonicalization patterns. The rewriter is - // automatically notified of any necessary changes, so there is nothing else - // to do here. - matcher.matchAndRewrite(op); - } + // Make sure that any new operations are inserted at this point. + builder.setInsertionPoint(op); + + // Try to match one of the canonicalization patterns. The rewriter is + // automatically notified of any necessary changes, so there is nothing + // else to do here. + changed |= matcher.matchAndRewrite(op); + } + } while (changed && ++i < maxIterations); + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; } /// Rewrite the specified function by repeatedly applying the highest benefit -/// patterns in a greedy work-list driven manner. +/// patterns in a greedy work-list driven manner. Return true if no more +/// patterns can be matched in the result function. /// -void mlir::applyPatternsGreedily(Function &fn, +bool mlir::applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); - driver.simplifyFunction(); + bool converged = driver.simplifyFunction(maxPatternMatchIterations); + LLVM_DEBUG(if (!converged) { + llvm::dbgs() + << "The pattern rewrite doesn't converge after scanning the function " + << maxPatternMatchIterations << " times"; + }); + return converged; } diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index b5345f3..f1ea8f4 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -53,6 +53,7 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d // CHECK-LABEL: func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_read(%M: index, %N: index, %O: index, %P: index) { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 { @@ -67,8 +68,6 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { - // CHECK-NEXT: %[[C0:.*]] = constant 0 : index - // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]] // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index @@ -78,7 +77,7 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D1]]] // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index - // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: {{.*}} = select // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: %[[L1:.*]] = select // @@ -127,6 +126,7 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_write(%M: index, %N: index, %O: index, %P: index) { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %cst = constant splat, 1.000000e+00> : vector<5x4x3xf32> // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { @@ -143,8 +143,6 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: affine.for %[[I4:.*]] = 0 to 3 { // CHECK-NEXT: affine.for %[[I5:.*]] = 0 to 4 { // CHECK-NEXT: affine.for %[[I6:.*]] = 0 to 5 { - // CHECK-NEXT: %[[C0:.*]] = constant 0 : index - // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]] // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index -- 2.7.4