public:
virtual ~TypeConverter() = default;
- /// This class provides all of the information necessary to convert a
- /// FunctionType signature.
+ /// This class provides all of the information necessary to convert a type
+ /// signature.
class SignatureConversion {
public:
SignatureConversion(unsigned numOrigInputs)
/// Return the result types for the new signature.
ArrayRef<Type> getConvertedResultTypes() const { return resultTypes; }
- /// Returns the attributes for the arguments of the new signature.
- ArrayRef<NamedAttributeList> getConvertedArgAttrs() const {
- return argAttrs;
- }
-
/// Get the input mapping for the given argument.
llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
return remappedInputs[input];
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
- void addInputs(unsigned origInputNo, ArrayRef<Type> types,
- ArrayRef<NamedAttributeList> attrs = llvm::None);
+ void addInputs(unsigned origInputNo, ArrayRef<Type> types);
/// Append new input types to the signature conversion, this should only be
/// used if the new types are not intended to remap an existing input.
- void addInputs(ArrayRef<Type> types,
- ArrayRef<NamedAttributeList> attrs = llvm::None);
+ void addInputs(ArrayRef<Type> types);
/// Remap an input of the original signature with a range of types in the
/// new signature.
/// The set of argument and results types.
SmallVector<Type, 4> argTypes, resultTypes;
-
- /// The set of attributes for each new argument type.
- SmallVector<NamedAttributeList, 4> argAttrs;
};
/// This hooks allows for converting a type. This function should return
/// Convert the given FunctionType signature. This functions returns a valid
/// SignatureConversion on success, None otherwise.
- llvm::Optional<SignatureConversion>
- convertSignature(FunctionType type, ArrayRef<NamedAttributeList> argAttrs);
- llvm::Optional<SignatureConversion> convertSignature(FunctionType type) {
- SmallVector<NamedAttributeList, 4> argAttrs(type.getNumInputs());
- return convertSignature(type, argAttrs);
- }
+ llvm::Optional<SignatureConversion> convertSignature(FunctionType type);
/// This hook allows for changing a FunctionType signature. This function
- /// should populate 'result' with the new arguments and result on success,
+ /// should populate 'result' with the new arguments and results on success,
/// otherwise return failure.
///
/// The default behavior of this function is to call 'convertType' on
- /// individual function operands and results. Any argument attributes are
- /// dropped if the resultant conversion is not a 1->1 mapping.
+ /// individual function operands and results.
virtual LogicalResult convertSignature(FunctionType type,
- ArrayRef<NamedAttributeList> argAttrs,
SignatureConversion &result);
/// This hook allows for converting a specific argument of a signature. It
- /// takes as inputs the original argument input number, type, and attributes.
+ /// takes as inputs the original argument input number, type.
/// On success, this function should populate 'result' with any new mappings.
virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
- NamedAttributeList attrs,
SignatureConversion &result);
+ /// This hook allows for converting the signature of a region 'regionIdx',
+ /// i.e. the signature of the entry to the region, on the given operation
+ /// 'op'. This function should return a valid conversion for the signature on
+ /// success, None otherwise. This hook is allowed to modify the attributes on
+ /// the provided operation if necessary.
+ ///
+ /// The default behavior of this function is to invoke 'convertBlockSignature'
+ /// on the entry block, if one is present. This function also provides special
+ /// handling for FuncOp to update the type signature.
+ ///
+ /// TODO(riverriddle) This should be replaced in favor of using patterns, but
+ /// the pattern rewriter needs to know how to properly replace/remap
+ /// arguments.
+ virtual llvm::Optional<SignatureConversion>
+ convertRegionSignature(Operation *op, unsigned regionIdx);
+
+ /// This function converts the type signature of the given block, by invoking
+ /// 'convertSignatureArg' for each argument. This function should return a
+ /// valid conversion for the signature on success, None otherwise.
+ llvm::Optional<SignatureConversion> convertBlockSignature(Block *block);
+
/// This hook allows for materializing a conversion from a set of types into
/// one result type by generating a cast operation of some kind. The generated
/// operation should produce one result, of 'resultType', with the provided
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there are unreachable blocks in any of the regions nested
-/// within 'ops'.
-LLVM_NODISCARD LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
- OwningRewritePatternList &&patterns);
-LLVM_NODISCARD LogicalResult
-applyPartialConversion(Operation *op, ConversionTarget &target,
- OwningRewritePatternList &&patterns);
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
+LLVM_NODISCARD LogicalResult applyPartialConversion(
+ Operation *op, ConversionTarget &target,
+ OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
/// fails, or if there are unreachable blocks in any of the regions nested
-/// within 'ops'.
-LLVM_NODISCARD LogicalResult
-applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
- OwningRewritePatternList &&patterns);
-LLVM_NODISCARD LogicalResult
-applyFullConversion(Operation *op, ConversionTarget &target,
- OwningRewritePatternList &&patterns);
-
-//===----------------------------------------------------------------------===//
-// Op + Type Conversion Entry Points
-//===----------------------------------------------------------------------===//
-
-/// Apply a partial conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered.
-LLVM_NODISCARD LogicalResult applyPartialConversion(
- ModuleOp module, ConversionTarget &target, TypeConverter &converter,
- OwningRewritePatternList &&patterns);
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered.
-LLVM_NODISCARD LogicalResult applyPartialConversion(
- MutableArrayRef<FuncOp> fns, ConversionTarget &target,
- TypeConverter &converter, OwningRewritePatternList &&patterns);
-
-/// Apply a full conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered, or
-/// if the conversion of any operations failed.
+/// within 'ops'. If 'converter' is provided, the signatures of blocks and
+/// regions are also converted.
LLVM_NODISCARD LogicalResult applyFullConversion(
- ModuleOp module, ConversionTarget &target, TypeConverter &converter,
- OwningRewritePatternList &&patterns);
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered, or if the conversion
-/// of any operation failed.
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
LLVM_NODISCARD LogicalResult applyFullConversion(
- MutableArrayRef<FuncOp> fns, ConversionTarget &target,
- TypeConverter &converter, OwningRewritePatternList &&patterns);
+ Operation *op, ConversionTarget &target,
+ OwningRewritePatternList &&patterns, TypeConverter *converter = nullptr);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
/// Erase any rewrites registered for arguments to blocks within the given
/// region. This function is called when the given region is to be destroyed.
- void cancelPendingRewrites(Region ®ion);
+ void cancelPendingRewrites(Block *block);
- /// Cleanup and undo any generated conversion values.
- void discardRewrites();
+ /// Cleanup and undo any generated conversions for the arguments of block.
+ /// This method differs from 'cancelPendingRewrites' in that it returns the
+ /// block signature to its original state.
+ void discardPendingRewrites(Block *block);
/// Replace usages of the cast operations with the argument directly.
void applyRewrites();
- /// Converts the signature of the given entry block.
- void convertSignature(Block *block,
- TypeConverter::SignatureConversion &signatureConversion,
- BlockAndValueMapping &mapping);
+ /// Return if the signature of the given block has already been converted.
+ bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
- /// Converts the arguments of the given block.
- LogicalResult convertArguments(Block *block, BlockAndValueMapping &mapping);
+ /// Attempt to convert the signature of the given region.
+ LogicalResult convertSignature(Region ®ion, BlockAndValueMapping &mapping);
+
+ /// Attempt to convert the signature of the given block.
+ LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
+
+ /// Apply the given signature conversion on the given block.
+ void applySignatureConversion(
+ Block *block, TypeConverter::SignatureConversion &signatureConversion,
+ BlockAndValueMapping &mapping);
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
constexpr StringLiteral ArgConverter::kCastName;
-/// Erase any rewrites registered for arguments to blocks within the given
-/// region. This function is called when the given region is to be destroyed.
-void ArgConverter::cancelPendingRewrites(Region ®ion) {
- for (auto &block : region) {
- auto it = argMapping.find(&block);
- if (it == argMapping.end())
- continue;
- for (auto *op : it->second) {
- // If the operation exists within the parent block, like with 1->N cast
- // operations, we don't need to drop them. They will be automatically
- // cleaned up with the region is destroyed.
- if (op->getBlock())
- continue;
-
- op->dropAllDefinedValueUses();
- op->destroy();
- }
- argMapping.erase(it);
+/// Erase any rewrites registered for arguments to the given block.
+void ArgConverter::cancelPendingRewrites(Block *block) {
+ auto it = argMapping.find(block);
+ if (it == argMapping.end())
+ return;
+ for (auto *op : it->second) {
+ op->dropAllDefinedValueUses();
+ op->erase();
}
+ argMapping.erase(it);
}
-/// Cleanup and undo any generated conversion values.
-void ArgConverter::discardRewrites() {
- // On failure reinstate all of the original block arguments.
- Block *block;
- ArrayRef<Operation *> argOps;
- for (auto &mapping : argMapping) {
- std::tie(block, argOps) = mapping;
+/// Cleanup and undo any generated conversions for the arguments of block.
+/// This method differs from 'cancelPendingRewrites' in that it returns the
+/// block signature to its original state.
+void ArgConverter::discardPendingRewrites(Block *block) {
+ auto it = argMapping.find(block);
+ if (it == argMapping.end())
+ return;
- // Erase all of the new arguments.
- for (int i = block->getNumArguments() - 1; i >= 0; --i) {
- block->getArgument(i)->dropAllUses();
- block->eraseArgument(i, /*updatePredTerms=*/false);
- }
+ // Erase all of the new arguments.
+ for (int i = block->getNumArguments() - 1; i >= 0; --i) {
+ block->getArgument(i)->dropAllUses();
+ block->eraseArgument(i, /*updatePredTerms=*/false);
+ }
- // Re-instate the old arguments.
- for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
- auto *op = argOps[i];
- auto *arg = block->addArgument(op->getResult(0)->getType());
- op->getResult(0)->replaceAllUsesWith(arg);
+ // Re-instate the old arguments.
+ auto &mapping = it->second;
+ for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
+ auto *op = mapping[i];
+ auto *arg = block->addArgument(op->getResult(0)->getType());
+ op->getResult(0)->replaceAllUsesWith(arg);
- // If this was a 1->N value mapping it exists within the parent block so
- // erase it instead of destroying.
- if (op->getBlock())
- op->erase();
- else
- op->destroy();
- }
+ // If this operation is within a block, it will be cleaned up automatically.
+ if (!op->getBlock())
+ op->erase();
}
- argMapping.clear();
+ argMapping.erase(it);
}
/// Replace usages of the cast operations with the argument directly.
}
}
+/// Converts the signature of the given region.
+LogicalResult ArgConverter::convertSignature(Region ®ion,
+ BlockAndValueMapping &mapping) {
+ if (auto conversion = typeConverter->convertRegionSignature(
+ region.getContainingOp(), region.getRegionNumber())) {
+ if (!region.empty())
+ applySignatureConversion(®ion.front(), *conversion, mapping);
+ return success();
+ }
+ return failure();
+}
+
/// Converts the signature of the given entry block.
-void ArgConverter::convertSignature(
+LogicalResult ArgConverter::convertSignature(Block *block,
+ BlockAndValueMapping &mapping) {
+ auto conversion = typeConverter->convertBlockSignature(block);
+ if (conversion)
+ return applySignatureConversion(block, *conversion, mapping), success();
+ return failure();
+}
+
+/// Apply the given signature conversion on the given block.
+void ArgConverter::applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
block->eraseArgument(0, /*updatePredTerms=*/false);
}
-/// Converts the arguments of the given block.
-LogicalResult ArgConverter::convertArguments(Block *block,
- BlockAndValueMapping &mapping) {
- unsigned origArgCount = block->getNumArguments();
- if (origArgCount == 0 || argMapping.count(block))
- return success();
-
- // Convert the types of each of the block arguments.
- SmallVector<SmallVector<Type, 1>, 4> newArgTypes(origArgCount);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto *arg = block->getArgument(i);
- if (failed(typeConverter->convertType(arg->getType(), newArgTypes[i])))
- return emitError(block->getParent()->getLoc())
- << "could not convert block argument of type " << arg->getType();
- }
-
- // Remap all of the original argument values.
- auto &newArgMapping = argMapping[block];
- rewriter.setInsertionPointToStart(block);
- for (unsigned i = 0; i != origArgCount; ++i) {
- SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
- newArgMapping.push_back(
- convertArgument(block->getArgument(i), newArgs, mapping));
- }
-
- // Erase all of the original arguments.
- for (unsigned i = 0; i != origArgCount; ++i)
- block->eraseArgument(0, /*updatePredTerms=*/false);
- return success();
-}
-
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
/// to perform the conversion.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
- unsigned numBlockActions)
+ unsigned numBlockActions, unsigned numTypeConversions)
: numCreatedOperations(numCreatedOperations),
- numReplacements(numReplacements), numBlockActions(numBlockActions) {}
+ numReplacements(numReplacements), numBlockActions(numBlockActions),
+ numTypeConversions(numTypeConversions) {}
/// The current number of created operations.
unsigned numCreatedOperations;
/// The current number of block actions performed.
unsigned numBlockActions;
+
+ /// The current number of type conversion actions performed.
+ unsigned numTypeConversions;
};
/// This class implements a pattern rewriter for ConversionPattern
BlockActionKind kind;
};
+ /// A storage class representing a type conversion of a block or a region.
+ struct TypeConversion {
+ /// The region, or block, that had its types converted.
+ llvm::PointerUnion<Region *, Block *> object;
+
+ /// If the object is a region, this corresponds to the original attributes
+ /// of the parent operation.
+ NamedAttributeList originalParentAttributes;
+ };
+
DialectConversionRewriter(MLIRContext *ctx, TypeConverter *converter)
: PatternRewriter(ctx), argConverter(converter, *this) {}
~DialectConversionRewriter() = default;
/// Return the current state of the rewriter.
RewriterState getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size());
+ blockActions.size(), typeConversions.size());
}
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state) {
+ // Undo any type conversions or block actions.
+ undoTypeConversions(state.numTypeConversions);
+ undoBlockActions(state.numBlockActions);
+
// Reset any replaced operations and undo any saved mappings.
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
for (auto *result : repl.op->getResults())
// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOperations)
createdOps.pop_back_val()->erase();
-
- // Undo any block operations.
- undoBlockActions(state.numBlockActions);
}
/// Undo the block actions (motions, splits) one by one in reverse order until
}
}
}
+ blockActions.resize(numActionsToKeep);
+ }
+
+ /// Undo the type conversion actions one by one, until "numActionsToKeep"
+ /// actions remain.
+ void undoTypeConversions(unsigned numActionsToKeep = 0) {
+ for (auto &conversion :
+ llvm::drop_begin(typeConversions, numActionsToKeep)) {
+ if (auto *region = conversion.object.dyn_cast<Region *>())
+ region->getContainingOp()->setAttrs(
+ conversion.originalParentAttributes);
+ else
+ argConverter.discardPendingRewrites(conversion.object.get<Block *>());
+ }
+ typeConversions.resize(numActionsToKeep);
}
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
void discardRewrites() {
- argConverter.discardRewrites();
+ undoTypeConversions();
+ undoBlockActions();
// Remove any newly created ops.
for (auto *op : createdOps) {
op->dropAllDefinedValueUses();
op->erase();
}
-
- undoBlockActions();
}
/// Apply all requested operation rewrites. This method is invoked when the
// if this operation defines any regions, drop any pending argument
// rewrites.
- if (repl.op->getNumRegions() && !argConverter.argMapping.empty()) {
+ if (argConverter.typeConverter && repl.op->getNumRegions()) {
for (auto ®ion : repl.op->getRegions())
- argConverter.cancelPendingRewrites(region);
+ for (auto &block : region)
+ argConverter.cancelPendingRewrites(&block);
}
}
argConverter.applyRewrites();
}
+ /// Return if the given block has already been converted.
+ bool hasSignatureBeenConverted(Block *block) {
+ return argConverter.hasBeenConverted(block);
+ }
+
+ /// Convert the signature of the given region.
+ LogicalResult convertRegionSignature(Region ®ion) {
+ auto parentAttrs = region.getContainingOp()->getAttrList();
+ auto result = argConverter.convertSignature(region, mapping);
+ if (succeeded(result)) {
+ typeConversions.push_back(TypeConversion{®ion, parentAttrs});
+ if (!region.empty())
+ typeConversions.push_back(
+ TypeConversion{®ion.front(), NamedAttributeList()});
+ }
+ return result;
+ }
+
+ /// Convert the signature of the given block.
+ LogicalResult convertBlockSignature(Block *block) {
+ auto result = argConverter.convertSignature(block, mapping);
+ if (succeeded(result))
+ typeConversions.push_back(TypeConversion{block, NamedAttributeList()});
+ return result;
+ }
+
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) override {
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
+
+ /// Ordered list of type conversion actions.
+ SmallVector<TypeConversion, 4> typeConversions;
};
} // end anonymous namespace
LogicalResult
OperationLegalizer::legalize(Operation *op,
DialectConversionRewriter &rewriter) {
+ // Make sure that the signature of the parent block of this operation has been
+ // converted.
+ if (rewriter.argConverter.typeConverter) {
+ auto *block = op->getBlock();
+ if (block && !rewriter.hasSignatureBeenConverted(block)) {
+ if (failed(block->isEntryBlock()
+ ? rewriter.convertRegionSignature(*block->getParent())
+ : rewriter.convertBlockSignature(block)))
+ return failure();
+ }
+ }
+
LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
<< "\n");
struct OperationConverter {
explicit OperationConverter(ConversionTarget &target,
OwningRewritePatternList &patterns,
- OpConversionMode mode,
- TypeConverter *conversion = nullptr)
- : typeConverter(conversion), opLegalizer(target, patterns), mode(mode) {}
-
- /// Converts the given function to the conversion target. Returns failure on
- /// error, success otherwise.
- LogicalResult
- convertFunction(FuncOp f,
- TypeConverter::SignatureConversion &signatureConversion);
+ OpConversionMode mode)
+ : opLegalizer(target, patterns), mode(mode) {}
/// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops);
+ LogicalResult convertOperations(ArrayRef<Operation *> ops,
+ TypeConverter *typeConverter);
private:
- /// Converts a block or operation with the given rewriter.
- LogicalResult convert(DialectConversionRewriter &rewriter,
- llvm::PointerUnion<Operation *, Block *> &ptr);
-
- /// Converts a set of blocks/operations with the given rewriter.
- LogicalResult
- convert(DialectConversionRewriter &rewriter,
- std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
-
- /// Recursively collect all of the blocks, and operations, to convert from
- /// within 'region'.
- LogicalResult computeConversionSet(
- Region ®ion,
- std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
-
- /// Pointer to the type converter.
- TypeConverter *typeConverter;
+ /// Converts an operation with the given rewriter.
+ LogicalResult convert(DialectConversionRewriter &rewriter, Operation *op);
+
+ /// Recursively collect all of the operations, to convert from within
+ /// 'region'.
+ LogicalResult computeConversionSet(Region ®ion,
+ std::vector<Operation *> &toConvert);
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
} // end anonymous namespace
/// Recursively collect all of the blocks to convert from within 'region'.
-LogicalResult OperationConverter::computeConversionSet(
- Region ®ion,
- std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
+LogicalResult
+OperationConverter::computeConversionSet(Region ®ion,
+ std::vector<Operation *> &toConvert) {
if (region.empty())
return success();
while (!worklist.empty()) {
auto *block = worklist.pop_back_val();
- // We only need to process blocks if we are changing argument types.
- if (typeConverter)
- toConvert.emplace_back(block);
-
// Compute the conversion set of each of the nested operations.
for (auto &op : *block) {
toConvert.emplace_back(&op);
return success();
}
-/// Converts a block or operation with the given rewriter.
-LogicalResult
-OperationConverter::convert(DialectConversionRewriter &rewriter,
- llvm::PointerUnion<Operation *, Block *> &ptr) {
- // If this is a block, then convert the types of each of the arguments.
- if (auto *block = ptr.dyn_cast<Block *>()) {
- assert(typeConverter && "expected valid type converter");
- return rewriter.argConverter.convertArguments(block, rewriter.mapping);
- }
-
- // Otherwise, legalize the given operation.
- auto *op = ptr.get<Operation *>();
+/// Converts an operation with the given rewriter.
+LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter,
+ Operation *op) {
+ // Legalize the given operation.
if (failed(opLegalizer.legalize(op, rewriter))) {
// Handle the case of a failed conversion for each of the different modes.
/// Full conversions expect all operations to be converted.
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
}
- return success();
-}
-LogicalResult OperationConverter::convert(
- DialectConversionRewriter &rewriter,
- std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
- // Convert each operation/block and discard rewrites on failure.
- for (auto &it : toConvert) {
- if (failed(convert(rewriter, it))) {
- rewriter.discardRewrites();
- return failure();
- }
+ // Convert the signature of any empty regions of this operation, non-empty
+ // regions are converted on demand when converting any operations contained
+ // within.
+ // FIXME(riverriddle) This should be replaced by patterns when the pattern
+ // rewriter exposes functionality to remap region signatures.
+ if (rewriter.argConverter.typeConverter) {
+ for (auto ®ion : op->getRegions())
+ if (region.empty() && failed(rewriter.convertRegionSignature(region)))
+ return failure();
}
- // Otherwise the body conversion succeeded, so apply all rewrites.
- rewriter.applyRewrites();
return success();
}
-LogicalResult OperationConverter::convertFunction(
- FuncOp f, TypeConverter::SignatureConversion &signatureConversion) {
- // If this is an external function, there is nothing else to do.
- if (f.isExternal())
- return success();
-
- // Update the signature of the entry block.
- DialectConversionRewriter rewriter(f.getContext(), typeConverter);
- rewriter.argConverter.convertSignature(&f.getBody().front(),
- signatureConversion, rewriter.mapping);
-
- // Compute the set of operations and blocks to convert.
- std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
- if (failed(computeConversionSet(f.getBody(), toConvert)))
- return failure();
- return convert(rewriter, toConvert);
-}
-
/// Converts the given top-level operation to the conversion target.
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult
+OperationConverter::convertOperations(ArrayRef<Operation *> ops,
+ TypeConverter *typeConverter) {
if (ops.empty())
return success();
/// Compute the set of operations and blocks to convert.
- std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
+ std::vector<Operation *> toConvert;
for (auto *op : ops) {
toConvert.emplace_back(op);
for (auto ®ion : op->getRegions())
return failure();
}
- // Rewrite the blocks and operations.
+ // Convert each operation and discard rewrites on failure.
DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
- return convert(rewriter, toConvert);
+ for (auto *op : toConvert) {
+ if (failed(convert(rewriter, op))) {
+ rewriter.discardRewrites();
+ return failure();
+ }
+ }
+
+ // Otherwise the body conversion succeeded, so apply all rewrites.
+ rewriter.applyRewrites();
+ return success();
}
//===----------------------------------------------------------------------===//
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
-void TypeConverter::SignatureConversion::addInputs(
- unsigned origInputNo, ArrayRef<Type> types,
- ArrayRef<NamedAttributeList> attrs) {
+void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
+ ArrayRef<Type> types) {
assert(!types.empty() && "expected valid types");
remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
- addInputs(types, attrs);
+ addInputs(types);
}
/// Append new input types to the signature conversion, this should only be
/// used if the new types are not intended to remap an existing input.
-void TypeConverter::SignatureConversion::addInputs(
- ArrayRef<Type> types, ArrayRef<NamedAttributeList> attrs) {
+void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
assert(!types.empty() &&
"1->0 type remappings don't need to be added explicitly");
- assert(attrs.empty() || types.size() == attrs.size());
-
argTypes.append(types.begin(), types.end());
- if (attrs.empty())
- argAttrs.resize(argTypes.size());
- else
- argAttrs.append(attrs.begin(), attrs.end());
}
/// Remap an input of the original signature with a range of types in the
}
/// Convert the given FunctionType signature.
-auto TypeConverter::convertSignature(FunctionType type,
- ArrayRef<NamedAttributeList> argAttrs)
+auto TypeConverter::convertSignature(FunctionType type)
-> llvm::Optional<SignatureConversion> {
SignatureConversion result(type.getNumInputs());
- if (failed(convertSignature(type, argAttrs, result)))
+ if (failed(convertSignature(type, result)))
return llvm::None;
return result;
}
/// This hook allows for changing a FunctionType signature.
-LogicalResult
-TypeConverter::convertSignature(FunctionType type,
- ArrayRef<NamedAttributeList> argAttrs,
- SignatureConversion &result) {
+LogicalResult TypeConverter::convertSignature(FunctionType type,
+ SignatureConversion &result) {
// Convert the original function arguments.
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
- if (failed(convertSignatureArg(i, type.getInput(i), argAttrs[i], result)))
+ if (failed(convertSignatureArg(i, type.getInput(i), result)))
return failure();
// Convert the original function results.
/// This hook allows for converting a specific argument of a signature.
LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
- NamedAttributeList attrs,
SignatureConversion &result) {
// Try to convert the given input type.
SmallVector<Type, 1> convertedTypes;
return success();
// Otherwise, add the new inputs.
- auto convertedAttrs =
- convertedTypes.size() == 1 ? llvm::makeArrayRef(attrs) : llvm::None;
- result.addInputs(inputNo, convertedTypes, convertedAttrs);
+ result.addInputs(inputNo, convertedTypes);
return success();
}
+/// This hook defines how the signature of a region 'regionIdx', i.e. the
+/// signature of the entry to the region, on the given operation 'op' is
+/// converted. This function should return a valid conversion for the signature
+/// on success, None otherwise.
+///
+/// The default behavior of this function is to invoke 'convertBlockSignature'
+/// on the entry block, if one is present. This function also provides special
+/// handling for FuncOp to update the type signature.
+///
+/// TODO(riverriddle) This should be replaced in favor of using patterns, but
+/// the pattern rewriter needs to know how to properly replace/remap
+/// arguments.
+auto TypeConverter::convertRegionSignature(Operation *op, unsigned regionIdx)
+ -> llvm::Optional<SignatureConversion> {
+ // Provide explicit handling for FuncOp.
+ if (auto funcOp = dyn_cast<FuncOp>(op)) {
+ auto conversion = convertSignature(funcOp.getType());
+ if (conversion)
+ funcOp.setType(conversion->getConvertedType(funcOp.getContext()));
+ return conversion;
+ }
+
+ // Otherwise, default to handle the arguments of the entry block for the given
+ // region.
+ auto ®ion = op->getRegion(regionIdx);
+ if (region.empty())
+ return SignatureConversion(/*numOrigInputs=*/0);
+ return convertBlockSignature(®ion.front());
+}
+
+/// This function converts the type signature of the given block, by invoking
+/// 'convertSignatureArg' for each argument. This function should return a valid
+/// conversion for the signature on success, None otherwise.
+auto TypeConverter::convertBlockSignature(Block *block)
+ -> llvm::Optional<SignatureConversion> {
+ SignatureConversion conversion(block->getNumArguments());
+ for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
+ if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
+ conversion)))
+ return llvm::None;
+ return conversion;
+}
+
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
/// Apply a partial conversion on the given operations, and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize.
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
- ConversionTarget &target,
- OwningRewritePatternList &&patterns) {
- OperationConverter converter(target, patterns, OpConversionMode::Partial);
- return converter.convertOperations(ops);
+LogicalResult mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+ ConversionTarget &target,
+ OwningRewritePatternList &&patterns,
+ TypeConverter *converter) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
+ return opConverter.convertOperations(ops, converter);
}
-LogicalResult
-mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
- OwningRewritePatternList &&patterns) {
+LogicalResult mlir::applyPartialConversion(Operation *op,
+ ConversionTarget &target,
+ OwningRewritePatternList &&patterns,
+ TypeConverter *converter) {
return applyPartialConversion(llvm::makeArrayRef(op), target,
- std::move(patterns));
+ std::move(patterns), converter);
}
/// Apply a complete conversion on the given operations, and all nested
/// operation fails.
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
- OwningRewritePatternList &&patterns) {
- OperationConverter converter(target, patterns, OpConversionMode::Full);
- return converter.convertOperations(ops);
+ OwningRewritePatternList &&patterns,
+ TypeConverter *converter) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+ return opConverter.convertOperations(ops, converter);
}
LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
- OwningRewritePatternList &&patterns) {
+ OwningRewritePatternList &&patterns,
+ TypeConverter *converter) {
return applyFullConversion(llvm::makeArrayRef(op), target,
- std::move(patterns));
-}
-
-//===----------------------------------------------------------------------===//
-// Op + Type Conversion Entry Points
-//===----------------------------------------------------------------------===//
-
-static LogicalResult applyConversion(MutableArrayRef<FuncOp> fns,
- ConversionTarget &target,
- TypeConverter &converter,
- OwningRewritePatternList &&patterns,
- OpConversionMode mode) {
- if (fns.empty())
- return success();
-
- // Build the function converter.
- OperationConverter funcConverter(target, patterns, mode, &converter);
-
- // Try to convert each of the functions within the module.
- SmallVector<NamedAttributeList, 4> argAttrs;
- auto *ctx = fns.front().getContext();
- for (auto func : fns) {
- argAttrs.clear();
- func.getAllArgAttrs(argAttrs);
-
- // Convert the function type using the type converter.
- auto conversion = converter.convertSignature(func.getType(), argAttrs);
- if (!conversion)
- return failure();
-
- // Update the function signature.
- func.setType(conversion->getConvertedType(ctx));
- func.setAllArgAttrs(conversion->getConvertedArgAttrs());
-
- // Convert the body of this function.
- if (failed(funcConverter.convertFunction(func, *conversion)))
- return failure();
- }
-
- return success();
-}
-
-/// Apply a partial conversion on the function operations within the given
-/// module. This method returns failure if a type conversion was encountered.
-LogicalResult
-mlir::applyPartialConversion(ModuleOp module, ConversionTarget &target,
- TypeConverter &converter,
- OwningRewritePatternList &&patterns) {
- SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
- return applyPartialConversion(allFunctions, target, converter,
- std::move(patterns));
-}
-
-/// Apply a partial conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered.
-LogicalResult
-mlir::applyPartialConversion(MutableArrayRef<FuncOp> fns,
- ConversionTarget &target, TypeConverter &converter,
- OwningRewritePatternList &&patterns) {
- return applyConversion(fns, target, converter, std::move(patterns),
- OpConversionMode::Partial);
-}
-
-/// Apply a full conversion on the function operations within the given module.
-/// This method returns failure if a type conversion was encountered, or if the
-/// conversion of any operations failed.
-LogicalResult mlir::applyFullConversion(ModuleOp module,
- ConversionTarget &target,
- TypeConverter &converter,
- OwningRewritePatternList &&patterns) {
- SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
- return applyFullConversion(allFunctions, target, converter,
- std::move(patterns));
-}
-
-/// Apply a full conversion on the given function operations. This method
-/// returns failure if a type conversion was encountered, or if the conversion
-/// of any operation failed.
-LogicalResult mlir::applyFullConversion(MutableArrayRef<FuncOp> fns,
- ConversionTarget &target,
- TypeConverter &converter,
- OwningRewritePatternList &&patterns) {
- return applyConversion(fns, target, converter, std::move(patterns),
- OpConversionMode::Full);
+ std::move(patterns), converter);
}