From 87e345b1bdb76867cc6e9ae59b6dd2633a480d38 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 20 Jan 2023 10:01:27 +0100 Subject: [PATCH] [mlir] GreedyPatternRewriteDriver: Add new strict mode option There are now three options: * `AnyOp` (previously `false`) * `ExistingAndNewOps` (previously `true`) * `ExistingOps`: this one is new. The last option corresponds to what the `applyOpPatternsAndFold(Operation*, ...)` overload is doing. It is now also supported on the `applyOpPatternsAndFold(ArrayRef, ...)` overload. Differential Revision: https://reviews.llvm.org/D141904 --- .../mlir/Transforms/GreedyPatternRewriteDriver.h | 39 +++++++++---- .../Affine/TransformOps/AffineTransformOps.cpp | 3 +- .../Affine/Transforms/AffineDataCopyGeneration.cpp | 3 +- .../Affine/Transforms/SimplifyAffineStructures.cpp | 3 +- .../Utils/GreedyPatternRewriteDriver.cpp | 67 +++++++++------------- .../Transforms/test-strict-pattern-driver.mlir | 46 ++++++++++----- .../test/lib/Dialect/Affine/TestAffineDataCopy.cpp | 3 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 25 ++++++-- 8 files changed, 115 insertions(+), 74 deletions(-) diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index aaafebf..72b2475 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -18,6 +18,17 @@ namespace mlir { +/// This enum controls which ops are put on the worklist during a greedy +/// pattern rewrite. +enum class GreedyRewriteStrictness { + /// No restrictions wrt. which ops are processed. + AnyOp, + /// Only pre-existing and newly created ops are processed. + ExistingAndNewOps, + /// Only pre-existing ops are processed. + ExistingOps +}; + /// This class allows control over how the GreedyPatternRewriteDriver works. class GreedyRewriteConfig { public: @@ -88,21 +99,29 @@ LogicalResult applyOpPatternsAndFold(Operation *op, bool *erased = nullptr); /// Applies the specified rewrite patterns on `ops` while also trying to fold -/// these ops as well as any other ops that were in turn created due to such -/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops` -/// remain completely unmodified if `strict` is set to true. If `strict` is -/// false, other operations that use results of rewritten ops or supply operands -/// to such ops are in turn simplified; any other ops still remain unmodified -/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. If more far -/// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. +/// these ops. +/// +/// Newly created ops and other pre-existing ops that use results of rewritten +/// ops or supply operands to such ops are simplified, unless such ops are +/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless +/// of `strictMode`). +/// +/// * GreedyRewriteStrictness::AnyOp: No ops are excluded. +/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly +/// created ops are simplified. All other ops are excluded. +/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are +/// simplified. All other ops are excluded. +/// +/// Note that ops in `ops` could be erased as result of folding, becoming dead, +/// or via pattern rewrites. If more far reaching simplification is desired, +/// applyPatternsAndFoldGreedily should be used. /// /// Returns success if the iterative process converged and no more patterns can /// be matched. `changed` is set to true if the IR was modified at all. LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, - bool strict, bool *changed = nullptr); + GreedyRewriteStrictness strictMode, + bool *changed = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 282a35b..d516de8 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -132,7 +132,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results, FrozenRewritePatternSet frozenPatterns(std::move(patterns)); // Apply the simplification pattern to a fixpoint. if (failed( - applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) { + applyOpPatternsAndFold(targets, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 0d84d38..a9d6f94 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -239,5 +239,6 @@ void AffineDataCopyGeneration::runOnOperation() { AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true); + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index bb5b390..6cb0a30 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -105,5 +105,6 @@ void SimplifyAffineStructures::runOnOperation() { if (isa(op)) opsToSimplify.push_back(op); }); - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true); + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, + GreedyRewriteStrictness::ExistingAndNewOps); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index b7ea592..56a9466 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -575,66 +575,54 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool strict) + GreedyRewriteStrictness strictMode) : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strict) {} + strictMode(strictMode) {} + /// Performs the specified rewrites on `ops` while also trying to fold these + /// ops. `strictMode` controls which other ops are simplified. + /// + /// Note that ops in `ops` could be erased as a result of folding, becoming + /// dead, or via pattern rewrites. The return value indicates convergence. LogicalResult simplifyLocally(ArrayRef op, bool *changed = nullptr); void addToWorklist(Operation *op) override { - if (!strictMode || strictModeFilteredOps.contains(op)) + if (strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) GreedyPatternRewriteDriver::addSingleOpToWorklist(op); } private: void notifyOperationInserted(Operation *op) override { - if (strictMode) + if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps) strictModeFilteredOps.insert(op); GreedyPatternRewriteDriver::notifyOperationInserted(op); } void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); - if (strictMode) + if (strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); } - /// If `strictMode` is true, any pre-existing ops outside of - /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. - /// If `strictMode` is false, operations that use results of (or supply - /// operands to) any rewritten ops stemming from the simplification of the - /// provided ops are in turn simplified; any other ops still remain untouched - /// (i.e., regardless of `strictMode`). - bool strictMode = false; - - /// The list of ops we are restricting our rewrites to if `strictMode` is on. - /// These include the supplied set of ops as well as new ops created while - /// rewriting those ops. This set is not maintained when strictMode is off. + /// `strictMode` control which ops are added to the worklist during + /// simplification. + GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + + /// The list of ops we are restricting our rewrites to. These include the + /// supplied set of ops as well as new ops created while rewriting those ops + /// depending on `strictMode`. This set is not maintained when `strictMode` + /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; }; } // namespace -/// Performs the specified rewrites on `ops` while also trying to fold these ops -/// as well as any other ops that were in turn created due to these rewrite -/// patterns. Any pre-existing ops outside of `ops` remain completely -/// unmodified if `strictMode` is true. If `strictMode` is false, other -/// operations that use results of rewritten ops or supply operands to such ops -/// are in turn simplified; any other ops still remain unmodified (i.e., -/// regardless of `strictMode`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. Returns true if -/// at all any changes happened. -// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op -// or GreedyPatternRewriteDriver::simplify, this method just iterates until -// the worklist is empty. As our objective is to keep simplification "local", -// there is no strong rationale to re-add all operations into the worklist and -// rerun until an iteration changes nothing. If more widereaching simplification -// is desired, GreedyPatternRewriteDriver should be used. LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, bool *changed) { - if (strictMode) { + if (strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } @@ -659,7 +647,8 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, if (op == nullptr) continue; - assert((!strictMode || strictModeFilteredOps.contains(op)) && + assert((strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) && "unexpected op was inserted under strict mode"); // If the operation is trivially dead - remove it. @@ -718,9 +707,6 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, return success(worklist.empty()); } -/// Rewrites only `op` using the supplied canonicalization patterns and -/// folding. `erased` is set to true if the op is erased as a result of being -/// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { // Start the pattern driver. @@ -738,10 +724,9 @@ LogicalResult mlir::applyOpPatternsAndFold( return converged; } -LogicalResult -mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict, bool *changed) { +LogicalResult mlir::applyOpPatternsAndFold( + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, bool *changed) { if (ops.empty()) { if (changed) *changed = false; @@ -750,6 +735,6 @@ mlir::applyOpPatternsAndFold(ArrayRef ops, // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strict); + strictMode); return driver.simplifyLocally(ops, changed); } diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir index 8c6eaf3..ad6f6a5 100644 --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -1,9 +1,15 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s +// RUN: mlir-opt \ +// RUN: -test-strict-pattern-driver="strictness=ExistingAndNewOps" \ +// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN -// CHECK-LABEL: func @test_erase -// CHECK: test.arg0 -// CHECK: test.arg1 -// CHECK-NOT: test.erase_op +// RUN: mlir-opt \ +// RUN: -test-strict-pattern-driver="strictness=ExistingOps" \ +// RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX + +// CHECK-EN-LABEL: func @test_erase +// CHECK-EN: test.arg0 +// CHECK-EN: test.arg1 +// CHECK-EN-NOT: test.erase_op func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -11,18 +17,22 @@ func.func @test_erase() { return } -// CHECK-LABEL: func @test_insert_same_op -// CHECK: "test.insert_same_op"() {skip = true} -// CHECK: "test.insert_same_op"() {skip = true} +// ----- + +// CHECK-EN-LABEL: func @test_insert_same_op +// CHECK-EN: "test.insert_same_op"() {skip = true} +// CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { %0 = "test.insert_same_op"() : () -> (i32) return } -// CHECK-LABEL: func @test_replace_with_new_op -// CHECK: %[[n:.*]] = "test.new_op" -// CHECK: "test.dummy_user"(%[[n]]) -// CHECK: "test.dummy_user"(%[[n]]) +// ----- + +// CHECK-EN-LABEL: func @test_replace_with_new_op +// CHECK-EN: %[[n:.*]] = "test.new_op" +// CHECK-EN: "test.dummy_user"(%[[n]]) +// CHECK-EN: "test.dummy_user"(%[[n]]) func.func @test_replace_with_new_op() { %0 = "test.replace_with_new_op"() : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) @@ -30,9 +40,15 @@ func.func @test_replace_with_new_op() { return } -// CHECK-LABEL: func @test_replace_with_erase_op -// CHECK-NOT: test.replace_with_new_op -// CHECK-NOT: test.erase_op +// ----- + +// CHECK-EN-LABEL: func @test_replace_with_erase_op +// CHECK-EN-NOT: test.replace_with_new_op +// CHECK-EN-NOT: test.erase_op + +// CHECK-EX-LABEL: func @test_replace_with_erase_op +// CHECK-EX-NOT: test.replace_with_new_op +// CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { "test.replace_with_new_op"() {create_erase_op} : () -> () return diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index 117f83e..7dc478c 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -132,7 +132,8 @@ void TestAffineDataCopy::runOnOperation() { AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } } - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true); + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), + GreedyRewriteStrictness::ExistingAndNewOps); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d3ef160..286d0de 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -244,11 +244,13 @@ public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) TestStrictPatternDriver() = default; - TestStrictPatternDriver(const TestStrictPatternDriver &other) = default; + TestStrictPatternDriver(const TestStrictPatternDriver &other) { + strictMode = other.strictMode; + } StringRef getArgument() const final { return "test-strict-pattern-driver"; } StringRef getDescription() const final { - return "Run strict mode of pattern driver"; + return "Test strict mode of pattern driver"; } void runOnOperation() override { @@ -263,13 +265,28 @@ public: } }); + GreedyRewriteStrictness mode; + if (strictMode == "AnyOp") { + mode = GreedyRewriteStrictness::AnyOp; + } else if (strictMode == "ExistingAndNewOps") { + mode = GreedyRewriteStrictness::ExistingAndNewOps; + } else if (strictMode == "ExistingOps") { + mode = GreedyRewriteStrictness::ExistingOps; + } else { + llvm_unreachable("invalid strictness option"); + } + // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), - /*strict=*/true); + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode); } + Option strictMode{ + *this, "strictness", + llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), + llvm::cl::init("AnyOp")}; + private: // New inserted operation is valid for further transformation. class InsertSameOp : public RewritePattern { -- 2.7.4