/// illegal type to the original type to allow for undoing pending rewrites in
/// the case of failure.
struct ArgConverter {
- ArgConverter(MLIRContext *ctx)
- : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {}
+ ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
+ : castOpName(kCastName, rewriter.getContext()),
+ loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+ rewriter(rewriter) {}
/// Erase any rewrites registered for arguments to blocks within the given
/// region. This function is called when the given region is to be destroyed.
void discardRewrites();
/// Replace usages of the cast operations with the argument directly.
- LogicalResult applyRewrites();
+ void applyRewrites();
/// Converts the signature of the given entry block.
- void convertSignature(Block *block, PatternRewriter &rewriter,
- TypeConverter &converter,
+ void convertSignature(Block *block,
TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping);
/// Converts the arguments of the given block.
- LogicalResult convertArguments(Block *block, PatternRewriter &rewriter,
- TypeConverter &converter,
- BlockAndValueMapping &mapping);
+ LogicalResult convertArguments(Block *block, BlockAndValueMapping &mapping);
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
/// to perform the conversion.
Operation *convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
- PatternRewriter &rewriter,
- TypeConverter &converter,
BlockAndValueMapping &mapping);
/// A utility function used to create a conversion cast operation with the
/// An instance of the unknown location that is used when generating
/// producers.
Location loc;
+
+ /// The type converter to use when changing types.
+ TypeConverter *typeConverter;
+
+ /// The pattern rewriter to use when materializing conversions.
+ PatternRewriter &rewriter;
};
constexpr StringLiteral ArgConverter::kCastName;
}
/// Replace usages of the cast operations with the argument directly.
-LogicalResult ArgConverter::applyRewrites() {
+void ArgConverter::applyRewrites() {
Block *block;
ArrayRef<Operation *> argOps;
-
- LogicalResult result = success();
for (auto &mapping : argMapping) {
std::tie(block, argOps) = mapping;
continue;
}
- // Handle the case where this argument had a direct mapping.
- if (op->getNumOperands() == 1) {
- op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
- // Otherwise, this argument was expected to be dropped.
- } else if (!op->getResult(0)->use_empty()) {
- // Don't emit another error if we already have one.
- if (!failed(result)) {
- auto *parent = block->getParent();
- auto diag = emitError(parent->getLoc())
- << "block argument #" << i << " with type "
- << op->getResult(0)->getType()
- << " has unexpected remaining uses";
- auto *user = *op->getResult(0)->user_begin();
- diag.attachNote(user->getLoc())
- << "unexpected user defined here : " << *user;
- result = failure();
- }
- // Move this fake producer to the beginning of the parent block, we
- // can't recover from this failure and we want to make sure the
- // operations get cleaned up. Recovering from this would require
- // detecting that an argument would be unused before applying all of
- // the operation rewrites, which can get quite expensive.
- block->push_front(op);
- continue;
+ // Otherwise, if there are any dangling uses then replace the fake
+ // conversion operation with one generated by the type converter. This
+ // is necessary as the cast must persist in the IR after conversion.
+ auto *opResult = op->getResult(0);
+ if (!opResult->use_empty()) {
+ rewriter.setInsertionPointToStart(block);
+ SmallVector<Value *, 1> operands(op->getOperands());
+ auto *newOp = typeConverter->materializeConversion(
+ rewriter, opResult->getType(), operands, op->getLoc());
+ opResult->replaceAllUsesWith(newOp->getResult(0));
}
op->destroy();
}
}
- return result;
}
/// Converts the signature of the given entry block.
void ArgConverter::convertSignature(
- Block *block, PatternRewriter &rewriter, TypeConverter &converter,
- TypeConverter::SignatureConversion &signatureConversion,
+ Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedArgTypes();
remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
BlockArgument *arg = block->getArgument(i);
- newArgMapping.push_back(
- convertArgument(arg, remappedValues, rewriter, converter, mapping));
+ newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
}
// Erase all of the original arguments.
/// Converts the arguments of the given block.
LogicalResult ArgConverter::convertArguments(Block *block,
- PatternRewriter &rewriter,
- TypeConverter &converter,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
if (origArgCount == 0)
SmallVector<SmallVector<Type, 1>, 4> newArgTypes(origArgCount);
for (unsigned i = 0; i != origArgCount; ++i) {
auto *arg = block->getArgument(i);
- if (failed(converter.convertType(arg->getType(), newArgTypes[i])))
+ if (failed(typeConverter->convertType(arg->getType(), newArgTypes[i])))
return emitError(block->getParent()->getLoc())
<< "could not convert block argument of type " << arg->getType();
}
rewriter.setInsertionPointToStart(block);
for (unsigned i = 0; i != origArgCount; ++i) {
SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
- newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs,
- rewriter, converter, mapping));
+ newArgMapping.push_back(
+ convertArgument(block->getArgument(i), newArgs, mapping));
}
// Erase all of the original arguments.
/// to perform the conversion.
Operation *ArgConverter::convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
- PatternRewriter &rewriter,
- TypeConverter &converter,
BlockAndValueMapping &mapping) {
// Handle the cases of 1->0 or 1->1 mappings.
if (newValues.size() < 2) {
// Otherwise, this is a 1->N mapping. Call into the provided type converter
// to pack the new values.
- auto *cast = converter.materializeConversion(rewriter, origArg->getType(),
- newValues, loc);
+ auto *cast = typeConverter->materializeConversion(
+ rewriter, origArg->getType(), newValues, loc);
assert(cast->getNumResults() == 1 &&
cast->getNumOperands() == newValues.size());
origArg->replaceAllUsesWith(cast->getResult(0));
BlockActionKind kind;
};
- DialectConversionRewriter(Region ®ion)
- : PatternRewriter(region), argConverter(region.getContext()) {}
+ DialectConversionRewriter(Region ®ion, TypeConverter *converter)
+ : PatternRewriter(region), argConverter(converter, *this) {}
~DialectConversionRewriter() = default;
/// Return the current state of the rewriter.
/// Apply all requested operation rewrites. This method is invoked when the
/// conversion process succeeds.
- LogicalResult applyRewrites() {
+ void applyRewrites() {
// Apply all of the rewrites replacements requested during conversion.
for (auto &repl : replacements) {
for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
repl.op->erase();
}
- return argConverter.applyRewrites();
+ argConverter.applyRewrites();
}
/// PatternRewriter hook for replacing the results of an operation.
if (typeConverter) {
for (Block &block :
llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
- if (failed(rewriter.argConverter.convertArguments(
- &block, rewriter, *typeConverter, rewriter.mapping)))
+ if (failed(
+ rewriter.argConverter.convertArguments(&block, rewriter.mapping)))
return failure();
}
}
if (f->isExternal())
return success();
- DialectConversionRewriter rewriter(f->getBody());
+ DialectConversionRewriter rewriter(f->getBody(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
- rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter,
- *typeConverter, *signatureConversion,
- rewriter.mapping);
+ rewriter.argConverter.convertSignature(
+ &f->getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
return failure();
}
- // Otherwise the body conversion succeeded, so try to apply all rewrites.
- return rewriter.applyRewrites();
+ // Otherwise the body conversion succeeded, so apply all rewrites.
+ rewriter.applyRewrites();
+ return success();
}
//===----------------------------------------------------------------------===//
-// RUN: mlir-opt -test-legalize-patterns -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input=fail
+// RUN: mlir-opt -test-legalize-patterns %s
// CHECK-LABEL: verifyDirectPattern
func @verifyDirectPattern() -> i32 {
// CHECK-LABEL: func @remap_input_1_to_0()
func @remap_input_1_to_0(i16)
-// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64) -> f64
-func @remap_input_1_to_1(%arg0: i64) -> i64 {
- // CHECK-NEXT: return %arg0 : f64
- return %arg0 : i64
+// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
+func @remap_input_1_to_1(%arg0: i64) {
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64)
+ "test.invalid"(%arg0) : (i64) -> ()
}
-// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16)
+// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16)
func @remap_input_1_to_N(%arg0: f32) -> f32 {
- // CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> ()
- "test.return"(%arg0) : (f32) -> ()
+ // CHECK-NEXT: "test.valid"(%arg0, %arg1) : (f16, f16) -> ()
+ "test.invalid"(%arg0) : (f32) -> ()
}
// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
// CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64)
func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
- // CHECK-NEXT: return %arg0, %arg1 : f64, f64
- return %arg0, %arg1 : i64, i64
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
}
// CHECK-LABEL: func @remap_nested
"foo.region"() ({
// CHECK-NEXT: ^bb1(%i0: f64, %i1: f64):
^bb1(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "work"{{.*}} : (f64, f64)
- "work"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (i64, i64) -> ()
}) : () -> ()
return
}
// CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "work"{{.*}} : (f64, f64, f32)
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
"test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+ "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
return
}
// CHECK-NEXT: }
"test.drop_op"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
+ "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
return
}
-// -----
-
+// CHECK-LABEL: func @dropped_input_in_use
func @dropped_input_in_use(%arg: i16, %arg2: i64) {
- // expected-error@-1 {{block argument #0 with type 'i16' has unexpected remaining uses}}
- // expected-note@+1 {{unexpected user defined here}}
+ // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
+ // CHECK-NEXT: "work"{{.*}} : (i16)
"work"(%arg) : (i16) -> ()
}