From 6a501e3d1b6725bf63f9cf053f7674c042794d49 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Mon, 26 Aug 2019 09:44:09 -0700 Subject: [PATCH] Support folding of ops with inner ops in GreedyPatternRewriteDriver. This fixes a bug when folding ops with inner ops and inner ops are still being visited. PiperOrigin-RevId: 265475780 --- .../Utils/GreedyPatternRewriteDriver.cpp | 15 ++++++------- mlir/test/Transforms/test-canonicalize.mlir | 25 +++++++++++++++++++--- mlir/test/lib/TestDialect/TestDialect.cpp | 13 +++++++---- mlir/test/lib/TestDialect/TestOps.td | 16 +++++++++++++- 4 files changed, 53 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 3fb1998..ddb92a5 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -96,8 +96,6 @@ protected: // worklist anymore because we'd get dangling references to it. void notifyOperationRemoved(Operation *op) override { addToWorklist(op->getOperands()); - removeFromWorklist(op); - folder.notifyRemoval(op); op->walk([this](Operation *operation) { removeFromWorklist(operation); folder.notifyRemoval(operation); @@ -174,17 +172,16 @@ bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) { // If the operation has no side effects, and no users, then it is // trivially dead - remove it. if (op->hasNoSideEffect() && op->use_empty()) { - // Be careful to update bookkeeping in OperationFolder to keep - // consistency if this is a constant op. - folder.notifyRemoval(op); + // Be careful to update bookkeeping. + notifyOperationRemoved(op); op->erase(); continue; } // Collects all the operands and result uses of the given `op` into work - // list. + // list. Also remove `op` and nested ops from worklist. originalOperands.assign(op->operand_begin(), op->operand_end()); - auto collectOperandsAndUses = [&](Operation *op) { + auto preReplaceAction = [&](Operation *op) { // Add the operands to the worklist for visitation. addToWorklist(originalOperands); @@ -193,10 +190,12 @@ bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) { for (auto *result : op->getResults()) for (auto *operand : result->getUsers()) addToWorklist(operand); + + notifyOperationRemoved(op); }; // Try to fold this op. - if (succeeded(folder.tryToFold(op, collectOps, collectOperandsAndUses))) { + if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) { changed |= true; continue; } diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir index a98be27..7bae125 100644 --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -1,10 +1,29 @@ // RUN: mlir-opt %s -canonicalize | FileCheck %s -// CHECK-LABEL: func @remove_op_with_inner_ops -func @remove_op_with_inner_ops() { +// CHECK-LABEL: func @remove_op_with_inner_ops_pattern +func @remove_op_with_inner_ops_pattern() { // CHECK-NEXT: return - "test.op_with_region"() ({ + "test.op_with_region_pattern"() ({ "foo.op_with_region_terminator"() : () -> () }) : () -> () return } + +// CHECK-LABEL: func @remove_op_with_inner_ops_fold_no_side_effect +func @remove_op_with_inner_ops_fold_no_side_effect() { + // CHECK-NEXT: return + "test.op_with_region_fold_no_side_effect"() ({ + "foo.op_with_region_terminator"() : () -> () + }) : () -> () + return +} + +// CHECK-LABEL: func @remove_op_with_inner_ops_fold +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32) +func @remove_op_with_inner_ops_fold(%arg0 : i32) -> (i32) { + // CHECK-NEXT: return %[[ARG_0]] + %0 = "test.op_with_region_fold"(%arg0) ({ + "foo.op_with_region_terminator"() : () -> () + }) : (i32) -> (i32) + return %0 : i32 +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 8b44b6c..af5c5c8 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -82,10 +82,11 @@ static ParseResult parsePolyForOp(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// namespace { -struct TestRemoveOpWithInnerOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TestRemoveOpWithInnerOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TestOpWithRegion op, + PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op, PatternRewriter &rewriter) const override { rewriter.replaceOp(op, llvm::None); return matchSuccess(); @@ -93,11 +94,15 @@ struct TestRemoveOpWithInnerOps : public OpRewritePattern { }; } // end anonymous namespace -void TestOpWithRegion::getCanonicalizationPatterns( +void TestOpWithRegionPattern::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } +OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { + return operand(); +} + // Static initialization for Test dialect registration. static mlir::DialectRegistration testDialect; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 2926930..fd2d235 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -338,11 +338,25 @@ def : Pat<(OpAllAttrConstraint1 (OpAllAttrConstraint2 $attr)>; // Op for testing RewritePattern removing op with inner ops. -def TestOpWithRegion : TEST_Op<"op_with_region"> { +def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> { let regions = (region SizedRegion<1>:$region); let hasCanonicalizer = 1; } +// Op for testing trivial removal via folding of op with inner ops and no uses. +def TestOpWithRegionFoldNoSideEffect : TEST_Op< + "op_with_region_fold_no_side_effect", [NoSideEffect]> { + let regions = (region SizedRegion<1>:$region); +} + +// Op for testing folding of outer op with inner ops. +def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> { + let arguments = (ins I32:$operand); + let results = (outs I32:$result); + let regions = (region SizedRegion<1>:$region); + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding) -- 2.7.4