From 7c755d06aa6681e25cd1d289b937a505de47c1f8 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 28 Jun 2019 11:28:30 -0700 Subject: [PATCH] Refactor DialectConversion to use 'materializeConversion' when a type conversion must persist after the conversion has finished. During conversion, if a type conversion has dangling uses a type conversion must persist after conversion has finished to maintain valid IR. In these cases, we now query the TypeConverter to materialize a conversion for us. This allows for the default case of a full conversion to continue working as expected, but also handle the degenerate cases more robustly. PiperOrigin-RevId: 255637171 --- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 8 ++ mlir/include/mlir/Transforms/DialectConversion.h | 11 ++- mlir/lib/Transforms/DialectConversion.cpp | 106 +++++++++-------------- mlir/test/Transforms/test-legalizer.mlir | 37 ++++---- mlir/test/lib/TestDialect/TestOps.td | 4 + mlir/test/lib/TestDialect/TestPatterns.cpp | 14 ++- 6 files changed, 91 insertions(+), 89 deletions(-) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 34a2dc3..8b2a392 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -331,6 +331,14 @@ protected: return array.toMemref(); return t; } + + /// Materialize a conversion to allow for partial lowering of types. + Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, + ArrayRef inputs, + Location loc) override { + assert(inputs.size() == 1 && "expected only one input value"); + return rewriter.create(loc, inputs[0], resultType); + } }; /// This is lowering to Linalg the parts that can be (matmul and add on arrays) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 599c7f1..00da0d5 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -221,13 +221,13 @@ public: /// one result type by generating a cast operation of some kind. The generated /// operation should produce one result, of 'resultType', with the provided /// 'inputs' as operands. This hook must be overridden when a type conversion - /// results in more than one type. + /// results in more than one type, or if a type conversion may persist after + /// the conversion has finished. virtual Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, ArrayRef inputs, Location loc) { - llvm_unreachable("expected 'materializeConversion' to be overridden when " - "generating 1->N type conversions"); + llvm_unreachable("expected 'materializeConversion' to be overridden"); } }; @@ -337,14 +337,13 @@ private: /// Convert the given module with the provided conversion patterns and type /// conversion object. This function returns failure if a type conversion -/// failed, potentially leaving the IR in an invalid state. +/// failed. LLVM_NODISCARD LogicalResult applyConversionPatterns( Module &module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); /// Convert the given functions with the provided conversion patterns. This -/// function returns failure if a type conversion failed, potentially leaving -/// the IR in an invalid state. +/// function returns failure if a type conversion failed. LLVM_NODISCARD LogicalResult applyConversionPatterns(ArrayRef fns, ConversionTarget &target, diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 88db7b6..f6d6329 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -40,8 +40,10 @@ namespace { /// illegal type to the original type to allow for undoing pending rewrites in /// the case of failure. struct ArgConverter { - ArgConverter(MLIRContext *ctx) - : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {} + ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) + : castOpName(kCastName, rewriter.getContext()), + loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), + rewriter(rewriter) {} /// Erase any rewrites registered for arguments to blocks within the given /// region. This function is called when the given region is to be destroyed. @@ -51,26 +53,21 @@ struct ArgConverter { void discardRewrites(); /// Replace usages of the cast operations with the argument directly. - LogicalResult applyRewrites(); + void applyRewrites(); /// Converts the signature of the given entry block. - void convertSignature(Block *block, PatternRewriter &rewriter, - TypeConverter &converter, + void convertSignature(Block *block, TypeConverter::SignatureConversion &signatureConversion, BlockAndValueMapping &mapping); /// Converts the arguments of the given block. - LogicalResult convertArguments(Block *block, PatternRewriter &rewriter, - TypeConverter &converter, - BlockAndValueMapping &mapping); + LogicalResult convertArguments(Block *block, BlockAndValueMapping &mapping); /// Convert the given block argument given the provided set of new argument /// values that are to replace it. This function returns the operation used /// to perform the conversion. Operation *convertArgument(BlockArgument *origArg, ArrayRef newValues, - PatternRewriter &rewriter, - TypeConverter &converter, BlockAndValueMapping &mapping); /// A utility function used to create a conversion cast operation with the @@ -90,6 +87,12 @@ struct ArgConverter { /// An instance of the unknown location that is used when generating /// producers. Location loc; + + /// The type converter to use when changing types. + TypeConverter *typeConverter; + + /// The pattern rewriter to use when materializing conversions. + PatternRewriter &rewriter; }; constexpr StringLiteral ArgConverter::kCastName; @@ -147,11 +150,9 @@ void ArgConverter::discardRewrites() { } /// Replace usages of the cast operations with the argument directly. -LogicalResult ArgConverter::applyRewrites() { +void ArgConverter::applyRewrites() { Block *block; ArrayRef argOps; - - LogicalResult result = success(); for (auto &mapping : argMapping) { std::tie(block, argOps) = mapping; @@ -169,41 +170,25 @@ LogicalResult ArgConverter::applyRewrites() { continue; } - // Handle the case where this argument had a direct mapping. - if (op->getNumOperands() == 1) { - op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); - // Otherwise, this argument was expected to be dropped. - } else if (!op->getResult(0)->use_empty()) { - // Don't emit another error if we already have one. - if (!failed(result)) { - auto *parent = block->getParent(); - auto diag = emitError(parent->getLoc()) - << "block argument #" << i << " with type " - << op->getResult(0)->getType() - << " has unexpected remaining uses"; - auto *user = *op->getResult(0)->user_begin(); - diag.attachNote(user->getLoc()) - << "unexpected user defined here : " << *user; - result = failure(); - } - // Move this fake producer to the beginning of the parent block, we - // can't recover from this failure and we want to make sure the - // operations get cleaned up. Recovering from this would require - // detecting that an argument would be unused before applying all of - // the operation rewrites, which can get quite expensive. - block->push_front(op); - continue; + // Otherwise, if there are any dangling uses then replace the fake + // conversion operation with one generated by the type converter. This + // is necessary as the cast must persist in the IR after conversion. + auto *opResult = op->getResult(0); + if (!opResult->use_empty()) { + rewriter.setInsertionPointToStart(block); + SmallVector operands(op->getOperands()); + auto *newOp = typeConverter->materializeConversion( + rewriter, opResult->getType(), operands, op->getLoc()); + opResult->replaceAllUsesWith(newOp->getResult(0)); } op->destroy(); } } - return result; } /// Converts the signature of the given entry block. void ArgConverter::convertSignature( - Block *block, PatternRewriter &rewriter, TypeConverter &converter, - TypeConverter::SignatureConversion &signatureConversion, + Block *block, TypeConverter::SignatureConversion &signatureConversion, BlockAndValueMapping &mapping) { unsigned origArgCount = block->getNumArguments(); auto convertedTypes = signatureConversion.getConvertedArgTypes(); @@ -223,8 +208,7 @@ void ArgConverter::convertSignature( remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size); BlockArgument *arg = block->getArgument(i); - newArgMapping.push_back( - convertArgument(arg, remappedValues, rewriter, converter, mapping)); + newArgMapping.push_back(convertArgument(arg, remappedValues, mapping)); } // Erase all of the original arguments. @@ -234,8 +218,6 @@ void ArgConverter::convertSignature( /// Converts the arguments of the given block. LogicalResult ArgConverter::convertArguments(Block *block, - PatternRewriter &rewriter, - TypeConverter &converter, BlockAndValueMapping &mapping) { unsigned origArgCount = block->getNumArguments(); if (origArgCount == 0) @@ -245,7 +227,7 @@ LogicalResult ArgConverter::convertArguments(Block *block, SmallVector, 4> newArgTypes(origArgCount); for (unsigned i = 0; i != origArgCount; ++i) { auto *arg = block->getArgument(i); - if (failed(converter.convertType(arg->getType(), newArgTypes[i]))) + if (failed(typeConverter->convertType(arg->getType(), newArgTypes[i]))) return emitError(block->getParent()->getLoc()) << "could not convert block argument of type " << arg->getType(); } @@ -255,8 +237,8 @@ LogicalResult ArgConverter::convertArguments(Block *block, rewriter.setInsertionPointToStart(block); for (unsigned i = 0; i != origArgCount; ++i) { SmallVector newArgs(block->addArguments(newArgTypes[i])); - newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs, - rewriter, converter, mapping)); + newArgMapping.push_back( + convertArgument(block->getArgument(i), newArgs, mapping)); } // Erase all of the original arguments. @@ -270,8 +252,6 @@ LogicalResult ArgConverter::convertArguments(Block *block, /// to perform the conversion. Operation *ArgConverter::convertArgument(BlockArgument *origArg, ArrayRef newValues, - PatternRewriter &rewriter, - TypeConverter &converter, BlockAndValueMapping &mapping) { // Handle the cases of 1->0 or 1->1 mappings. if (newValues.size() < 2) { @@ -289,8 +269,8 @@ Operation *ArgConverter::convertArgument(BlockArgument *origArg, // Otherwise, this is a 1->N mapping. Call into the provided type converter // to pack the new values. - auto *cast = converter.materializeConversion(rewriter, origArg->getType(), - newValues, loc); + auto *cast = typeConverter->materializeConversion( + rewriter, origArg->getType(), newValues, loc); assert(cast->getNumResults() == 1 && cast->getNumOperands() == newValues.size()); origArg->replaceAllUsesWith(cast->getResult(0)); @@ -370,8 +350,8 @@ struct DialectConversionRewriter final : public PatternRewriter { BlockActionKind kind; }; - DialectConversionRewriter(Region ®ion) - : PatternRewriter(region), argConverter(region.getContext()) {} + DialectConversionRewriter(Region ®ion, TypeConverter *converter) + : PatternRewriter(region), argConverter(converter, *this) {} ~DialectConversionRewriter() = default; /// Return the current state of the rewriter. @@ -438,7 +418,7 @@ struct DialectConversionRewriter final : public PatternRewriter { /// Apply all requested operation rewrites. This method is invoked when the /// conversion process succeeds. - LogicalResult applyRewrites() { + void applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) @@ -454,7 +434,7 @@ struct DialectConversionRewriter final : public PatternRewriter { repl.op->erase(); } - return argConverter.applyRewrites(); + argConverter.applyRewrites(); } /// PatternRewriter hook for replacing the results of an operation. @@ -943,8 +923,8 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, if (typeConverter) { for (Block &block : llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) { - if (failed(rewriter.argConverter.convertArguments( - &block, rewriter, *typeConverter, rewriter.mapping))) + if (failed( + rewriter.argConverter.convertArguments(&block, rewriter.mapping))) return failure(); } } @@ -973,13 +953,12 @@ LogicalResult FunctionConverter::convertFunction( if (f->isExternal()) return success(); - DialectConversionRewriter rewriter(f->getBody()); + DialectConversionRewriter rewriter(f->getBody(), typeConverter); // Update the signature of the entry block. if (signatureConversion) { - rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter, - *typeConverter, *signatureConversion, - rewriter.mapping); + rewriter.argConverter.convertSignature( + &f->getBody().front(), *signatureConversion, rewriter.mapping); } // Rewrite the function body. @@ -990,8 +969,9 @@ LogicalResult FunctionConverter::convertFunction( return failure(); } - // Otherwise the body conversion succeeded, so try to apply all rewrites. - return rewriter.applyRewrites(); + // Otherwise the body conversion succeeded, so apply all rewrites. + rewriter.applyRewrites(); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 7dc076b..66777d7 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-legalize-patterns -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input=fail +// RUN: mlir-opt -test-legalize-patterns %s // CHECK-LABEL: verifyDirectPattern func @verifyDirectPattern() -> i32 { @@ -17,16 +17,16 @@ func @verifyLargerBenefit() -> i32 { // CHECK-LABEL: func @remap_input_1_to_0() func @remap_input_1_to_0(i16) -// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64) -> f64 -func @remap_input_1_to_1(%arg0: i64) -> i64 { - // CHECK-NEXT: return %arg0 : f64 - return %arg0 : i64 +// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64) +func @remap_input_1_to_1(%arg0: i64) { + // CHECK-NEXT: "test.valid"{{.*}} : (f64) + "test.invalid"(%arg0) : (i64) -> () } -// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16) +// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) func @remap_input_1_to_N(%arg0: f32) -> f32 { - // CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> () - "test.return"(%arg0) : (f32) -> () + // CHECK-NEXT: "test.valid"(%arg0, %arg1) : (f16, f16) -> () + "test.invalid"(%arg0) : (f32) -> () } // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16) @@ -38,8 +38,8 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) { // CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64) func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) { - // CHECK-NEXT: return %arg0, %arg1 : f64, f64 - return %arg0, %arg1 : i64, i64 + // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64) + "test.invalid"(%arg0, %arg1) : (i64, i64) -> () } // CHECK-LABEL: func @remap_nested @@ -48,8 +48,8 @@ func @remap_nested() { "foo.region"() ({ // CHECK-NEXT: ^bb1(%i0: f64, %i1: f64): ^bb1(%i0: i64, %unused: i16, %i1: i64): - // CHECK-NEXT: "work"{{.*}} : (f64, f64) - "work"(%i0, %i1) : (i64, i64) -> () + // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64) + "test.invalid"(%i0, %i1) : (i64, i64) -> () }) : () -> () return } @@ -59,10 +59,10 @@ func @remap_moved_region_args() { // CHECK-NEXT: return // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 - // CHECK-NEXT: "work"{{.*}} : (f64, f64, f32) + // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32) "test.region"() ({ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "work"(%i0, %i1, %2) : (i64, i64, f32) -> () + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () return } @@ -73,15 +73,14 @@ func @remap_drop_region() { // CHECK-NEXT: } "test.drop_op"() ({ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "work"(%i0, %i1, %2) : (i64, i64, f32) -> () + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () return } -// ----- - +// CHECK-LABEL: func @dropped_input_in_use func @dropped_input_in_use(%arg: i16, %arg2: i64) { - // expected-error@-1 {{block argument #0 with type 'i16' has unexpected remaining uses}} - // expected-note@+1 {{unexpected user defined here}} + // CHECK-NEXT: "test.cast"{{.*}} : () -> i16 + // CHECK-NEXT: "work"{{.*}} : (i16) "work"(%arg) : (i16) -> () } diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index e12aaac..ba5362b 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -245,5 +245,9 @@ def TestReturnOp : TEST_Op<"return", [Terminator]>, Arguments<(ins Variadic:$inputs)>; def TestCastOp : TEST_Op<"cast">, Arguments<(ins Variadic:$inputs)>, Results<(outs AnyType:$res)>; +def TestInvalidOp : TEST_Op<"invalid", [Terminator]>, + Arguments<(ins Variadic:$inputs)>; +def TestValidOp : TEST_Op<"valid", [Terminator]>, + Arguments<(ins Variadic:$inputs)>; #endif // TEST_OPS diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index f323d7f..bde01f7 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -78,6 +78,17 @@ struct TestDropOp : public ConversionPattern { return matchSuccess(); } }; +/// This pattern simply updates the operands of the given operation. +struct TestPassthroughInvalidOp : public ConversionPattern { + TestPassthroughInvalidOp(MLIRContext *ctx) + : ConversionPattern("test.invalid", 1, ctx) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, llvm::None, operands, + llvm::None); + return matchSuccess(); + } +}; /// This pattern handles the case of a split return value. struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) @@ -139,7 +150,7 @@ struct TestTypeConverter : public TypeConverter { struct TestConversionTarget : public ConversionTarget { TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addLegalOp(); + addLegalOp(); addDynamicallyLegalOp(); } bool isDynamicallyLegal(Operation *op) const final { @@ -155,6 +166,7 @@ struct TestLegalizePatternDriver mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); RewriteListBuilder::build(patterns, &getContext()); TestTypeConverter converter; -- 2.7.4