From 9c29273ddc4666dd2dc1df53cc2901a59bad0b03 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 16 Aug 2019 10:16:09 -0700 Subject: [PATCH] Refactor DialectConversion to convert the signatures of blocks when they are moved. Often we want to ensure that block arguments are converted before operations that use them. This refactors the current implementation to be cleaner/less frequent by triggering conversion when a set of blocks are moved/inlined; or when legalization is successful. PiperOrigin-RevId: 263795005 --- mlir/lib/Transforms/DialectConversion.cpp | 74 +++++++++++++++++-------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 5a4145a..adbed67 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -499,11 +499,11 @@ void ConversionPatternRewriterImpl::applyRewrites() { LogicalResult ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { // Check to see if this block should not be converted: - // * The block is invalid, or there is no type converter. + // * There is no type converter. // * The block has already been converted. // * This is an entry block, these are converted explicitly via patterns. - if (!block || !argConverter.typeConverter || - argConverter.hasBeenConverted(block) || block->isEntryBlock()) + if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || + block->isEntryBlock()) return success(); // Otherwise, try to convert the block signature. @@ -738,10 +738,6 @@ bool OperationLegalizer::isIllegal(Operation *op) const { LogicalResult OperationLegalizer::legalize(Operation *op, ConversionPatternRewriter &rewriter) { - // Make sure that the signature of the parent block has been converted. - if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock()))) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() << "\n"); @@ -802,6 +798,24 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, return cleanupFailure(); } + // If the pattern moved any blocks, try to legalize their types. This ensures + // that the types of the block arguments are legal for the region they were + // moved into. + for (unsigned i = curState.numBlockActions, + e = rewriterImpl.blockActions.size(); + i != e; ++i) { + auto &action = rewriterImpl.blockActions[i]; + if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move) + continue; + + // Convert the block signature. + if (failed(rewriterImpl.convertBlockSignature(action.block))) { + LLVM_DEBUG(llvm::dbgs() + << "-- FAIL: failed to convert types of moved block.\n"); + return cleanupFailure(); + } + } + // Recursively legalize each of the new operations. for (unsigned i = curState.numCreatedOperations, e = rewriterImpl.createdOps.size(); @@ -958,9 +972,9 @@ enum OpConversionMode { Analysis, }; -// This class converts operations using the given pattern matcher. If a -// TypeConverter object is provided, then the types of block arguments will be -// converted using the appropriate 'convertType' calls. +// This class converts operations to a given conversion target via a set of +// rewrite patterns. The conversion behaves differently depending on the +// conversion mode. struct OperationConverter { explicit OperationConverter(ConversionTarget &target, const OwningRewritePatternList &patterns, @@ -981,8 +995,7 @@ private: LogicalResult computeConversionSet(Region ®ion, std::vector &toConvert); - /// Converts the type signatures of the blocks nested within 'op' that have - /// yet to be converted. + /// Converts the type signatures of the blocks nested within 'op'. LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, Operation *op); @@ -1001,18 +1014,14 @@ private: LogicalResult OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, Operation *op) { - SmallVector worklist; - for (auto ®ion : op->getRegions()) - worklist.push_back(®ion); + // Check to see if type signatures need to be converted. + if (!rewriter.getImpl().argConverter.typeConverter) + return success(); - while (!worklist.empty()) { - for (auto &block : *worklist.pop_back_val()) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) if (failed(rewriter.getImpl().convertBlockSignature(&block))) return failure(); - for (auto &nestedOp : block) - for (auto ®ion : nestedOp.getRegions()) - worklist.push_back(®ion); - } } return success(); } @@ -1065,10 +1074,17 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, return op->emitError() << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; - } else if (mode == OpConversionMode::Analysis) { - /// Analysis conversions don't fail if any operations fail to legalize, they - /// are only interested in the operations that were successfully legalized. - legalizableOps->insert(op); + } else { + /// Analysis conversions don't fail if any operations fail to legalize, + /// they are only interested in the operations that were successfully + /// legalized. + if (mode == OpConversionMode::Analysis) + legalizableOps->insert(op); + + // If legalization succeeded, convert the types any of the blocks within + // this operation. + if (failed(convertBlockSignatures(rewriter, op))) + return failure(); } return success(); } @@ -1094,14 +1110,6 @@ OperationConverter::convertOperations(ArrayRef ops, if (failed(convert(rewriter, op))) return rewriter.getImpl().discardRewrites(), failure(); - // If a type converter was provided, ensure that all blocks have had their - // signatures properly converted. - if (typeConverter) { - for (auto *op : ops) - if (failed(convertBlockSignatures(rewriter, op))) - return rewriter.getImpl().discardRewrites(), failure(); - } - // Otherwise, the body conversion succeeded. Apply rewrites if this is not an // analysis conversion. if (mode == OpConversionMode::Analysis) -- 2.7.4