From 93377888ae89560ba6d3976e2762d3d4724c4dfd Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Sun, 11 Oct 2020 10:40:28 +0200 Subject: [PATCH] [mlir] add scf.if op canonicalization pattern that removes unused results The patch adds a canonicalization pattern that removes the unused results of scf.if operation. As a result, cse may remove unused computations in the then and else regions of the scf.if operation. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D89029 --- mlir/include/mlir/Dialect/SCF/SCFOps.td | 2 + mlir/lib/Dialect/SCF/SCF.cpp | 61 ++++++++++++++++++++++++ mlir/test/Dialect/SCF/canonicalize.mlir | 84 +++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index d7ff8b6..476898a 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -262,6 +262,8 @@ def IfOp : SCF_Op<"if", : OpBuilder::atBlockEnd(body, listener); } }]; + + let hasCanonicalizer = 1; } def ParallelOp : SCF_Op<"parallel", diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index e36ffc2..f25ccc4 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -508,6 +508,67 @@ void IfOp::getSuccessorRegions(Optional index, regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); } +namespace { +// Pattern to remove unused IfOp results. +struct RemoveUnusedResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void transferBody(Block *source, Block *dest, ArrayRef usedResults, + PatternRewriter &rewriter) const { + // Move all operations to the destination block. + rewriter.mergeBlocks(source, dest); + // Replace the yield op by one that returns only the used values. + auto yieldOp = cast(dest->getTerminator()); + SmallVector usedOperands; + llvm::transform(usedResults, std::back_inserter(usedOperands), + [&](OpResult result) { + return yieldOp.getOperand(result.getResultNumber()); + }); + rewriter.updateRootInPlace( + yieldOp, [&]() { yieldOp.getOperation()->setOperands(usedOperands); }); + } + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Compute the list of used results. + SmallVector usedResults; + llvm::copy_if(op.getResults(), std::back_inserter(usedResults), + [](OpResult result) { return !result.use_empty(); }); + + // Replace the operation if only a subset of its results have uses. + if (usedResults.size() == op.getNumResults()) + return failure(); + + // Compute the result types of the replacement operation. + SmallVector newTypes; + llvm::transform(usedResults, std::back_inserter(newTypes), + [](OpResult result) { return result.getType(); }); + + // Create a replacement operation with empty then and else regions. + auto emptyBuilder = [](OpBuilder &, Location) {}; + auto newOp = rewriter.create(op.getLoc(), newTypes, op.condition(), + emptyBuilder, emptyBuilder); + + // Move the bodies and replace the terminators (note there is a then and + // an else region since the operation returns results). + transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter); + transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter); + + // Replace the operation by the new one. + SmallVector repResults(op.getNumResults()); + for (auto en : llvm::enumerate(usedResults)) + repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); + rewriter.replaceOp(op, repResults); + return success(); + } +}; +} // namespace + +void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index fc98dab..a967860 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -53,3 +53,87 @@ func @no_iteration(%A: memref) { // CHECK: scf.yield // CHECK: } // CHECK: return + +// ----- + +func @one_unused() -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + scf.yield %c0, %c1 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + return %1 : index +} + +// CHECK-LABEL: func @one_unused +// CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C1:%.*]] = constant true +// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: scf.yield [[C0]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: return [[V0]] : index + +// ----- + +func @nested_unused() -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + %2, %3 = scf.if %true -> (index, index) { + scf.yield %c0, %c1 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + scf.yield %2, %3 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + return %1 : index +} + +// CHECK-LABEL: func @nested_unused +// CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C1:%.*]] = constant true +// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: [[V1:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: scf.yield [[C0]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: scf.yield [[V1]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: return [[V0]] : index + +// ----- + +func @side_effect() {} +func @all_unused() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + call @side_effect() : () -> () + scf.yield %c0, %c1 : index, index + } else { + call @side_effect() : () -> () + scf.yield %c0, %c1 : index, index + } + return +} + +// CHECK-LABEL: func @all_unused +// CHECK: [[C1:%.*]] = constant true +// CHECK: scf.if [[C1]] { +// CHECK: call @side_effect() : () -> () +// CHECK: } else +// CHECK: call @side_effect() : () -> () +// CHECK: } +// CHECK: return -- 2.7.4