namespace mlir {
-class FuncOp;
class PatternRewriter;
//===----------------------------------------------------------------------===//
OwningRewritePatternList patterns;
};
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner. Return true if no more
-/// patterns can be matched in the result function.
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
///
-bool applyPatternsGreedily(FuncOp fn, OwningRewritePatternList &&patterns);
+bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/FoldUtils.h"
static llvm::cl::opt<unsigned> maxPatternMatchIterations(
"mlir-max-pattern-match-iterations",
- llvm::cl::desc(
- "Max number of iterations scanning the functions for pattern match"),
+ llvm::cl::desc("Max number of iterations scanning for pattern match"),
llvm::cl::init(10));
namespace {
/// Perform the rewrites. Return true if the rewrite converges in
/// `maxIterations`.
- bool simplifyFunction(Region *region, int maxIterations);
+ bool simplify(Operation *op, int maxIterations);
void addToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
- /// efficiently remove operations from the worklist when they are erased from
- /// the function, even if they aren't the root of a pattern.
+ /// efficiently remove operations from the worklist when they are erased, even
+ /// if they aren't the root of a pattern.
std::vector<Operation *> worklist;
DenseMap<Operation *, unsigned> worklistMap;
} // end anonymous namespace
/// Perform the rewrites.
-bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
- int maxIterations) {
+bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
// Add the given operation to the worklist.
auto collectOps = [this](Operation *op) { addToWorklist(op); };
bool changed = false;
int i = 0;
do {
- // Add all operations to the worklist.
- region->walk(collectOps);
+ // Add all nested operations to the worklist.
+ for (auto ®ion : op->getRegions())
+ region.walk(collectOps);
// These are scratch vectors used in the folding loop below.
SmallVector<Value *, 8> originalOperands, resultValues;
return !changed;
}
-/// Rewrite the specified function by repeatedly applying the highest benefit
-/// patterns in a greedy work-list driven manner. Return true if no more
-/// patterns can be matched in the result function.
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
///
-bool mlir::applyPatternsGreedily(FuncOp fn,
+bool mlir::applyPatternsGreedily(Operation *op,
OwningRewritePatternList &&patterns) {
- GreedyPatternRewriteDriver driver(fn.getContext(), std::move(patterns));
- bool converged =
- driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations);
+ // The top-level operation must be known to be isolated from above to
+ // prevent performing canonicalizations on operations defined at or above
+ // the region containing 'op'.
+ if (!op->isKnownIsolatedFromAbove())
+ return false;
+
+ GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
+ bool converged = driver.simplify(op, maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) {
- llvm::dbgs()
- << "The pattern rewrite doesn't converge after scanning the function "
- << maxPatternMatchIterations << " times";
+ llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+ << maxPatternMatchIterations << " times";
});
return converged;
}