From 724a0e2c2d7a5724dd81b00db470ba4bb8b616ca Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 3 Feb 2023 09:44:42 +0100 Subject: [PATCH] [mlir] GreedyPatternRewriteDriver: Ignore scope when rewriting top-level ops Top-level ModuleOps cannot be transformed with the GreedyPatternRewriteDriver since D141945 because they do not have an enclosing region that could be used as a scope. Make the scope optional inside GreedyPatternRewriteDriver, so that top-level ops can be processed when they are on the initial list of ops. Note: This does not allow users to bypass the scoping mechanism by setting `config.scope = nullptr`. Fixes #60462. Differential Revision: https://reviews.llvm.org/D143151 --- .../mlir/Transforms/GreedyPatternRewriteDriver.h | 3 +- .../Utils/GreedyPatternRewriteDriver.cpp | 37 ++++++++++++---------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index d8c17c6..423221d 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -61,7 +61,8 @@ public: static constexpr int64_t kNoLimit = -1; /// Only ops within the scope are added to the worklist. If no scope is - /// specified, the closest enclosing region is used as a scope. + /// specified, the closest enclosing region around the initial list of ops + /// is used as a scope. Region *scope = nullptr; /// Strict mode can restrict the ops that are added to the worklist during diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 4c5868a..997bdc6 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -124,7 +124,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) { - assert(config.scope && "scope is not specified"); worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -266,19 +265,19 @@ bool GreedyPatternRewriteDriver::processWorklist() { void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; - ancestors.push_back(op); - while (Region *region = op->getParentRegion()) { - if (config.scope == region) { - // All gathered ops are in fact ancestors. - for (Operation *op : ancestors) - addSingleOpToWorklist(op); - break; - } - op = region->getParentOp(); - if (!op) - break; + Region *region = nullptr; + do { ancestors.push_back(op); - } + region = op->getParentRegion(); + if (config.scope == region) { + // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops. + for (Operation *op : ancestors) + addSingleOpToWorklist(op); + return; + } + if (region == nullptr) + return; + } while ((op = region->getParentOp())); } void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { @@ -556,6 +555,9 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef ops, } /// Find the region that is the closest common ancestor of all given ops. +/// +/// Note: This function returns `nullptr` if there is a top-level op among the +/// given list of ops. static Region *findCommonAncestor(ArrayRef ops) { assert(!ops.empty() && "expected at least one op"); // Fast path in case there is only one op. @@ -566,7 +568,7 @@ static Region *findCommonAncestor(ArrayRef ops) { ops = ops.drop_front(); int sz = ops.size(); llvm::BitVector remainingOps(sz, true); - do { + while (region) { int pos = -1; // Iterate over all remaining ops. while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) { @@ -576,8 +578,8 @@ static Region *findCommonAncestor(ArrayRef ops) { } if (remainingOps.none()) break; - } while ((region = region->getParentRegion())); - assert(region && "could not find common parent region"); + region = region->getParentRegion(); + } return region; } @@ -594,7 +596,8 @@ LogicalResult mlir::applyOpPatternsAndFold( // Determine scope of rewrite. if (!config.scope) { - // Compute scope if none was provided. + // Compute scope if none was provided. The scope will remain `nullptr` if + // there is a top-level op among `ops`. config.scope = findCommonAncestor(ops); } else { // If a scope was provided, make sure that all ops are in scope. -- 2.7.4