From ab610e8a9961fb19f396ee4c2065261bd12b0013 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 16 Dec 2019 12:09:14 -0800 Subject: [PATCH] Insert signature-converted blocks into a region with a parent operation. This keeps the IR valid and consistent as it is expected that each block should have a valid parent region/operation. Previously, converted blocks were kept floating without a valid parent region. PiperOrigin-RevId: 285821687 --- mlir/include/mlir/IR/Block.h | 6 ++- mlir/lib/IR/Block.cpp | 10 ++++- mlir/lib/Transforms/DialectConversion.cpp | 75 ++++++++++++++++++++----------- 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index a36ecdd..2ef7bf3 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -56,10 +56,14 @@ public: /// Return if this block is the entry block in the parent region. bool isEntryBlock(); - /// Insert this block (which must not already be in a function) right before + /// Insert this block (which must not already be in a region) right before /// the specified block. void insertBefore(Block *block); + /// Unlink this block from its current region and insert it right before the + /// specific block. + void moveBefore(Block *block); + /// Unlink this Block from its parent region and delete it. void erase(); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index ea92422..63e8580 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -57,7 +57,15 @@ bool Block::isEntryBlock() { return this == &getParent()->front(); } void Block::insertBefore(Block *block) { assert(!getParent() && "already inserted into a block!"); assert(block->getParent() && "cannot insert before a block without a parent"); - block->getParent()->getBlocks().insert(Region::iterator(block), this); + block->getParent()->getBlocks().insert(block->getIterator(), this); +} + +/// Unlink this block from its current region and insert it right before the +/// specific block. +void Block::moveBefore(Block *block) { + assert(block->getParent() && "cannot insert before a block without a parent"); + block->getParent()->getBlocks().splice( + block->getIterator(), getParent()->getBlocks(), getIterator()); } /// Unlink this Block from its parent Region and delete it. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ac13bc2..4b4575a 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -164,9 +164,10 @@ struct ArgConverter { // Rewrite Application //===--------------------------------------------------------------------===// - /// Erase any rewrites registered for the current block that is about to be - /// removed. This merely drops the rewrites without undoing them. - void notifyBlockRemoved(Block *block); + /// Erase any rewrites registered for the blocks within the given operation + /// which is about to be removed. This merely drops the rewrites without + /// undoing them. + void notifyOpRemoved(Operation *op); /// Cleanup and undo any generated conversions for the arguments of block. /// This method replaces the new block with the original, reverting the IR to @@ -194,9 +195,16 @@ struct ArgConverter { Block *block, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping); + /// Insert a new conversion into the cache. + void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); + /// A collection of blocks that have had their arguments converted. llvm::MapVector conversionInfo; + /// A mapping from valid regions, to those containing the original blocks of a + /// conversion. + DenseMap> regionMapping; + /// An instance of the unknown location that is used when materializing /// conversions. Location loc; @@ -212,18 +220,26 @@ struct ArgConverter { //===----------------------------------------------------------------------===// // Rewrite Application -void ArgConverter::notifyBlockRemoved(Block *block) { - auto it = conversionInfo.find(block); - if (it == conversionInfo.end()) - return; - - // Drop all uses of the original arguments and delete the original block. - Block *origBlock = it->second.origBlock; - for (BlockArgument *arg : origBlock->getArguments()) - arg->dropAllUses(); - delete origBlock; - - conversionInfo.erase(it); +void ArgConverter::notifyOpRemoved(Operation *op) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // Drop any rewrites from within. + for (Operation &nestedOp : block) + if (nestedOp.getNumRegions()) + notifyOpRemoved(&nestedOp); + + // Check if this block was converted. + auto it = conversionInfo.find(&block); + if (it == conversionInfo.end()) + return; + + // Drop all uses of the original arguments and delete the original block. + Block *origBlock = it->second.origBlock; + for (BlockArgument *arg : origBlock->getArguments()) + arg->dropAllUses(); + conversionInfo.erase(it); + } + } } void ArgConverter::discardRewrites(Block *block) { @@ -239,7 +255,7 @@ void ArgConverter::discardRewrites(Block *block) { // Move the operations back the original block and the delete the new block. origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - origBlock->insertBefore(block); + origBlock->moveBefore(block); block->erase(); conversionInfo.erase(it); @@ -301,9 +317,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { if (castValue->use_empty()) castValue->getDefiningOp()->erase(); } - - // Drop the original block now the rewrites were applied. - delete origBlock; } } @@ -377,11 +390,24 @@ Block *ArgConverter::applySignatureConversion( } // Remove the original block from the region and return the new one. - newBlock->getParent()->getBlocks().remove(block); - conversionInfo.insert({newBlock, std::move(info)}); + insertConversion(newBlock, std::move(info)); return newBlock; } +void ArgConverter::insertConversion(Block *newBlock, + ConvertedBlockInfo &&info) { + // Get a region to insert the old block. + Region *region = newBlock->getParent(); + std::unique_ptr &mappedRegion = regionMapping[region]; + if (!mappedRegion) + mappedRegion = std::make_unique(region->getParentOp()); + + // Move the original block to the mapped region and emplace the conversion. + mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), + info.origBlock->getIterator()); + conversionInfo.insert({newBlock, std::move(info)}); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -642,11 +668,8 @@ void ConversionPatternRewriterImpl::applyRewrites() { // If this operation defines any regions, drop any pending argument // rewrites. - if (argConverter.typeConverter && repl.op->getNumRegions()) { - for (auto ®ion : repl.op->getRegions()) - for (auto &block : region) - argConverter.notifyBlockRemoved(&block); - } + if (argConverter.typeConverter && repl.op->getNumRegions()) + argConverter.notifyOpRemoved(repl.op); } // In a second pass, erase all of the replaced operations in reverse. This -- 2.7.4