[mlir] Update `simplifyRegions` to use RewriterBase for erasure notifications
authorRiver Riddle <riddleriver@gmail.com>
Fri, 19 Mar 2021 23:19:23 +0000 (16:19 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 19 Mar 2021 23:33:54 +0000 (16:33 -0700)
This allows for notifying callers when operations/blocks get erased, which is especially useful for the greedy pattern driver. The current greedy pattern driver "throws away" all information on constants in the operation folder because it doesn't know if they get erased or not. By passing in RewriterBase, we can directly track this and prevent the need for the pattern driver to rediscover all of the existing constants. In some situations this cuts the compile time of the canonicalizer in half.

Differential Revision: https://reviews.llvm.org/D98755

mlir/include/mlir/Transforms/RegionUtils.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index 72c2f51..c2124d8 100644 (file)
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
+class RewriterBase;
 
 /// Check if all values in the provided range are defined above the `limit`
 /// region.  That is, if they are defined in a region that is a proper ancestor
@@ -53,8 +54,10 @@ void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
-/// of the regions were simplified, failure otherwise.
-LogicalResult simplifyRegions(MutableArrayRef<Region> regions);
+/// of the regions were simplified, failure otherwise. The provided rewriter is
+/// used to notify callers of operation and block deletion.
+LogicalResult simplifyRegions(RewriterBase &rewriter,
+                              MutableArrayRef<Region> regions);
 
 } // namespace mlir
 
index 9ed3b35..922fbb1 100644 (file)
@@ -114,7 +114,7 @@ private:
       // TODO: This is based on the fact that zero use operations
       // may be deleted, and that single use values often have more
       // canonicalization opportunities.
-      if (!operand.use_empty() && !operand.hasOneUse())
+      if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
         continue;
       if (auto *defInst = operand.getDefiningOp())
         addToWorklist(defInst);
@@ -202,10 +202,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 
     // After applying patterns, make sure that the CFG of each of the regions is
     // kept up to date.
-    if (succeeded(simplifyRegions(regions))) {
-      folder.clear();
-      changed = true;
-    }
+    changed |= succeeded(simplifyRegions(*this, regions));
   } while (changed && ++i < maxIterations);
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;
index 21d0ff5..47635c3 100644 (file)
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -75,7 +76,8 @@ void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
 /// Erase the unreachable blocks within the provided regions. Returns success
 /// if any blocks were erased, failure otherwise.
 // TODO: We could likely merge this with the DCE algorithm below.
-static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
+static LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
   // Set of blocks found to be reachable within a given region.
   llvm::df_iterator_default_set<Block *, 16> reachable;
   // If any blocks were found to be dead.
@@ -108,7 +110,7 @@ static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
     for (Block &block : llvm::make_early_inc_range(*region)) {
       if (!reachable.count(&block)) {
         block.dropAllDefinedValueUses();
-        block.erase();
+        rewriter.eraseBlock(&block);
         erasedDeadBlocks = true;
         continue;
       }
@@ -305,7 +307,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
   }
 }
 
-static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
+static LogicalResult deleteDeadness(RewriterBase &rewriter,
+                                    MutableArrayRef<Region> regions,
                                     LiveMap &liveMap) {
   bool erasedAnything = false;
   for (Region &region : regions) {
@@ -324,10 +327,10 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
         if (!liveMap.wasProvenLive(&childOp)) {
           erasedAnything = true;
           childOp.dropAllUses();
-          childOp.erase();
+          rewriter.eraseOp(&childOp);
         } else {
-          erasedAnything |=
-              succeeded(deleteDeadness(childOp.getRegions(), liveMap));
+          erasedAnything |= succeeded(
+              deleteDeadness(rewriter, childOp.getRegions(), liveMap));
         }
       }
     }
@@ -359,7 +362,8 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
 //
 // This function returns success if any operations or arguments were deleted,
 // failure otherwise.
-static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
+static LogicalResult runRegionDCE(RewriterBase &rewriter,
+                                  MutableArrayRef<Region> regions) {
   LiveMap liveMap;
   do {
     liveMap.resetChanged();
@@ -368,7 +372,7 @@ static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
       propagateLiveness(region, liveMap);
   } while (liveMap.hasChanged());
 
-  return deleteDeadness(regions, liveMap);
+  return deleteDeadness(rewriter, regions, liveMap);
 }
 
 //===----------------------------------------------------------------------===//
@@ -456,7 +460,7 @@ public:
   LogicalResult addToCluster(BlockEquivalenceData &blockData);
 
   /// Try to merge all of the blocks within this cluster into the leader block.
-  LogicalResult merge();
+  LogicalResult merge(RewriterBase &rewriter);
 
 private:
   /// The equivalence data for the leader of the cluster.
@@ -550,7 +554,7 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
-LogicalResult BlockMergeCluster::merge() {
+LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
     return failure();
@@ -613,7 +617,7 @@ LogicalResult BlockMergeCluster::merge() {
   // Replace all uses of the merged blocks with the leader and erase them.
   for (Block *block : blocksToMerge) {
     block->replaceAllUsesWith(leaderBlock);
-    block->erase();
+    rewriter.eraseBlock(block);
   }
   return success();
 }
@@ -621,7 +625,8 @@ LogicalResult BlockMergeCluster::merge() {
 /// Identify identical blocks within the given region and merge them, inserting
 /// new block arguments as necessary. Returns success if any blocks were merged,
 /// failure otherwise.
-static LogicalResult mergeIdenticalBlocks(Region &region) {
+static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
+                                          Region &region) {
   if (region.empty() || llvm::hasSingleElement(region))
     return failure();
 
@@ -659,7 +664,7 @@ static LogicalResult mergeIdenticalBlocks(Region &region) {
         clusters.emplace_back(std::move(data));
     }
     for (auto &cluster : clusters)
-      mergedAnyBlocks |= succeeded(cluster.merge());
+      mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
   }
 
   return success(mergedAnyBlocks);
@@ -667,14 +672,15 @@ static LogicalResult mergeIdenticalBlocks(Region &region) {
 
 /// Identify identical blocks within the given regions and merge them, inserting
 /// new block arguments as necessary.
-static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
+static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
+                                          MutableArrayRef<Region> regions) {
   llvm::SmallSetVector<Region *, 1> worklist;
   for (auto &region : regions)
     worklist.insert(&region);
   bool anyChanged = false;
   while (!worklist.empty()) {
     Region *region = worklist.pop_back_val();
-    if (succeeded(mergeIdenticalBlocks(*region))) {
+    if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
       worklist.insert(region);
       anyChanged = true;
     }
@@ -697,10 +703,12 @@ static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise.
-LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
-  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions));
-  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions));
-  bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions));
+LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
+                                    MutableArrayRef<Region> regions) {
+  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
+  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
+  bool mergedIdenticalBlocks =
+      succeeded(mergeIdenticalBlocks(rewriter, regions));
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
index 2824fde..0a1558f 100644 (file)
@@ -21,12 +21,12 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // CHECK-LABEL:   func @single_iteration(
 // CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
+// CHECK:           [[C42:%.*]] = constant 42 : i32
 // CHECK:           [[C0:%.*]] = constant 0 : index
 // CHECK:           [[C2:%.*]] = constant 2 : index
 // CHECK:           [[C3:%.*]] = constant 3 : index
 // CHECK:           [[C6:%.*]] = constant 6 : index
 // CHECK:           [[C7:%.*]] = constant 7 : index
-// CHECK:           [[C42:%.*]] = constant 42 : i32
 // CHECK:           scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) {
 // CHECK:             memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref<?x?x?xi32>
 // CHECK:             scf.yield