#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
///
/// Note that ops in `ops` could be erased as a result of folding, becoming
/// dead, or via pattern rewrites. The return value indicates convergence.
- LogicalResult simplifyLocally(ArrayRef<Operation *> op,
- bool *changed = nullptr);
+ ///
+ /// All `ops` that survived the rewrite are stored in `surviving`.
+ LogicalResult
+ simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
+ llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
void addToWorklist(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::AnyOp ||
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+ if (survivingOps)
+ survivingOps->erase(op);
if (strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
/// depending on `strictMode`. This set is not maintained when `strictMode`
/// is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
+ /// An optional set of ops that survived the rewrite. This set is populated
+ /// at the beginning of `simplifyLocally` with the inititally provided list
+ /// of ops.
+ llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr;
};
} // namespace
-LogicalResult
-MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
- bool *changed) {
+LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
+ ArrayRef<Operation *> ops, bool *changed,
+ llvm::SmallDenseSet<Operation *, 4> *surviving) {
+ auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
+ if (surviving) {
+ survivingOps = surviving;
+ survivingOps->clear();
+ survivingOps->insert(ops.begin(), ops.end());
+ }
+
if (strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed) {
+ GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
+ if (allErased)
+ *allErased = true;
return success();
}
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strictMode);
- return driver.simplifyLocally(ops, changed);
+ llvm::SmallDenseSet<Operation *, 4> surviving;
+ LogicalResult converged =
+ driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+ if (allErased)
+ *allErased = surviving.empty();
+ return converged;
}
// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX
// CHECK-EN-LABEL: func @test_erase
+// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: test.arg0
// CHECK-EN: test.arg1
// CHECK-EN-NOT: test.erase_op
// -----
// CHECK-EN-LABEL: func @test_insert_same_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
func.func @test_insert_same_op() {
// -----
// CHECK-EN-LABEL: func @test_replace_with_new_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
// CHECK-EN: "test.dummy_user"(%[[n]])
// CHECK-EN: "test.dummy_user"(%[[n]])
// -----
// CHECK-EN-LABEL: func @test_replace_with_erase_op
+// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: test.replace_with_new_op
// CHECK-EN-NOT: test.erase_op
// CHECK-EX-LABEL: func @test_replace_with_erase_op
+// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EX-NOT: test.replace_with_new_op
// CHECK-EX: test.erase_op
func.func @test_replace_with_erase_op() {
}
void runOnOperation() override {
- mlir::RewritePatternSet patterns(&getContext());
- patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
+ MLIRContext *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
// Check if these transformations introduce visiting of operations that
// are not in the `ops` set (The new created ops are valid). An invalid
// operation will trigger the assertion while processing.
- (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode);
+ bool changed = false;
+ bool allErased = false;
+ (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
+ &changed, &allErased);
+ Builder b(ctx);
+ getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
+ getOperation()->setAttr("pattern_driver_all_erased",
+ b.getBoolAttr(allErased));
}
Option<std::string> strictMode{