[MLIR] Add PatternRewriter::mergeBlockBefore() to merge a block in the middle of...
authorRahul Joshi <jurahul@google.com>
Wed, 19 Aug 2020 23:07:42 +0000 (16:07 -0700)
committerRahul Joshi <jurahul@google.com>
Wed, 19 Aug 2020 23:24:59 +0000 (16:24 -0700)
- This utility to merge a block anywhere into another one can help inline single
  block regions into other blocks.
- Modified patterns test to use the new function.

Differential Revision: https://reviews.llvm.org/D86251

mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/Block.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index 859fa17..ca25230 100644 (file)
@@ -207,7 +207,10 @@ public:
   }
 
   /// Return true if this block has no predecessors.
-  bool hasNoPredecessors();
+  bool hasNoPredecessors() { return pred_begin() == pred_end(); }
+
+  /// Returns true if this blocks has no successors.
+  bool hasNoSuccessors() { return succ_begin() == succ_end(); }
 
   /// If this block has exactly one predecessor, return it.  Otherwise, return
   /// null.
index f1c7c39..46dd964 100644 (file)
@@ -326,6 +326,11 @@ public:
   virtual void mergeBlocks(Block *source, Block *dest,
                            ValueRange argValues = llvm::None);
 
+  // Merge the operations of block 'source' before the operation 'op'. Source
+  // block should not have existing predecessors or successors.
+  void mergeBlockBefore(Block *source, Operation *op,
+                        ValueRange argValues = llvm::None);
+
   /// Split the operations starting at "before" (inclusive) out of the given
   /// block into a new block, and return it.
   virtual Block *splitBlock(Block *block, Block::iterator before);
index 8013132..71f368c 100644 (file)
@@ -201,9 +201,6 @@ Operation *Block::getTerminator() {
   return &back();
 }
 
-/// Return true if this block has no predecessors.
-bool Block::hasNoPredecessors() { return pred_begin() == pred_end(); }
-
 // Indexed successor access.
 unsigned Block::getNumSuccessors() {
   return empty() ? 0 : back().getNumSuccessors();
index e05d234..a26bc63 100644 (file)
@@ -128,6 +128,28 @@ void PatternRewriter::mergeBlocks(Block *source, Block *dest,
   source->erase();
 }
 
+// Merge the operations of block 'source' before the operation 'op'. Source
+// block should not have existing predecessors or successors.
+void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
+                                       ValueRange argValues) {
+  assert(source->hasNoPredecessors() &&
+         "expected 'source' to have no predecessors");
+  assert(source->hasNoSuccessors() &&
+         "expected 'source' to have no successors");
+
+  // Split the block containing 'op' into two, one containg all operations
+  // before 'op' (prologue) and another (epilogue) containing 'op' and all
+  // operations after it.
+  Block *prologue = op->getBlock();
+  Block *epilogue = splitBlock(prologue, op->getIterator());
+
+  // Merge the source block at the end of the prologue.
+  mergeBlocks(source, prologue, argValues);
+
+  // Merge the epilogue at the end the prologue.
+  mergeBlocks(epilogue, prologue);
+}
+
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
index be5d799..cee8798 100644 (file)
@@ -893,16 +893,12 @@ struct TestMergeSingleBlockOps
         op.getParentOfType<SingleBlockImplicitTerminatorOp>();
     if (!parentOp)
       return failure();
-    Block &parentBlock = parentOp.region().front();
     Block &innerBlock = op.region().front();
     TerminatorOp innerTerminator =
         cast<TerminatorOp>(innerBlock.getTerminator());
-    Block *parentPrologue =
-        rewriter.splitBlock(&parentBlock, Block::iterator(op));
+    rewriter.mergeBlockBefore(&innerBlock, op);
     rewriter.eraseOp(innerTerminator);
-    rewriter.mergeBlocks(&innerBlock, &parentBlock, {});
     rewriter.eraseOp(op);
-    rewriter.mergeBlocks(parentPrologue, &parentBlock, {});
     rewriter.updateRootInPlace(op, [] {});
     return success();
   }