size_t inputNo, size;
};
- /// Return the converted type signature.
- FunctionType getConvertedType(MLIRContext *ctx) const {
- return FunctionType::get(argTypes, resultTypes, ctx);
- }
-
/// Return the argument types for the new signature.
- ArrayRef<Type> getConvertedArgTypes() const { return argTypes; }
-
- /// Return the result types for the new signature.
- ArrayRef<Type> getConvertedResultTypes() const { return resultTypes; }
+ ArrayRef<Type> getConvertedTypes() const { return argTypes; }
/// Get the input mapping for the given argument.
llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
// Conversion Hooks
//===------------------------------------------------------------------===//
- /// Append new result types to the signature conversion.
- void addResults(ArrayRef<Type> results);
-
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
void addInputs(unsigned origInputNo, ArrayRef<Type> types);
/// The remapping information for each of the original arguments.
SmallVector<llvm::Optional<InputMapping>, 4> remappedInputs;
- /// The set of argument and results types.
- SmallVector<Type, 4> argTypes, resultTypes;
+ /// The set of new argument types.
+ SmallVector<Type, 4> argTypes;
};
/// This hooks allows for converting a type. This function should return
/// the type convert to on success, and a null type on failure.
virtual Type convertType(Type t) { return t; }
- /// Convert the given FunctionType signature. This functions returns a valid
- /// SignatureConversion on success, None otherwise.
- llvm::Optional<SignatureConversion> convertSignature(FunctionType type);
+ /// Convert the given set of types, filling 'results' as necessary. This
+ /// returns failure if the conversion of any of the types fails, success
+ /// otherwise.
+ LogicalResult convertTypes(ArrayRef<Type> types,
+ SmallVectorImpl<Type> &results);
+
+ /// Return true if the given type is legal for this type converter, i.e. the
+ /// type converts to itself.
+ bool isLegal(Type type);
- /// This hook allows for changing a FunctionType signature. This function
- /// should populate 'result' with the new arguments and results on success,
- /// otherwise return failure.
- ///
- /// The default behavior of this function is to call 'convertType' on
- /// individual function operands and results.
- virtual LogicalResult convertSignature(FunctionType type,
- SignatureConversion &result);
+ /// Return true if the inputs and outputs of the given function type are
+ /// legal.
+ bool isSignatureLegal(FunctionType funcType);
/// This hook allows for converting a specific argument of a signature. It
/// takes as inputs the original argument input number, type.
virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result);
- /// This hook allows for converting the signature of a region 'regionIdx',
- /// i.e. the signature of the entry to the region, on the given operation
- /// 'op'. This function should return a valid conversion for the signature on
- /// success, None otherwise. This hook is allowed to modify the attributes on
- /// the provided operation if necessary.
- ///
- /// The default behavior of this function is to invoke 'convertBlockSignature'
- /// on the entry block, if one is present. This function also provides special
- /// handling for FuncOp to update the type signature.
- ///
- /// TODO(riverriddle) This should be replaced in favor of using patterns, but
- /// the pattern rewriter needs to know how to properly replace/remap
- /// arguments.
- virtual llvm::Optional<SignatureConversion>
- convertRegionSignature(Operation *op, unsigned regionIdx);
-
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
/// valid conversion for the signature on success, None otherwise.
using RewritePattern::rewrite;
};
+/// Add a pattern to the given pattern list to convert the signature of a FuncOp
+/// with the given type converter.
+void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
+ MLIRContext *ctx,
+ TypeConverter &converter);
+
//===----------------------------------------------------------------------===//
// Conversion PatternRewriter
//===----------------------------------------------------------------------===//
struct ConversionPatternRewriterImpl;
} // end namespace detail
-/// This class implements a pattern rewriter for use with ConversionPatterns.
+/// This class implements a pattern rewriter for use with ConversionPatterns. It
+/// extends the base PatternRewriter and provides special conversion specific
+/// hooks.
class ConversionPatternRewriter final : public PatternRewriter {
public:
ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
~ConversionPatternRewriter() override;
+ /// Apply a signature conversion to the entry block of the given region.
+ void applySignatureConversion(Region *region,
+ TypeConverter::SignatureConversion &conversion);
+
+ /// Clone the given operation without cloning its regions.
+ Operation *cloneWithoutRegions(Operation *op);
+ template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+ return cast<OpT>(cloneWithoutRegions(op.getOperation()));
+ }
+
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
LLVM::LLVMDialect &dialect;
};
+struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
+ using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto funcOp = cast<FuncOp>(op);
+ FunctionType type = funcOp.getType();
+
+ // Convert the original function arguments.
+ TypeConverter::SignatureConversion result(type.getNumInputs());
+ for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+ if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
+ return matchFailure();
+
+ // Pack the result types into a struct.
+ Type packedResult;
+ if (type.getNumResults() != 0) {
+ if (!(packedResult = lowering.packFunctionResults(type.getResults())))
+ return matchFailure();
+ }
+
+ // Create a new function with an updated signature.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(FunctionType::get(
+ result.getConvertedTypes(),
+ packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
+ funcOp.getContext()));
+
+ // Tell the rewriter to convert the region signature.
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+ rewriter.replaceOp(op, llvm::None);
+ return matchSuccess();
+ }
+};
+
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering,
- IndexCastOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering,
- MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering,
- RemFOpLowering, ReturnOpLowering, SelectOpLowering, StoreOpLowering,
- SubFOpLowering, SubIOpLowering,
+ FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
+ MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
+ RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
+ SelectOpLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering,
XOrOpLowering>::build(patterns, *converter.getDialect(), converter);
}
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
-// Convert function signatures using the stored LLVM IR module.
-LogicalResult LLVMTypeConverter::convertSignature(FunctionType type,
- SignatureConversion &result) {
- // Convert the original function arguments.
- for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
- if (failed(convertSignatureArg(i, type.getInput(i), result)))
- return failure();
-
- // If function does not return anything, return immediately.
- if (type.getNumResults() == 0)
- return success();
-
- // Otherwise pack the result types into a struct.
- if (auto packedRet = packFunctionResults(type.getResults())) {
- result.addResults(packedRet);
- return success();
- }
-
- return failure();
-}
-
/// Create an instance of LLVMTypeConverter in the given context.
static std::unique_ptr<LLVMTypeConverter>
makeStandardToLLVMTypeConverter(MLIRContext *context) {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return typeConverter->isSignatureLegal(op.getType());
+ });
if (failed(applyPartialConversion(m, target, std::move(patterns),
typeConverter.get())))
signalPassFailure();
/// Return if the signature of the given block has already been converted.
bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
- /// Attempt to convert the signature of the given region.
- LogicalResult convertSignature(Region ®ion, BlockAndValueMapping &mapping);
-
/// Attempt to convert the signature of the given block.
LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
}
}
-/// Converts the signature of the given region.
-LogicalResult ArgConverter::convertSignature(Region ®ion,
- BlockAndValueMapping &mapping) {
- if (auto conversion = typeConverter->convertRegionSignature(
- region.getContainingOp(), region.getRegionNumber())) {
- if (!region.empty())
- applySignatureConversion(®ion.front(), *conversion, mapping);
- return success();
- }
- return failure();
-}
-
/// Converts the signature of the given entry block.
LogicalResult ArgConverter::convertSignature(Block *block,
BlockAndValueMapping &mapping) {
- auto conversion = typeConverter->convertBlockSignature(block);
- if (conversion)
+ if (auto conversion = typeConverter->convertBlockSignature(block))
return applySignatureConversion(block, *conversion, mapping), success();
return failure();
}
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedArgTypes();
+ auto convertedTypes = signatureConversion.getConvertedTypes();
if (origArgCount == 0 && convertedTypes.empty())
return;
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
- unsigned numBlockActions, unsigned numTypeConversions)
+ unsigned numBlockActions)
: numCreatedOperations(numCreatedOperations),
- numReplacements(numReplacements), numBlockActions(numBlockActions),
- numTypeConversions(numTypeConversions) {}
+ numReplacements(numReplacements), numBlockActions(numBlockActions) {}
/// The current number of created operations.
unsigned numCreatedOperations;
/// The current number of block actions performed.
unsigned numBlockActions;
-
- /// The current number of type conversion actions performed.
- unsigned numTypeConversions;
};
} // end anonymous namespace
/// The kind of the block action performed during the rewrite. Actions can be
/// undone if the conversion fails.
- enum class BlockActionKind { Split, Move };
+ enum class BlockActionKind { Split, Move, TypeConversion };
/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
/// The storage class for an undoable block action (one of BlockActionKind),
/// contains the information necessary to undo this action.
struct BlockAction {
+ static BlockAction getSplit(Block *block, Block *originalBlock) {
+ BlockAction action{BlockActionKind::Split, block};
+ action.originalBlock = originalBlock;
+ return action;
+ }
+ static BlockAction getMove(Block *block, BlockPosition originalPos) {
+ return {BlockActionKind::Move, block, {originalPos}};
+ }
+ static BlockAction getTypeConversion(Block *block) {
+ return BlockAction{BlockActionKind::TypeConversion, block};
+ }
+
+ // The action kind.
+ BlockActionKind kind;
+
// A pointer to the block that was created by the action.
Block *block;
// block that was split into two parts.
Block *originalBlock;
};
-
- BlockActionKind kind;
- };
-
- /// A storage class representing a type conversion of a block or a region.
- struct TypeConversion {
- /// The region, or block, that had its types converted.
- llvm::PointerUnion<Region *, Block *> object;
-
- /// If the object is a region, this corresponds to the original attributes
- /// of the parent operation.
- NamedAttributeList originalParentAttributes;
};
ConversionPatternRewriterImpl(PatternRewriter &rewriter,
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
- /// Undo the type conversion actions one by one, until "numActionsToKeep"
- /// actions remain.
- void undoTypeConversions(unsigned numActionsToKeep = 0);
-
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
void discardRewrites();
/// conversion process succeeds.
void applyRewrites();
- /// Return if the given block has already been converted.
- bool hasSignatureBeenConverted(Block *block);
-
- /// Convert the signature of the given region.
- LogicalResult convertRegionSignature(Region ®ion);
-
/// Convert the signature of the given block.
LogicalResult convertBlockSignature(Block *block);
+ /// Apply a signature conversion on the given region.
+ void applySignatureConversion(Region *region,
+ TypeConverter::SignatureConversion &conversion);
+
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead);
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
-
- /// Ordered list of type conversion actions.
- SmallVector<TypeConversion, 4> typeConversions;
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), typeConversions.size());
+ blockActions.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Undo any type conversions or block actions.
- undoTypeConversions(state.numTypeConversions);
+ // Undo any block actions.
undoBlockActions(state.numBlockActions);
// Reset any replaced operations and undo any saved mappings.
action.block->getParent()->getBlocks(), action.block);
break;
}
+ // Undo the type conversion.
+ case BlockActionKind::TypeConversion: {
+ argConverter.discardPendingRewrites(action.block);
+ break;
+ }
}
}
blockActions.resize(numActionsToKeep);
}
-void ConversionPatternRewriterImpl::undoTypeConversions(
- unsigned numActionsToKeep) {
- for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) {
- if (auto *region = conversion.object.dyn_cast<Region *>())
- region->getContainingOp()->setAttrs(conversion.originalParentAttributes);
- else
- argConverter.discardPendingRewrites(conversion.object.get<Block *>());
- }
- typeConversions.resize(numActionsToKeep);
-}
-
void ConversionPatternRewriterImpl::discardRewrites() {
- undoTypeConversions();
undoBlockActions();
// Remove any newly created ops.
argConverter.applyRewrites();
}
-bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) {
- return argConverter.hasBeenConverted(block);
-}
-
LogicalResult
-ConversionPatternRewriterImpl::convertRegionSignature(Region ®ion) {
- auto parentAttrs = region.getContainingOp()->getAttrList();
- auto result = argConverter.convertSignature(region, mapping);
- if (succeeded(result)) {
- typeConversions.push_back(TypeConversion{®ion, parentAttrs});
- if (!region.empty())
- typeConversions.push_back(
- TypeConversion{®ion.front(), NamedAttributeList()});
- }
- return result;
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+ // Check to see if this block should not be converted:
+ // * The block is invalid, or there is no type converter.
+ // * The block has already been converted.
+ // * This is an entry block, these are converted explicitly via patterns.
+ if (!block || !argConverter.typeConverter ||
+ argConverter.hasBeenConverted(block) || block->isEntryBlock())
+ return success();
+
+ // Otherwise, try to convert the block signature.
+ if (failed(argConverter.convertSignature(block, mapping)))
+ return failure();
+ blockActions.push_back(BlockAction::getTypeConversion(block));
+ return success();
}
-LogicalResult
-ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
- auto result = argConverter.convertSignature(block, mapping);
- if (succeeded(result))
- typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
- return result;
+void ConversionPatternRewriterImpl::applySignatureConversion(
+ Region *region, TypeConverter::SignatureConversion &conversion) {
+ if (!region->empty()) {
+ argConverter.applySignatureConversion(®ion->front(), conversion,
+ mapping);
+ blockActions.push_back(BlockAction::getTypeConversion(®ion->front()));
+ }
}
void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
Block *continuation) {
- BlockAction action;
- action.kind = BlockActionKind::Split;
- action.block = continuation;
- action.originalBlock = block;
- blockActions.push_back(action);
+ blockActions.push_back(BlockAction::getSplit(continuation, block));
}
void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
for (auto &pair : llvm::enumerate(region)) {
Block &block = pair.value();
unsigned position = pair.index();
- BlockAction action;
- action.kind = BlockActionKind::Move;
- action.block = █
- action.originalPosition = {®ion, position};
- blockActions.push_back(action);
+ blockActions.push_back(BlockAction::getMove(&block, {®ion, position}));
}
}
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
}
+/// Apply a signature conversion to the entry block of the given region.
+void ConversionPatternRewriter::applySignatureConversion(
+ Region *region, TypeConverter::SignatureConversion &conversion) {
+ impl->applySignatureConversion(region, conversion);
+}
+
+/// Clone the given operation without cloning its regions.
+Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
+ Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
+ impl->createdOps.push_back(newOp);
+ return newOp;
+}
+
/// PatternRewriter hook for splitting a block into two parts.
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
LogicalResult
OperationLegalizer::legalize(Operation *op,
ConversionPatternRewriter &rewriter) {
- // Make sure that the signature of the parent block of this operation has been
- // converted.
- auto &rewriterImpl = rewriter.getImpl();
- if (rewriterImpl.argConverter.typeConverter) {
- auto *block = op->getBlock();
- if (block && !rewriterImpl.hasSignatureBeenConverted(block)) {
- if (failed(block->isEntryBlock()
- ? rewriterImpl.convertRegionSignature(*block->getParent())
- : rewriterImpl.convertBlockSignature(block)))
- return failure();
- }
- }
+ // Make sure that the signature of the parent block has been converted.
+ if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock())))
+ return failure();
LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
<< "\n");
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
- /// Recursively collect all of the operations, to convert from within
- /// 'region'.
+ /// Recursively collect all of the operations to convert from within 'region'.
LogicalResult computeConversionSet(Region ®ion,
std::vector<Operation *> &toConvert);
+ /// Converts the type signatures of the blocks nested within 'op' that have
+ /// yet to be converted.
+ LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
+ Operation *op);
+
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
};
} // end anonymous namespace
-/// Recursively collect all of the blocks to convert from within 'region'.
+LogicalResult
+OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
+ Operation *op) {
+ SmallVector<Region *, 8> worklist;
+ for (auto ®ion : op->getRegions())
+ worklist.push_back(®ion);
+
+ while (!worklist.empty()) {
+ for (auto &block : *worklist.pop_back_val()) {
+ if (failed(rewriter.getImpl().convertBlockSignature(&block)))
+ return failure();
+ for (auto &nestedOp : block)
+ for (auto ®ion : nestedOp.getRegions())
+ worklist.push_back(®ion);
+ }
+ }
+ return success();
+}
+
LogicalResult
OperationConverter::computeConversionSet(Region ®ion,
std::vector<Operation *> &toConvert) {
return success();
}
-/// Converts an operation with the given rewriter.
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
// Legalize the given operation.
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
}
-
- // Convert the signature of any empty regions of this operation, non-empty
- // regions are converted on demand when converting any operations contained
- // within.
- // FIXME(riverriddle) This should be replaced by patterns when the pattern
- // rewriter exposes functionality to remap region signatures.
- auto &rewriterImpl = rewriter.getImpl();
- if (rewriterImpl.argConverter.typeConverter) {
- for (auto ®ion : op->getRegions())
- if (region.empty() && failed(rewriterImpl.convertRegionSignature(region)))
- return failure();
- }
-
return success();
}
-/// Converts the given operations to the conversion target.
LogicalResult
OperationConverter::convertOperations(ArrayRef<Operation *> ops,
TypeConverter *typeConverter) {
// Convert each operation and discard rewrites on failure.
ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
- for (auto *op : toConvert) {
- if (failed(convert(rewriter, op))) {
- rewriter.getImpl().discardRewrites();
- return failure();
- }
+ for (auto *op : toConvert)
+ if (failed(convert(rewriter, op)))
+ return rewriter.getImpl().discardRewrites(), failure();
+
+ // If a type converter was provided, ensure that all blocks have had their
+ // signatures properly converted.
+ if (typeConverter) {
+ for (auto *op : ops)
+ if (failed(convertBlockSignatures(rewriter, op)))
+ return rewriter.getImpl().discardRewrites(), failure();
}
// Otherwise the body conversion succeeded, so apply all rewrites.
// Type Conversion
//===----------------------------------------------------------------------===//
-/// Append new result types to the signature conversion.
-void TypeConverter::SignatureConversion::addResults(ArrayRef<Type> results) {
- resultTypes.append(results.begin(), results.end());
-}
-
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
return failure();
}
-/// Convert the given FunctionType signature.
-auto TypeConverter::convertSignature(FunctionType type)
- -> llvm::Optional<SignatureConversion> {
- SignatureConversion result(type.getNumInputs());
- if (failed(convertSignature(type, result)))
- return llvm::None;
- return result;
-}
-
-/// This hook allows for changing a FunctionType signature.
-LogicalResult TypeConverter::convertSignature(FunctionType type,
- SignatureConversion &result) {
- // Convert the original function arguments.
- for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
- if (failed(convertSignatureArg(i, type.getInput(i), result)))
+/// Convert the given set of types, filling 'results' as necessary. This
+/// returns failure if the conversion of any of the types fails, success
+/// otherwise.
+LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
+ SmallVectorImpl<Type> &results) {
+ for (auto type : types)
+ if (failed(convertType(type, results)))
return failure();
+ return success();
+}
- // Convert the original function results.
- SmallVector<Type, 1> convertedTypes;
- for (auto t : type.getResults()) {
- convertedTypes.clear();
- if (failed(convertType(t, convertedTypes)))
- return failure();
- result.addResults(convertedTypes);
- }
+/// Return true if the given type is legal for this type converter, i.e. the
+/// type converts to itself.
+bool TypeConverter::isLegal(Type type) {
+ SmallVector<Type, 1> results;
+ return succeeded(convertType(type, results)) && results.size() == 1 &&
+ results.front() == type;
+}
- return success();
+/// Return true if the inputs and outputs of the given function type are
+/// legal.
+bool TypeConverter::isSignatureLegal(FunctionType funcType) {
+ return llvm::all_of(
+ llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
+ [this](Type type) { return isLegal(type); });
}
/// This hook allows for converting a specific argument of a signature.
return success();
}
-/// This hook defines how the signature of a region 'regionIdx', i.e. the
-/// signature of the entry to the region, on the given operation 'op' is
-/// converted. This function should return a valid conversion for the signature
-/// on success, None otherwise.
-///
-/// The default behavior of this function is to invoke 'convertBlockSignature'
-/// on the entry block, if one is present. This function also provides special
-/// handling for FuncOp to update the type signature.
-///
-/// TODO(riverriddle) This should be replaced in favor of using patterns, but
-/// the pattern rewriter needs to know how to properly replace/remap
-/// arguments.
-auto TypeConverter::convertRegionSignature(Operation *op, unsigned regionIdx)
- -> llvm::Optional<SignatureConversion> {
- // Provide explicit handling for FuncOp.
- if (auto funcOp = dyn_cast<FuncOp>(op)) {
- auto conversion = convertSignature(funcOp.getType());
- if (conversion)
- funcOp.setType(conversion->getConvertedType(funcOp.getContext()));
- return conversion;
+/// Create a default conversion pattern that rewrites the type signature of a
+/// FuncOp.
+namespace {
+struct FuncOpSignatureConversion : public ConversionPattern {
+ FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+ : ConversionPattern(FuncOp::getOperationName(), 1, ctx),
+ converter(converter) {}
+
+ /// Hook for derived classes to implement combined matching and rewriting.
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto funcOp = cast<FuncOp>(op);
+ FunctionType type = funcOp.getType();
+
+ // Convert the original function arguments.
+ TypeConverter::SignatureConversion result(type.getNumInputs());
+ for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+ if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
+ return matchFailure();
+
+ // Convert the original function results.
+ SmallVector<Type, 1> convertedResults;
+ if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+ return matchFailure();
+
+ // Create a new function with an updated signature.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
+ convertedResults, funcOp.getContext()));
+
+ // Tell the rewriter to convert the region signature.
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+ rewriter.replaceOp(op, llvm::None);
+ return matchSuccess();
}
- // Otherwise, default to handle the arguments of the entry block for the given
- // region.
- auto ®ion = op->getRegion(regionIdx);
- if (region.empty())
- return SignatureConversion(/*numOrigInputs=*/0);
- return convertBlockSignature(®ion.front());
+ /// The type converter to use when rewriting the signature.
+ TypeConverter &converter;
+};
+} // end anonymous namespace
+
+void mlir::populateFuncOpTypeConversionPattern(
+ OwningRewritePatternList &patterns, MLIRContext *ctx,
+ TypeConverter &converter) {
+ RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
+ converter);
}
/// This function converts the type signature of the given block, by invoking