/// The kind of the block action performed during the rewrite. Actions can be
/// undone if the conversion fails.
- enum class BlockActionKind { Create, Move, Split, TypeConversion };
+ enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
static BlockAction getCreate(Block *block) {
return {BlockActionKind::Create, block, {}};
}
+ static BlockAction getErase(Block *block, BlockPosition originalPos) {
+ return {BlockActionKind::Erase, block, {originalPos}};
+ }
static BlockAction getMove(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Move, block, {originalPos}};
}
Block *block;
union {
- // In use if kind == BlockActionKind::Move and contains a pointer to the
- // region that originally contained the block as well as the position of
- // the block in that region.
+ // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+ // contains a pointer to the region that originally contained the block as
+ // well as the position of the block in that region.
BlockPosition originalPosition;
// In use if kind == BlockActionKind::Split and contains a pointer to the
// block that was split into two parts.
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
+ /// Erase any blocks that were unlinked from their regions and stored in block
+ /// actions.
+ void eraseDanglingBlocks();
+
/// Undo the block actions (motions, splits) one by one in reverse order until
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues);
+ /// Notifies that a block is about to be erased.
+ void notifyBlockIsBeingErased(Block *block);
+
/// Notifies that a block was created.
void notifyCreatedBlock(Block *block);
ignoredOps.pop_back();
}
+void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
+ for (auto &action : blockActions) {
+ if (action.kind != BlockActionKind::Erase)
+ continue;
+ delete action.block;
+ }
+}
+
void ConversionPatternRewriterImpl::undoBlockActions(
unsigned numActionsToKeep) {
for (auto &action :
action.block->erase();
break;
}
+ // Put the block (owned by action) back into its original position.
+ case BlockActionKind::Erase: {
+ auto &blockList = action.originalPosition.region->getBlocks();
+ blockList.insert(
+ std::next(blockList.begin(), action.originalPosition.position),
+ action.block);
+ break;
+ }
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
repl.op->erase();
argConverter.applyRewrites(mapping);
+
+ // Now that the ops have been erased, also erase dangling blocks.
+ eraseDanglingBlocks();
}
LogicalResult
markNestedOpsIgnored(op);
}
+void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
+ Region *region = block->getParent();
+ auto position = std::distance(region->begin(), Region::iterator(block));
+ blockActions.push_back(BlockAction::getErase(block, {region, position}));
+}
+
void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
blockActions.push_back(BlockAction::getCreate(block));
}
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- llvm_unreachable("erasing blocks for dialect conversion not implemented");
+ impl->notifyBlockIsBeingErased(block);
+
+ // Mark all ops for erasure.
+ for (Operation &op : *block)
+ eraseOp(&op);
+
+ // Unlink the block from its parent region. The block is kept in the block
+ // action and will be actually destroyed when rewrites are applied. This
+ // allows us to keep the operations in the block live and undo the removal by
+ // re-inserting the block.
+ block->getParent()->getBlocks().remove(block);
}
/// Apply a signature conversion to the entry block of the given region.
i != e; ++i) {
auto &action = rewriterImpl.blockActions[i];
if (action.kind ==
- ConversionPatternRewriterImpl::BlockActionKind::TypeConversion)
+ ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
+ action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
continue;
// Convert the block signature.
// -----
+// The op in this function is rewritten to itself (and thus remains illegal) by
+// a pattern that removes its second block after adding an operation into it.
+// Check that we can undo block removal succesfully.
+// CHECK-LABEL: @undo_block_erase
+func @undo_block_erase() {
+ // CHECK: test.undo_block_erase
+ "test.undo_block_erase"() ({
+ // expected-remark@-1 {{not legalizable}}
+ // CHECK: "unregistered.return"()[^[[BB:.*]]]
+ "unregistered.return"()[^bb1] : () -> ()
+ // expected-remark@-1 {{not legalizable}}
+ // CHECK: ^[[BB]]
+ ^bb1:
+ // CHECK: unregistered.return
+ "unregistered.return"() : () -> ()
+ // expected-remark@-1 {{not legalizable}}
+ }) : () -> ()
+}
+
+// -----
+
// The op in this function is attempted to be rewritten to another illegal op
// with an attached region containing an invalid terminator. The terminator is
// created before the parent op. The deletion should not crash when deleting
}
};
+/// A rewrite pattern that tests the undo mechanism when erasing a block.
+struct TestUndoBlockErase : public ConversionPattern {
+ TestUndoBlockErase(MLIRContext *ctx)
+ : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Block *secondBlock = &*std::next(op->getRegion(0).begin());
+ rewriter.setInsertionPointToStart(secondBlock);
+ rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ rewriter.eraseBlock(secondBlock);
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
- patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
- TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64,
- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement, TestBoundedRecursiveRewrite,
- TestNestedOpCreationUndoRewrite>(&getContext());
+ patterns.insert<
+ TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
+ TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
+ TestPassthroughInvalidOp, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);