From: River Riddle Date: Sun, 21 Jul 2019 02:05:41 +0000 (-0700) Subject: Refactor region type signature conversion to be explicit via patterns. X-Git-Tag: llvmorg-11-init~1466^2~1130 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=00bdc8e070ec54c5df3eb57c10eb830d2f369d5e;p=platform%2Fupstream%2Fllvm.git Refactor region type signature conversion to be explicit via patterns. This cl enforces that the conversion of the type signatures for regions, and thus their entry blocks, is handled via ConversionPatterns. A new hook 'applySignatureConversion' is added to the ConversionPatternRewriter to perform the desired conversion on a region. This also means that the handling of rewriting the signature of a FuncOp is moved to a pattern. A default implementation is provided via 'mlir::populateFuncOpTypeConversionPattern'. This removes the hacky implicit 'dynamically legal' status of FuncOp that was present previously, and leaves it up to the user to decide when/how to convert the signature of a function. PiperOrigin-RevId: 259161999 --- diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 67b0ac0..411a7af 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -425,7 +425,9 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) { ConversionTarget target(*module.getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); return applyFullConversion(module, target, std::move(patterns), &converter); } diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 68a48d6..8c77737 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -162,7 +162,9 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed( applyFullConversion(module, target, std::move(patterns), &converter))) return failure(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 8b80588..5a01122 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -355,12 +355,17 @@ struct LateLoweringPass : public ModulePass { RewriteListBuilder::build(toyPatterns, &getContext()); + mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(), + typeConverter); // Perform Toy specific lowering. ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()); + }); if (failed(applyPartialConversion( getModule(), target, std::move(toyPatterns), &typeConverter))) { emitError(UnknownLoc::get(getModule().getContext()), diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 864d805..d5c4c11 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -63,12 +63,6 @@ public: LLVM::LLVMDialect *getDialect() { return llvmDialect; } protected: - /// Convert function signatures to LLVM IR. In particular, convert functions - /// with multiple results into functions returning LLVM IR's structure type. - /// Use `convertType` to convert individual argument and result types. - LogicalResult convertSignature(FunctionType t, - SignatureConversion &result) final; - /// LLVM IR module used to parse/create types. llvm::Module *module; LLVM::LLVMDialect *llvmDialect; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5543c21..1ffd5bb 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -61,16 +61,8 @@ public: size_t inputNo, size; }; - /// Return the converted type signature. - FunctionType getConvertedType(MLIRContext *ctx) const { - return FunctionType::get(argTypes, resultTypes, ctx); - } - /// Return the argument types for the new signature. - ArrayRef getConvertedArgTypes() const { return argTypes; } - - /// Return the result types for the new signature. - ArrayRef getConvertedResultTypes() const { return resultTypes; } + ArrayRef getConvertedTypes() const { return argTypes; } /// Get the input mapping for the given argument. llvm::Optional getInputMapping(unsigned input) const { @@ -81,9 +73,6 @@ public: // Conversion Hooks //===------------------------------------------------------------------===// - /// Append new result types to the signature conversion. - void addResults(ArrayRef results); - /// Remap an input of the original signature with a new set of types. The /// new types are appended to the new signature conversion. void addInputs(unsigned origInputNo, ArrayRef types); @@ -101,8 +90,8 @@ public: /// The remapping information for each of the original arguments. SmallVector, 4> remappedInputs; - /// The set of argument and results types. - SmallVector argTypes, resultTypes; + /// The set of new argument types. + SmallVector argTypes; }; /// This hooks allows for converting a type. This function should return @@ -115,18 +104,19 @@ public: /// the type convert to on success, and a null type on failure. virtual Type convertType(Type t) { return t; } - /// Convert the given FunctionType signature. This functions returns a valid - /// SignatureConversion on success, None otherwise. - llvm::Optional convertSignature(FunctionType type); + /// Convert the given set of types, filling 'results' as necessary. This + /// returns failure if the conversion of any of the types fails, success + /// otherwise. + LogicalResult convertTypes(ArrayRef types, + SmallVectorImpl &results); + + /// Return true if the given type is legal for this type converter, i.e. the + /// type converts to itself. + bool isLegal(Type type); - /// This hook allows for changing a FunctionType signature. This function - /// should populate 'result' with the new arguments and results on success, - /// otherwise return failure. - /// - /// The default behavior of this function is to call 'convertType' on - /// individual function operands and results. - virtual LogicalResult convertSignature(FunctionType type, - SignatureConversion &result); + /// Return true if the inputs and outputs of the given function type are + /// legal. + bool isSignatureLegal(FunctionType funcType); /// This hook allows for converting a specific argument of a signature. It /// takes as inputs the original argument input number, type. @@ -134,22 +124,6 @@ public: virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result); - /// This hook allows for converting the signature of a region 'regionIdx', - /// i.e. the signature of the entry to the region, on the given operation - /// 'op'. This function should return a valid conversion for the signature on - /// success, None otherwise. This hook is allowed to modify the attributes on - /// the provided operation if necessary. - /// - /// The default behavior of this function is to invoke 'convertBlockSignature' - /// on the entry block, if one is present. This function also provides special - /// handling for FuncOp to update the type signature. - /// - /// TODO(riverriddle) This should be replaced in favor of using patterns, but - /// the pattern rewriter needs to know how to properly replace/remap - /// arguments. - virtual llvm::Optional - convertRegionSignature(Operation *op, unsigned regionIdx); - /// This function converts the type signature of the given block, by invoking /// 'convertSignatureArg' for each argument. This function should return a /// valid conversion for the signature on success, None otherwise. @@ -244,6 +218,12 @@ private: using RewritePattern::rewrite; }; +/// Add a pattern to the given pattern list to convert the signature of a FuncOp +/// with the given type converter. +void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); + //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// @@ -252,12 +232,24 @@ namespace detail { struct ConversionPatternRewriterImpl; } // end namespace detail -/// This class implements a pattern rewriter for use with ConversionPatterns. +/// This class implements a pattern rewriter for use with ConversionPatterns. It +/// extends the base PatternRewriter and provides special conversion specific +/// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter); ~ConversionPatternRewriter() override; + /// Apply a signature conversion to the entry block of the given region. + void applySignatureConversion(Region *region, + TypeConverter::SignatureConversion &conversion); + + /// Clone the given operation without cloning its regions. + Operation *cloneWithoutRegions(Operation *op); + template OpT cloneWithoutRegions(OpT op) { + return cast(cloneWithoutRegions(op.getOperation())); + } + //===--------------------------------------------------------------------===// // PatternRewriter Hooks //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 042e768..c17909b 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -266,6 +266,44 @@ protected: LLVM::LLVMDialect &dialect; }; +struct FuncOpConversion : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + FunctionType type = funcOp.getType(); + + // Convert the original function arguments. + TypeConverter::SignatureConversion result(type.getNumInputs()); + for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + if (failed(lowering.convertSignatureArg(i, type.getInput(i), result))) + return matchFailure(); + + // Pack the result types into a struct. + Type packedResult; + if (type.getNumResults() != 0) { + if (!(packedResult = lowering.packFunctionResults(type.getResults()))) + return matchFailure(); + } + + // Create a new function with an updated signature. + auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + newFuncOp.setType(FunctionType::get( + result.getConvertedTypes(), + packedResult ? ArrayRef(packedResult) : llvm::None, + funcOp.getContext())); + + // Tell the rewriter to convert the region signature. + rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + rewriter.replaceOp(op, llvm::None); + return matchSuccess(); + } +}; + // Basic lowering implementation for one-to-one rewriting from Standard Ops to // LLVM Dialect Ops. template @@ -985,10 +1023,10 @@ void mlir::populateStdToLLVMConversionPatterns( BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering, - IndexCastOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, - MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, - RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering, - SubFOpLowering, SubIOpLowering, + FuncOpConversion, IndexCastOpLowering, LoadOpLowering, + MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, + RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, + SelectOpLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(), converter); } @@ -1014,27 +1052,6 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } -// Convert function signatures using the stored LLVM IR module. -LogicalResult LLVMTypeConverter::convertSignature(FunctionType type, - SignatureConversion &result) { - // Convert the original function arguments. - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (failed(convertSignatureArg(i, type.getInput(i), result))) - return failure(); - - // If function does not return anything, return immediately. - if (type.getNumResults() == 0) - return success(); - - // Otherwise pack the result types into a struct. - if (auto packedRet = packFunctionResults(type.getResults())) { - result.addResults(packedRet); - return success(); - } - - return failure(); -} - /// Create an instance of LLVMTypeConverter in the given context. static std::unique_ptr makeStandardToLLVMTypeConverter(MLIRContext *context) { @@ -1071,6 +1088,9 @@ struct LLVMLoweringPass : public ModulePass { ConversionTarget target(getContext()); target.addLegalDialect(); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter->isSignatureLegal(op.getType()); + }); if (failed(applyPartialConversion(m, target, std::move(patterns), typeConverter.get()))) signalPassFailure(); diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 98be230..b6bfa58 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -774,6 +774,8 @@ void LowerLinalgToLLVMPass::runOnModule() { ConversionTarget target(getContext()); target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns), &converter))) { signalPassFailure(); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 02ca31f..aac2e11 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -61,9 +61,6 @@ struct ArgConverter { /// Return if the signature of the given block has already been converted. bool hasBeenConverted(Block *block) const { return argMapping.count(block); } - /// Attempt to convert the signature of the given region. - LogicalResult convertSignature(Region ®ion, BlockAndValueMapping &mapping); - /// Attempt to convert the signature of the given block. LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping); @@ -196,23 +193,10 @@ void ArgConverter::applyRewrites() { } } -/// Converts the signature of the given region. -LogicalResult ArgConverter::convertSignature(Region ®ion, - BlockAndValueMapping &mapping) { - if (auto conversion = typeConverter->convertRegionSignature( - region.getContainingOp(), region.getRegionNumber())) { - if (!region.empty()) - applySignatureConversion(®ion.front(), *conversion, mapping); - return success(); - } - return failure(); -} - /// Converts the signature of the given entry block. LogicalResult ArgConverter::convertSignature(Block *block, BlockAndValueMapping &mapping) { - auto conversion = typeConverter->convertBlockSignature(block); - if (conversion) + if (auto conversion = typeConverter->convertBlockSignature(block)) return applySignatureConversion(block, *conversion, mapping), success(); return failure(); } @@ -222,7 +206,7 @@ void ArgConverter::applySignatureConversion( Block *block, TypeConverter::SignatureConversion &signatureConversion, BlockAndValueMapping &mapping) { unsigned origArgCount = block->getNumArguments(); - auto convertedTypes = signatureConversion.getConvertedArgTypes(); + auto convertedTypes = signatureConversion.getConvertedTypes(); if (origArgCount == 0 && convertedTypes.empty()) return; @@ -292,10 +276,9 @@ namespace { /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numCreatedOperations, unsigned numReplacements, - unsigned numBlockActions, unsigned numTypeConversions) + unsigned numBlockActions) : numCreatedOperations(numCreatedOperations), - numReplacements(numReplacements), numBlockActions(numBlockActions), - numTypeConversions(numTypeConversions) {} + numReplacements(numReplacements), numBlockActions(numBlockActions) {} /// The current number of created operations. unsigned numCreatedOperations; @@ -305,9 +288,6 @@ struct RewriterState { /// The current number of block actions performed. unsigned numBlockActions; - - /// The current number of type conversion actions performed. - unsigned numTypeConversions; }; } // end anonymous namespace @@ -326,7 +306,7 @@ struct ConversionPatternRewriterImpl { /// The kind of the block action performed during the rewrite. Actions can be /// undone if the conversion fails. - enum class BlockActionKind { Split, Move }; + enum class BlockActionKind { Split, Move, TypeConversion }; /// Original position of the given block in its parent region. We cannot use /// a region iterator because it could have been invalidated by other region @@ -339,6 +319,21 @@ struct ConversionPatternRewriterImpl { /// The storage class for an undoable block action (one of BlockActionKind), /// contains the information necessary to undo this action. struct BlockAction { + static BlockAction getSplit(Block *block, Block *originalBlock) { + BlockAction action{BlockActionKind::Split, block}; + action.originalBlock = originalBlock; + return action; + } + static BlockAction getMove(Block *block, BlockPosition originalPos) { + return {BlockActionKind::Move, block, {originalPos}}; + } + static BlockAction getTypeConversion(Block *block) { + return BlockAction{BlockActionKind::TypeConversion, block}; + } + + // The action kind. + BlockActionKind kind; + // A pointer to the block that was created by the action. Block *block; @@ -351,18 +346,6 @@ struct ConversionPatternRewriterImpl { // block that was split into two parts. Block *originalBlock; }; - - BlockActionKind kind; - }; - - /// A storage class representing a type conversion of a block or a region. - struct TypeConversion { - /// The region, or block, that had its types converted. - llvm::PointerUnion object; - - /// If the object is a region, this corresponds to the original attributes - /// of the parent operation. - NamedAttributeList originalParentAttributes; }; ConversionPatternRewriterImpl(PatternRewriter &rewriter, @@ -379,10 +362,6 @@ struct ConversionPatternRewriterImpl { /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); - /// Undo the type conversion actions one by one, until "numActionsToKeep" - /// actions remain. - void undoTypeConversions(unsigned numActionsToKeep = 0); - /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. void discardRewrites(); @@ -391,15 +370,13 @@ struct ConversionPatternRewriterImpl { /// conversion process succeeds. void applyRewrites(); - /// Return if the given block has already been converted. - bool hasSignatureBeenConverted(Block *block); - - /// Convert the signature of the given region. - LogicalResult convertRegionSignature(Region ®ion); - /// Convert the signature of the given block. LogicalResult convertBlockSignature(Block *block); + /// Apply a signature conversion on the given region. + void applySignatureConversion(Region *region, + TypeConverter::SignatureConversion &conversion); + /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead); @@ -430,21 +407,17 @@ struct ConversionPatternRewriterImpl { /// Ordered list of block operations (creations, splits, motions). SmallVector blockActions; - - /// Ordered list of type conversion actions. - SmallVector typeConversions; }; } // end namespace detail } // end namespace mlir RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), replacements.size(), - blockActions.size(), typeConversions.size()); + blockActions.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Undo any type conversions or block actions. - undoTypeConversions(state.numTypeConversions); + // Undo any block actions. undoBlockActions(state.numBlockActions); // Reset any replaced operations and undo any saved mappings. @@ -478,24 +451,17 @@ void ConversionPatternRewriterImpl::undoBlockActions( action.block->getParent()->getBlocks(), action.block); break; } + // Undo the type conversion. + case BlockActionKind::TypeConversion: { + argConverter.discardPendingRewrites(action.block); + break; + } } } blockActions.resize(numActionsToKeep); } -void ConversionPatternRewriterImpl::undoTypeConversions( - unsigned numActionsToKeep) { - for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) { - if (auto *region = conversion.object.dyn_cast()) - region->getContainingOp()->setAttrs(conversion.originalParentAttributes); - else - argConverter.discardPendingRewrites(conversion.object.get()); - } - typeConversions.resize(numActionsToKeep); -} - void ConversionPatternRewriterImpl::discardRewrites() { - undoTypeConversions(); undoBlockActions(); // Remove any newly created ops. @@ -530,29 +496,30 @@ void ConversionPatternRewriterImpl::applyRewrites() { argConverter.applyRewrites(); } -bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) { - return argConverter.hasBeenConverted(block); -} - LogicalResult -ConversionPatternRewriterImpl::convertRegionSignature(Region ®ion) { - auto parentAttrs = region.getContainingOp()->getAttrList(); - auto result = argConverter.convertSignature(region, mapping); - if (succeeded(result)) { - typeConversions.push_back(TypeConversion{®ion, parentAttrs}); - if (!region.empty()) - typeConversions.push_back( - TypeConversion{®ion.front(), NamedAttributeList()}); - } - return result; +ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { + // Check to see if this block should not be converted: + // * The block is invalid, or there is no type converter. + // * The block has already been converted. + // * This is an entry block, these are converted explicitly via patterns. + if (!block || !argConverter.typeConverter || + argConverter.hasBeenConverted(block) || block->isEntryBlock()) + return success(); + + // Otherwise, try to convert the block signature. + if (failed(argConverter.convertSignature(block, mapping))) + return failure(); + blockActions.push_back(BlockAction::getTypeConversion(block)); + return success(); } -LogicalResult -ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { - auto result = argConverter.convertSignature(block, mapping); - if (succeeded(result)) - typeConversions.push_back(TypeConversion{block, NamedAttributeList()}); - return result; +void ConversionPatternRewriterImpl::applySignatureConversion( + Region *region, TypeConverter::SignatureConversion &conversion) { + if (!region->empty()) { + argConverter.applySignatureConversion(®ion->front(), conversion, + mapping); + blockActions.push_back(BlockAction::getTypeConversion(®ion->front())); + } } void ConversionPatternRewriterImpl::replaceOp( @@ -574,11 +541,7 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { - BlockAction action; - action.kind = BlockActionKind::Split; - action.block = continuation; - action.originalBlock = block; - blockActions.push_back(action); + blockActions.push_back(BlockAction::getSplit(continuation, block)); } void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( @@ -586,11 +549,7 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( for (auto &pair : llvm::enumerate(region)) { Block &block = pair.value(); unsigned position = pair.index(); - BlockAction action; - action.kind = BlockActionKind::Move; - action.block = █ - action.originalPosition = {®ion, position}; - blockActions.push_back(action); + blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); } } @@ -618,6 +577,19 @@ void ConversionPatternRewriter::replaceOp( impl->replaceOp(op, newValues, valuesToRemoveIfDead); } +/// Apply a signature conversion to the entry block of the given region. +void ConversionPatternRewriter::applySignatureConversion( + Region *region, TypeConverter::SignatureConversion &conversion) { + impl->applySignatureConversion(region, conversion); +} + +/// Clone the given operation without cloning its regions. +Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) { + Operation *newOp = OpBuilder::cloneWithoutRegions(*op); + impl->createdOps.push_back(newOp); + return newOp; +} + /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { @@ -766,18 +738,9 @@ bool OperationLegalizer::isIllegal(Operation *op) const { LogicalResult OperationLegalizer::legalize(Operation *op, ConversionPatternRewriter &rewriter) { - // Make sure that the signature of the parent block of this operation has been - // converted. - auto &rewriterImpl = rewriter.getImpl(); - if (rewriterImpl.argConverter.typeConverter) { - auto *block = op->getBlock(); - if (block && !rewriterImpl.hasSignatureBeenConverted(block)) { - if (failed(block->isEntryBlock() - ? rewriterImpl.convertRegionSignature(*block->getParent()) - : rewriterImpl.convertBlockSignature(block))) - return failure(); - } - } + // Make sure that the signature of the parent block has been converted. + if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock()))) + return failure(); LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() << "\n"); @@ -1008,11 +971,15 @@ private: /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); - /// Recursively collect all of the operations, to convert from within - /// 'region'. + /// Recursively collect all of the operations to convert from within 'region'. LogicalResult computeConversionSet(Region ®ion, std::vector &toConvert); + /// Converts the type signatures of the blocks nested within 'op' that have + /// yet to be converted. + LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, + Operation *op); + /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1021,7 +988,25 @@ private: }; } // end anonymous namespace -/// Recursively collect all of the blocks to convert from within 'region'. +LogicalResult +OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, + Operation *op) { + SmallVector worklist; + for (auto ®ion : op->getRegions()) + worklist.push_back(®ion); + + while (!worklist.empty()) { + for (auto &block : *worklist.pop_back_val()) { + if (failed(rewriter.getImpl().convertBlockSignature(&block))) + return failure(); + for (auto &nestedOp : block) + for (auto ®ion : nestedOp.getRegions()) + worklist.push_back(®ion); + } + } + return success(); +} + LogicalResult OperationConverter::computeConversionSet(Region ®ion, std::vector &toConvert) { @@ -1055,7 +1040,6 @@ OperationConverter::computeConversionSet(Region ®ion, return success(); } -/// Converts an operation with the given rewriter. LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { // Legalize the given operation. @@ -1072,23 +1056,9 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; } - - // Convert the signature of any empty regions of this operation, non-empty - // regions are converted on demand when converting any operations contained - // within. - // FIXME(riverriddle) This should be replaced by patterns when the pattern - // rewriter exposes functionality to remap region signatures. - auto &rewriterImpl = rewriter.getImpl(); - if (rewriterImpl.argConverter.typeConverter) { - for (auto ®ion : op->getRegions()) - if (region.empty() && failed(rewriterImpl.convertRegionSignature(region))) - return failure(); - } - return success(); } -/// Converts the given operations to the conversion target. LogicalResult OperationConverter::convertOperations(ArrayRef ops, TypeConverter *typeConverter) { @@ -1106,11 +1076,16 @@ OperationConverter::convertOperations(ArrayRef ops, // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); - for (auto *op : toConvert) { - if (failed(convert(rewriter, op))) { - rewriter.getImpl().discardRewrites(); - return failure(); - } + for (auto *op : toConvert) + if (failed(convert(rewriter, op))) + return rewriter.getImpl().discardRewrites(), failure(); + + // If a type converter was provided, ensure that all blocks have had their + // signatures properly converted. + if (typeConverter) { + for (auto *op : ops) + if (failed(convertBlockSignatures(rewriter, op))) + return rewriter.getImpl().discardRewrites(), failure(); } // Otherwise the body conversion succeeded, so apply all rewrites. @@ -1122,11 +1097,6 @@ OperationConverter::convertOperations(ArrayRef ops, // Type Conversion //===----------------------------------------------------------------------===// -/// Append new result types to the signature conversion. -void TypeConverter::SignatureConversion::addResults(ArrayRef results) { - resultTypes.append(results.begin(), results.end()); -} - /// Remap an input of the original signature with a new set of types. The /// new types are appended to the new signature conversion. void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, @@ -1164,33 +1134,31 @@ LogicalResult TypeConverter::convertType(Type t, return failure(); } -/// Convert the given FunctionType signature. -auto TypeConverter::convertSignature(FunctionType type) - -> llvm::Optional { - SignatureConversion result(type.getNumInputs()); - if (failed(convertSignature(type, result))) - return llvm::None; - return result; -} - -/// This hook allows for changing a FunctionType signature. -LogicalResult TypeConverter::convertSignature(FunctionType type, - SignatureConversion &result) { - // Convert the original function arguments. - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (failed(convertSignatureArg(i, type.getInput(i), result))) +/// Convert the given set of types, filling 'results' as necessary. This +/// returns failure if the conversion of any of the types fails, success +/// otherwise. +LogicalResult TypeConverter::convertTypes(ArrayRef types, + SmallVectorImpl &results) { + for (auto type : types) + if (failed(convertType(type, results))) return failure(); + return success(); +} - // Convert the original function results. - SmallVector convertedTypes; - for (auto t : type.getResults()) { - convertedTypes.clear(); - if (failed(convertType(t, convertedTypes))) - return failure(); - result.addResults(convertedTypes); - } +/// Return true if the given type is legal for this type converter, i.e. the +/// type converts to itself. +bool TypeConverter::isLegal(Type type) { + SmallVector results; + return succeeded(convertType(type, results)) && results.size() == 1 && + results.front() == type; +} - return success(); +/// Return true if the inputs and outputs of the given function type are +/// legal. +bool TypeConverter::isSignatureLegal(FunctionType funcType) { + return llvm::all_of( + llvm::concat(funcType.getInputs(), funcType.getResults()), + [this](Type type) { return isLegal(type); }); } /// This hook allows for converting a specific argument of a signature. @@ -1210,34 +1178,55 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, return success(); } -/// This hook defines how the signature of a region 'regionIdx', i.e. the -/// signature of the entry to the region, on the given operation 'op' is -/// converted. This function should return a valid conversion for the signature -/// on success, None otherwise. -/// -/// The default behavior of this function is to invoke 'convertBlockSignature' -/// on the entry block, if one is present. This function also provides special -/// handling for FuncOp to update the type signature. -/// -/// TODO(riverriddle) This should be replaced in favor of using patterns, but -/// the pattern rewriter needs to know how to properly replace/remap -/// arguments. -auto TypeConverter::convertRegionSignature(Operation *op, unsigned regionIdx) - -> llvm::Optional { - // Provide explicit handling for FuncOp. - if (auto funcOp = dyn_cast(op)) { - auto conversion = convertSignature(funcOp.getType()); - if (conversion) - funcOp.setType(conversion->getConvertedType(funcOp.getContext())); - return conversion; +/// Create a default conversion pattern that rewrites the type signature of a +/// FuncOp. +namespace { +struct FuncOpSignatureConversion : public ConversionPattern { + FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(FuncOp::getOperationName(), 1, ctx), + converter(converter) {} + + /// Hook for derived classes to implement combined matching and rewriting. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + FunctionType type = funcOp.getType(); + + // Convert the original function arguments. + TypeConverter::SignatureConversion result(type.getNumInputs()); + for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) + return matchFailure(); + + // Convert the original function results. + SmallVector convertedResults; + if (failed(converter.convertTypes(type.getResults(), convertedResults))) + return matchFailure(); + + // Create a new function with an updated signature. + auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + newFuncOp.setType(FunctionType::get(result.getConvertedTypes(), + convertedResults, funcOp.getContext())); + + // Tell the rewriter to convert the region signature. + rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } - // Otherwise, default to handle the arguments of the entry block for the given - // region. - auto ®ion = op->getRegion(regionIdx); - if (region.empty()) - return SignatureConversion(/*numOrigInputs=*/0); - return convertBlockSignature(®ion.front()); + /// The type converter to use when rewriting the signature. + TypeConverter &converter; +}; +} // end anonymous namespace + +void mlir::populateFuncOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + RewriteListBuilder::build(patterns, ctx, + converter); } /// This function converts the type signature of the given block, by invoking diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 9fb6cc9..c44e489 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -49,13 +49,13 @@ func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) { "test.invalid"(%arg0, %arg1) : (i64, i64) -> () } -// CHECK-LABEL: func @remap_nested -func @remap_nested() { +// CHECK-LABEL: func @no_remap_nested +func @no_remap_nested() { // CHECK-NEXT: "foo.region" "foo.region"() ({ - // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: f64): + // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64): ^bb0(%i0: i64, %unused: i16, %i1: i64): - // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64) + // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64) "test.invalid"(%i0, %i1) : (i64, i64) -> () }) : () -> () return diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index d452edb..1cbd253 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -185,23 +185,26 @@ struct TestTypeConverter : public TypeConverter { struct TestLegalizePatternDriver : public ModulePass { void runOnModule() override { + TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); RewriteListBuilder::build(patterns, &getContext()); + mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), + converter); // Define the conversion target used for the test. ConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) { // Don't allow F32 operands. return llvm::none_of(op.getOperandTypes(), [](Type type) { return type.isF32(); }); }); - - TestTypeConverter converter; + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); (void)applyPartialConversion(getModule(), target, std::move(patterns), &converter); }