/// Attempt to convert the signature of the given block, if successful a new
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
- FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
- ConversionValueMapping &mapping);
+ FailureOr<Block *>
+ convertSignature(Block *block, TypeConverter &converter,
+ ConversionValueMapping &mapping,
+ SmallVectorImpl<BlockArgument> &argReplacements);
/// Apply the given signature conversion on the given block. The new block
/// containing the updated signature is returned. If no conversions were
Block *applySignatureConversion(
Block *block, TypeConverter &converter,
TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping);
+ ConversionValueMapping &mapping,
+ SmallVectorImpl<BlockArgument> &argReplacements);
/// Insert a new conversion into the cache.
void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
//===----------------------------------------------------------------------===//
// Conversion
-FailureOr<Block *>
-ArgConverter::convertSignature(Block *block, TypeConverter &converter,
- ConversionValueMapping &mapping) {
+FailureOr<Block *> ArgConverter::convertSignature(
+ Block *block, TypeConverter &converter, ConversionValueMapping &mapping,
+ SmallVectorImpl<BlockArgument> &argReplacements) {
// Check if the block was already converted. If the block is detached,
// conservatively assume it is going to be deleted.
if (hasBeenConverted(block) || !block->getParent())
// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter.convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping);
+ return applySignatureConversion(block, converter, *conversion, mapping,
+ argReplacements);
return failure();
}
Block *ArgConverter::applySignatureConversion(
Block *block, TypeConverter &converter,
TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping) {
+ ConversionValueMapping &mapping,
+ SmallVectorImpl<BlockArgument> &argReplacements) {
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
+ argReplacements.push_back(origArg);
continue;
}
newArg = replArgs.front();
}
mapping.map(origArg, newArg);
+ argReplacements.push_back(origArg);
info.argInfo[i] =
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
Block *block, TypeConverter &converter,
TypeConverter::SignatureConversion *conversion) {
FailureOr<Block *> result =
- conversion ? argConverter.applySignatureConversion(block, converter,
- *conversion, mapping)
- : argConverter.convertSignature(block, converter, mapping);
+ conversion ? argConverter.applySignatureConversion(
+ block, converter, *conversion, mapping, argReplacements)
+ : argConverter.convertSignature(block, converter, mapping,
+ argReplacements);
if (Block *newBlock = result.getValue()) {
if (newBlock != block)
blockActions.push_back(BlockAction::getTypeConversion(newBlock));
}
};
+/// Call signature conversion and then fail the rewrite to trigger the undo
+/// mechanism.
+struct TestSignatureConversionUndo
+ : public OpConversionPattern<TestSignatureConversionUndoOp> {
+ using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
+ return failure();
+ }
+};
+
+/// Just forward the operands to the root op. This is essentially a no-op
+/// pattern that is used to trigger target materialization.
+struct TestTypeConsumerForward
+ : public OpConversionPattern<TestTypeConsumerOp> {
+ using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
+ return success();
+ }
+};
+
struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry ®istry) const override {
// Initialize the set of rewrite patterns.
OwningRewritePatternList patterns;
- patterns.insert<TestTypeConversionProducer>(converter, &getContext());
+ patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo>(converter, &getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);