From d75a611afbc7c5f8c343e0398dd2b506684e506b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 19 Mar 2021 16:19:23 -0700 Subject: [PATCH] [mlir] Update `simplifyRegions` to use RewriterBase for erasure notifications This allows for notifying callers when operations/blocks get erased, which is especially useful for the greedy pattern driver. The current greedy pattern driver "throws away" all information on constants in the operation folder because it doesn't know if they get erased or not. By passing in RewriterBase, we can directly track this and prevent the need for the pattern driver to rediscover all of the existing constants. In some situations this cuts the compile time of the canonicalizer in half. Differential Revision: https://reviews.llvm.org/D98755 --- mlir/include/mlir/Transforms/RegionUtils.h | 7 +++- .../Utils/GreedyPatternRewriteDriver.cpp | 7 +--- mlir/lib/Transforms/Utils/RegionUtils.cpp | 46 +++++++++++++--------- mlir/test/Dialect/SCF/canonicalize.mlir | 2 +- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 72c2f51..c2124d8 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/SetVector.h" namespace mlir { +class RewriterBase; /// Check if all values in the provided range are defined above the `limit` /// region. That is, if they are defined in a region that is a proper ancestor @@ -53,8 +54,10 @@ void getUsedValuesDefinedAbove(MutableArrayRef regions, /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any -/// of the regions were simplified, failure otherwise. -LogicalResult simplifyRegions(MutableArrayRef regions); +/// of the regions were simplified, failure otherwise. The provided rewriter is +/// used to notify callers of operation and block deletion. +LogicalResult simplifyRegions(RewriterBase &rewriter, + MutableArrayRef regions); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 9ed3b35..922fbb1 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -114,7 +114,7 @@ private: // TODO: This is based on the fact that zero use operations // may be deleted, and that single use values often have more // canonicalization opportunities. - if (!operand.use_empty() && !operand.hasOneUse()) + if (!operand || (!operand.use_empty() && !operand.hasOneUse())) continue; if (auto *defInst = operand.getDefiningOp()) addToWorklist(defInst); @@ -202,10 +202,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, // After applying patterns, make sure that the CFG of each of the regions is // kept up to date. - if (succeeded(simplifyRegions(regions))) { - folder.clear(); - changed = true; - } + changed |= succeeded(simplifyRegions(*this, regions)); } while (changed && ++i < maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 21d0ff5..47635c3 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -75,7 +76,8 @@ void mlir::getUsedValuesDefinedAbove(MutableArrayRef regions, /// Erase the unreachable blocks within the provided regions. Returns success /// if any blocks were erased, failure otherwise. // TODO: We could likely merge this with the DCE algorithm below. -static LogicalResult eraseUnreachableBlocks(MutableArrayRef regions) { +static LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, + MutableArrayRef regions) { // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set reachable; // If any blocks were found to be dead. @@ -108,7 +110,7 @@ static LogicalResult eraseUnreachableBlocks(MutableArrayRef regions) { for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { block.dropAllDefinedValueUses(); - block.erase(); + rewriter.eraseBlock(&block); erasedDeadBlocks = true; continue; } @@ -305,7 +307,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, } } -static LogicalResult deleteDeadness(MutableArrayRef regions, +static LogicalResult deleteDeadness(RewriterBase &rewriter, + MutableArrayRef regions, LiveMap &liveMap) { bool erasedAnything = false; for (Region ®ion : regions) { @@ -324,10 +327,10 @@ static LogicalResult deleteDeadness(MutableArrayRef regions, if (!liveMap.wasProvenLive(&childOp)) { erasedAnything = true; childOp.dropAllUses(); - childOp.erase(); + rewriter.eraseOp(&childOp); } else { - erasedAnything |= - succeeded(deleteDeadness(childOp.getRegions(), liveMap)); + erasedAnything |= succeeded( + deleteDeadness(rewriter, childOp.getRegions(), liveMap)); } } } @@ -359,7 +362,8 @@ static LogicalResult deleteDeadness(MutableArrayRef regions, // // This function returns success if any operations or arguments were deleted, // failure otherwise. -static LogicalResult runRegionDCE(MutableArrayRef regions) { +static LogicalResult runRegionDCE(RewriterBase &rewriter, + MutableArrayRef regions) { LiveMap liveMap; do { liveMap.resetChanged(); @@ -368,7 +372,7 @@ static LogicalResult runRegionDCE(MutableArrayRef regions) { propagateLiveness(region, liveMap); } while (liveMap.hasChanged()); - return deleteDeadness(regions, liveMap); + return deleteDeadness(rewriter, regions, liveMap); } //===----------------------------------------------------------------------===// @@ -456,7 +460,7 @@ public: LogicalResult addToCluster(BlockEquivalenceData &blockData); /// Try to merge all of the blocks within this cluster into the leader block. - LogicalResult merge(); + LogicalResult merge(RewriterBase &rewriter); private: /// The equivalence data for the leader of the cluster. @@ -550,7 +554,7 @@ static bool ableToUpdatePredOperands(Block *block) { return true; } -LogicalResult BlockMergeCluster::merge() { +LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { // Don't consider clusters that don't have blocks to merge. if (blocksToMerge.empty()) return failure(); @@ -613,7 +617,7 @@ LogicalResult BlockMergeCluster::merge() { // Replace all uses of the merged blocks with the leader and erase them. for (Block *block : blocksToMerge) { block->replaceAllUsesWith(leaderBlock); - block->erase(); + rewriter.eraseBlock(block); } return success(); } @@ -621,7 +625,8 @@ LogicalResult BlockMergeCluster::merge() { /// Identify identical blocks within the given region and merge them, inserting /// new block arguments as necessary. Returns success if any blocks were merged, /// failure otherwise. -static LogicalResult mergeIdenticalBlocks(Region ®ion) { +static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, + Region ®ion) { if (region.empty() || llvm::hasSingleElement(region)) return failure(); @@ -659,7 +664,7 @@ static LogicalResult mergeIdenticalBlocks(Region ®ion) { clusters.emplace_back(std::move(data)); } for (auto &cluster : clusters) - mergedAnyBlocks |= succeeded(cluster.merge()); + mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); } return success(mergedAnyBlocks); @@ -667,14 +672,15 @@ static LogicalResult mergeIdenticalBlocks(Region ®ion) { /// Identify identical blocks within the given regions and merge them, inserting /// new block arguments as necessary. -static LogicalResult mergeIdenticalBlocks(MutableArrayRef regions) { +static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, + MutableArrayRef regions) { llvm::SmallSetVector worklist; for (auto ®ion : regions) worklist.insert(®ion); bool anyChanged = false; while (!worklist.empty()) { Region *region = worklist.pop_back_val(); - if (succeeded(mergeIdenticalBlocks(*region))) { + if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { worklist.insert(region); anyChanged = true; } @@ -697,10 +703,12 @@ static LogicalResult mergeIdenticalBlocks(MutableArrayRef regions) { /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. -LogicalResult mlir::simplifyRegions(MutableArrayRef regions) { - bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions)); - bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions)); - bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions)); +LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, + MutableArrayRef regions) { + bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); + bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); + bool mergedIdenticalBlocks = + succeeded(mergeIdenticalBlocks(rewriter, regions)); return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 2824fde..0a1558f 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -21,12 +21,12 @@ func @single_iteration(%A: memref) { // CHECK-LABEL: func @single_iteration( // CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK: [[C42:%.*]] = constant 42 : i32 // CHECK: [[C0:%.*]] = constant 0 : index // CHECK: [[C2:%.*]] = constant 2 : index // CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[C6:%.*]] = constant 6 : index // CHECK: [[C7:%.*]] = constant 7 : index -// CHECK: [[C42:%.*]] = constant 42 : i32 // CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) { // CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref // CHECK: scf.yield -- 2.7.4