#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
/// Simplify the operations within the given regions.
bool simplify(MutableArrayRef<Region> regions);
- /// Add the given operation to the worklist. Parent ops may or may not be
- /// added to the worklist, depending on the type of rewrite driver. By
- /// default, parent ops are added.
- virtual void addToWorklist(Operation *op);
+ /// Add the given operation and its ancestors to the worklist.
+ void addToWorklist(Operation *op);
/// Pop the next operation from the worklist.
Operation *popFromWorklist();
protected:
/// Add the given operation to the worklist.
- void addSingleOpToWorklist(Operation *op);
+ virtual void addSingleOpToWorklist(Operation *op);
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
/// Configuration information for how to simplify.
GreedyRewriteConfig config;
-private:
/// Only ops within this scope are simplified. This is set at the beginning
- /// of `simplify()` to the current scope the rewriter operates on.
+ /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter
+ /// operates on.
DenseSet<Region *> scope;
+private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
}
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+ scope.clear();
for (Region &r : regions)
scope.insert(&r);
strictMode(strictMode) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
- /// ops. `strictMode` controls which other ops are simplified.
+ /// ops. `strictMode` controls which other ops are simplified. Only ops
+ /// within the given scope region are added to the worklist. If no scope is
+ /// specified, it assumed to be closest common region of all `ops`.
///
/// Note that ops in `ops` could be erased as a result of folding, becoming
/// dead, or via pattern rewrites. The return value indicates convergence.
/// All `ops` that survived the rewrite are stored in `surviving`.
LogicalResult
simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
- llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
+ llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr,
+ Region *scope = nullptr);
- void addToWorklist(Operation *op) override {
+protected:
+ void addSingleOpToWorklist(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
ArrayRef<Operation *> ops, bool *changed,
- llvm::SmallDenseSet<Operation *, 4> *surviving) {
+ llvm::SmallDenseSet<Operation *, 4> *surviving, Region *scope) {
auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
if (surviving) {
survivingOps = surviving;
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
+ assert(scope && "scope is mandatory");
+ this->scope.clear();
+ this->scope.insert(scope);
+
if (changed)
*changed = false;
worklist.clear();
worklistMap.clear();
for (Operation *op : ops)
- addToWorklist(op);
+ addSingleOpToWorklist(op);
// These are scratch vectors used in the folding loop below.
SmallVector<Value, 8> originalOperands, resultValues;
return converged;
}
-LogicalResult mlir::applyOpPatternsAndFold(
- ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
+/// Find the region that is the closest common ancestor of all given ops.
+static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
+ assert(!ops.empty() && "expected at least one op");
+ // Fast path in case there is only one op.
+ if (ops.size() == 1)
+ return ops.front()->getParentRegion();
+
+ Region *region = ops.front()->getParentRegion();
+ ops = ops.drop_front();
+ int sz = ops.size();
+ llvm::BitVector remainingOps(sz, true);
+ do {
+ int pos = -1;
+ // Iterate over all remaining ops.
+ while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
+ // Is this op contained in `region`?
+ if (region->findAncestorOpInRegion(*ops[pos]))
+ remainingOps.reset(pos);
+ }
+ if (remainingOps.none())
+ break;
+ } while ((region = region->getParentRegion()));
+ assert(region && "could not find common parent region");
+ return region;
+}
+
+LogicalResult
+mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode, bool *changed,
+ bool *allErased, Region *scope) {
if (ops.empty()) {
if (changed)
*changed = false;
return success();
}
+ if (!scope) {
+ // Compute scope if none was provided.
+ scope = findCommonAncestor(ops);
+ } else {
+ // If a scope was provided, make sure that all ops are in scope.
+#ifndef NDEBUG
+ bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
+ return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+ });
+ assert(allOpsInScope && "ops must be within the specified scope");
+#endif // NDEBUG
+ }
+
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strictMode);
llvm::SmallDenseSet<Operation *, 4> surviving;
- LogicalResult converged =
- driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+ LogicalResult converged = driver.simplifyLocally(
+ ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
if (allErased)
*allErased = surviving.empty();
return converged;