From e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 24 Jun 2020 17:34:16 -0700 Subject: [PATCH] [mlir][DialectConversion][NFC] Add comment blocks and organize a bit of the code This helps improve the readability when scrolling through the many functions of ConversionPatternRewriterImpl. --- mlir/lib/Transforms/DialectConversion.cpp | 372 ++++++++++++++++-------------- 1 file changed, 203 insertions(+), 169 deletions(-) diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ecebe61..60c9e78 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -450,7 +450,7 @@ void ArgConverter::insertConversion(Block *newBlock, } //===----------------------------------------------------------------------===// -// ConversionPatternRewriterImpl +// Rewriter and Transation State //===----------------------------------------------------------------------===// namespace { /// This class contains a snapshot of the current conversion rewriter state. @@ -515,74 +515,89 @@ private: SmallVector operands; SmallVector successors; }; -} // end anonymous namespace -namespace mlir { -namespace detail { -struct ConversionPatternRewriterImpl { - /// This class represents one requested operation replacement via 'replaceOp'. - struct OpReplacement { - OpReplacement() = default; - OpReplacement(Operation *op, ValueRange newValues) - : op(op), newValues(newValues.begin(), newValues.end()) {} - - Operation *op; - SmallVector newValues; - }; +/// This class represents one requested operation replacement via 'replaceOp'. +struct OpReplacement { + OpReplacement() = default; + OpReplacement(Operation *op, ValueRange newValues) + : op(op), newValues(newValues.begin(), newValues.end()) {} - /// The kind of the block action performed during the rewrite. Actions can be - /// undone if the conversion fails. - enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion }; + Operation *op; + SmallVector newValues; +}; - /// Original position of the given block in its parent region. We cannot use - /// a region iterator because it could have been invalidated by other region - /// operations since the position was stored. - struct BlockPosition { - Region *region; - Region::iterator::difference_type position; - }; +/// The kind of the block action performed during the rewrite. Actions can be +/// undone if the conversion fails. +enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion }; - /// The storage class for an undoable block action (one of BlockActionKind), - /// contains the information necessary to undo this action. - struct BlockAction { - static BlockAction getCreate(Block *block) { - return {BlockActionKind::Create, block, {}}; - } - static BlockAction getErase(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Erase, block, {originalPos}}; - } - static BlockAction getMove(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Move, block, {originalPos}}; - } - static BlockAction getSplit(Block *block, Block *originalBlock) { - BlockAction action{BlockActionKind::Split, block, {}}; - action.originalBlock = originalBlock; - return action; - } - static BlockAction getTypeConversion(Block *block) { - return BlockAction{BlockActionKind::TypeConversion, block, {}}; - } +/// Original position of the given block in its parent region. We cannot use +/// a region iterator because it could have been invalidated by other region +/// operations since the position was stored. +struct BlockPosition { + Region *region; + Region::iterator::difference_type position; +}; - // The action kind. - BlockActionKind kind; - - // A pointer to the block that was created by the action. - Block *block; - - union { - // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and - // contains a pointer to the region that originally contained the block as - // well as the position of the block in that region. - BlockPosition originalPosition; - // In use if kind == BlockActionKind::Split and contains a pointer to the - // block that was split into two parts. - Block *originalBlock; - }; +/// The storage class for an undoable block action (one of BlockActionKind), +/// contains the information necessary to undo this action. +struct BlockAction { + static BlockAction getCreate(Block *block) { + return {BlockActionKind::Create, block, {}}; + } + static BlockAction getErase(Block *block, BlockPosition originalPos) { + return {BlockActionKind::Erase, block, {originalPos}}; + } + static BlockAction getMove(Block *block, BlockPosition originalPos) { + return {BlockActionKind::Move, block, {originalPos}}; + } + static BlockAction getSplit(Block *block, Block *originalBlock) { + BlockAction action{BlockActionKind::Split, block, {}}; + action.originalBlock = originalBlock; + return action; + } + static BlockAction getTypeConversion(Block *block) { + return BlockAction{BlockActionKind::TypeConversion, block, {}}; + } + + // The action kind. + BlockActionKind kind; + + // A pointer to the block that was created by the action. + Block *block; + + union { + // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and + // contains a pointer to the region that originally contained the block as + // well as the position of the block in that region. + BlockPosition originalPosition; + // In use if kind == BlockActionKind::Split and contains a pointer to the + // block that was split into two parts. + Block *originalBlock; }; +}; +} // end anonymous namespace +//===----------------------------------------------------------------------===// +// ConversionPatternRewriterImpl +//===----------------------------------------------------------------------===// +namespace mlir { +namespace detail { +struct ConversionPatternRewriterImpl { ConversionPatternRewriterImpl(PatternRewriter &rewriter) : argConverter(rewriter) {} + /// Cleanup and destroy any generated rewrite operations. This method is + /// invoked when the conversion process fails. + void discardRewrites(); + + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites(); + + //===--------------------------------------------------------------------===// + // State Management + //===--------------------------------------------------------------------===// + /// Return the current state of the rewriter. RewriterState getCurrentState(); @@ -597,13 +612,21 @@ struct ConversionPatternRewriterImpl { /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); - /// Cleanup and destroy any generated rewrite operations. This method is - /// invoked when the conversion process fails. - void discardRewrites(); + /// Remap the given operands to those with potentially different types. + void remapValues(Operation::operand_range operands, + SmallVectorImpl &remapped); - /// Apply all requested operation rewrites. This method is invoked when the - /// conversion process succeeds. - void applyRewrites(); + /// Returns true if the given operation is ignored, and does not need to be + /// converted. + bool isOpIgnored(Operation *op) const; + + /// Recursively marks the nested operations under 'op' as ignored. This + /// removes them from being considered for legalization. + void markNestedOpsIgnored(Operation *op); + + //===--------------------------------------------------------------------===// + // Type Conversion + //===--------------------------------------------------------------------===// /// Convert the signature of the given block. FailureOr convertBlockSignature( @@ -620,8 +643,12 @@ struct ConversionPatternRewriterImpl { convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); + //===--------------------------------------------------------------------===// + // Rewriter Notification Hooks + //===--------------------------------------------------------------------===// + /// PatternRewriter hook for replacing the results of an operation. - void replaceOp(Operation *op, ValueRange newValues); + void notifyOpReplaced(Operation *op, ValueRange newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -640,17 +667,9 @@ struct ConversionPatternRewriterImpl { void notifyRegionWasClonedBefore(iterator_range &blocks, Location origRegionLoc); - /// Remap the given operands to those with potentially different types. - void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped); - - /// Returns true if the given operation is ignored, and does not need to be - /// converted. - bool isOpIgnored(Operation *op) const; - - /// Recursively marks the nested operations under 'op' as ignored. This - /// removes them from being considered for legalization. - void markNestedOpsIgnored(Operation *op); + //===--------------------------------------------------------------------===// + // State + //===--------------------------------------------------------------------===// // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. @@ -700,12 +719,6 @@ struct ConversionPatternRewriterImpl { } // end namespace detail } // end namespace mlir -RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(createdOps.size(), replacements.size(), - argReplacements.size(), blockActions.size(), - ignoredOps.size(), rootUpdates.size()); -} - /// Detach any operations nested in the given operation from their parent /// blocks, and erase the given operation. This can be used when the nested /// operations are scheduled for erasure themselves, so deleting the regions of @@ -722,6 +735,73 @@ static void detachNestedAndErase(Operation *op) { op->erase(); } +void ConversionPatternRewriterImpl::discardRewrites() { + // Reset any operations that were updated in place. + for (auto &state : rootUpdates) + state.resetOperation(); + + undoBlockActions(); + + // Remove any newly created ops. + for (auto *op : llvm::reverse(createdOps)) + detachNestedAndErase(op); +} + +void ConversionPatternRewriterImpl::applyRewrites() { + // Apply all of the rewrites replacements requested during conversion. + for (auto &repl : replacements) { + for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { + if (auto newValue = repl.newValues[i]) + repl.op->getResult(i).replaceAllUsesWith( + mapping.lookupOrDefault(newValue)); + } + + // If this operation defines any regions, drop any pending argument + // rewrites. + if (repl.op->getNumRegions()) + argConverter.notifyOpRemoved(repl.op); + } + + // Apply all of the requested argument replacements. + for (BlockArgument arg : argReplacements) { + Value repl = mapping.lookupOrDefault(arg); + if (repl.isa()) { + arg.replaceAllUsesWith(repl); + continue; + } + + // If the replacement value is an operation, we check to make sure that we + // don't replace uses that are within the parent operation of the + // replacement value. + Operation *replOp = repl.cast().getOwner(); + Block *replBlock = replOp->getBlock(); + arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + }); + } + + // In a second pass, erase all of the replaced operations in reverse. This + // allows processing nested operations before their parent region is + // destroyed. + for (auto &repl : llvm::reverse(replacements)) + repl.op->erase(); + + argConverter.applyRewrites(mapping); + + // Now that the ops have been erased, also erase dangling blocks. + eraseDanglingBlocks(); +} + +//===----------------------------------------------------------------------===// +// State Management + +RewriterState ConversionPatternRewriterImpl::getCurrentState() { + return RewriterState(createdOps.size(), replacements.size(), + argReplacements.size(), blockActions.size(), + ignoredOps.size(), rootUpdates.size()); +} + void ConversionPatternRewriterImpl::resetState(RewriterState state) { // Reset any operations that were updated in place. for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) @@ -810,64 +890,34 @@ void ConversionPatternRewriterImpl::undoBlockActions( blockActions.resize(numActionsToKeep); } -void ConversionPatternRewriterImpl::discardRewrites() { - // Reset any operations that were updated in place. - for (auto &state : rootUpdates) - state.resetOperation(); - - undoBlockActions(); - - // Remove any newly created ops. - for (auto *op : llvm::reverse(createdOps)) - detachNestedAndErase(op); +void ConversionPatternRewriterImpl::remapValues( + Operation::operand_range operands, SmallVectorImpl &remapped) { + remapped.reserve(llvm::size(operands)); + for (Value operand : operands) + remapped.push_back(mapping.lookupOrDefault(operand)); } -void ConversionPatternRewriterImpl::applyRewrites() { - // Apply all of the rewrites replacements requested during conversion. - for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { - if (auto newValue = repl.newValues[i]) - repl.op->getResult(i).replaceAllUsesWith( - mapping.lookupOrDefault(newValue)); - } - - // If this operation defines any regions, drop any pending argument - // rewrites. - if (repl.op->getNumRegions()) - argConverter.notifyOpRemoved(repl.op); - } - - // Apply all of the requested argument replacements. - for (BlockArgument arg : argReplacements) { - Value repl = mapping.lookupOrDefault(arg); - if (repl.isa()) { - arg.replaceAllUsesWith(repl); - continue; - } - - // If the replacement value is an operation, we check to make sure that we - // don't replace uses that are within the parent operation of the - // replacement value. - Operation *replOp = repl.cast().getOwner(); - Block *replBlock = replOp->getBlock(); - arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { - Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); - }); - } - - // In a second pass, erase all of the replaced operations in reverse. This - // allows processing nested operations before their parent region is - // destroyed. - for (auto &repl : llvm::reverse(replacements)) - repl.op->erase(); - - argConverter.applyRewrites(mapping); +bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { + // Check to see if this operation or its parent were ignored. + return ignoredOps.count(op) || ignoredOps.count(op->getParentOp()); +} - // Now that the ops have been erased, also erase dangling blocks. - eraseDanglingBlocks(); +void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { + // Walk this operation and collect nested operations that define non-empty + // regions. We mark such operations as 'ignored' so that we know we don't have + // to convert them, or their nested ops. + if (op->getNumRegions() == 0) + return; + op->walk([&](Operation *op) { + if (llvm::any_of(op->getRegions(), + [](Region ®ion) { return !region.empty(); })) + ignoredOps.insert(op); + }); } +//===----------------------------------------------------------------------===// +// Type Conversion + FailureOr ConversionPatternRewriterImpl::convertBlockSignature( Block *block, TypeConverter &converter, TypeConverter::SignatureConversion *conversion) { @@ -907,8 +957,11 @@ FailureOr ConversionPatternRewriterImpl::convertRegionTypes( return newEntry; } -void ConversionPatternRewriterImpl::replaceOp(Operation *op, - ValueRange newValues) { +//===----------------------------------------------------------------------===// +// Rewriter Notification Hooks + +void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, + ValueRange newValues) { assert(newValues.size() == op->getNumResults()); // Create mappings for each of the new result values. @@ -962,31 +1015,6 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( assert(succeeded(result) && "expected region to have no unreachable blocks"); } -void ConversionPatternRewriterImpl::remapValues( - Operation::operand_range operands, SmallVectorImpl &remapped) { - remapped.reserve(llvm::size(operands)); - for (Value operand : operands) - remapped.push_back(mapping.lookupOrDefault(operand)); -} - -bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation or its parent were ignored. - return ignoredOps.count(op) || ignoredOps.count(op->getParentOp()); -} - -void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { - // Walk this operation and collect nested operations that define non-empty - // regions. We mark such operations as 'ignored' so that we know we don't have - // to convert them, or their nested ops. - if (op->getNumRegions() == 0) - return; - op->walk([&](Operation *op) { - if (llvm::any_of(op->getRegions(), - [](Region ®ion) { return !region.empty(); })) - ignoredOps.insert(op); - }); -} - //===----------------------------------------------------------------------===// // ConversionPatternRewriter //===----------------------------------------------------------------------===// @@ -1002,7 +1030,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - impl->replaceOp(op, newValues); + impl->notifyOpReplaced(op, newValues); } /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -1014,7 +1042,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector nullRepls(op->getNumResults(), nullptr); - impl->replaceOp(op, nullRepls); + impl->notifyOpReplaced(op, nullRepls); } void ConversionPatternRewriter::eraseBlock(Block *block) { @@ -1160,7 +1188,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { } //===----------------------------------------------------------------------===// -// Conversion Patterns +// ConversionPattern //===----------------------------------------------------------------------===// /// Attempt to match and rewrite the IR root at the specified operation. @@ -1234,6 +1262,10 @@ private: RewriterState &state, RewriterState &newState); + //===--------------------------------------------------------------------===// + // Cost Model + //===--------------------------------------------------------------------===// + /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with /// patterns for operations that are not directly legal, but may be @@ -1528,9 +1560,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions( for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; ++i) { auto &action = impl.blockActions[i]; - if (action.kind == - ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || - action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) + if (action.kind == BlockActionKind::TypeConversion || + action.kind == BlockActionKind::Erase) continue; // Only check blocks outside of the current operation. Operation *parentOp = action.block->getParentOp(); @@ -1599,6 +1630,9 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates( return success(); } +//===----------------------------------------------------------------------===// +// Cost Model + void OperationLegalizer::buildLegalizationGraph( LegalizationPatterns &anyOpLegalizerPatterns, DenseMap &legalizerPatterns) { -- 2.7.4