From: River Riddle Date: Fri, 21 Jun 2019 16:29:46 +0000 (-0700) Subject: Add support for 1->N type mappings in the dialect conversion infrastructure. To suppo... X-Git-Tag: llvmorg-11-init~1466^2~1375 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=704a7fb13ee45f6056d7922e2b7010c19a2148ab;p=platform%2Fupstream%2Fllvm.git Add support for 1->N type mappings in the dialect conversion infrastructure. To support these mappings a hook must be overridden on the type converter: 'materializeConversion' :to generate a cast operation from the new types to the old type. This operation is automatically erased if all uses are removed, otherwise it remains in the IR for the user to handle. PiperOrigin-RevId: 254411383 --- diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index ede1120..599c7f1 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -216,6 +216,19 @@ public: virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type, NamedAttributeList attrs, SignatureConversion &result); + + /// This hook allows for materializing a conversion from a set of types into + /// one result type by generating a cast operation of some kind. The generated + /// operation should produce one result, of 'resultType', with the provided + /// 'inputs' as operands. This hook must be overridden when a type conversion + /// results in more than one type. + virtual Operation *materializeConversion(PatternRewriter &rewriter, + Type resultType, + ArrayRef inputs, + Location loc) { + llvm_unreachable("expected 'materializeConversion' to be overridden when " + "generating 1->N type conversions"); + } }; /// This class describes a specific conversion target. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 3cec1f8..1684655 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -51,6 +51,12 @@ struct ArgConverter { if (it == argMapping.end()) continue; for (auto *op : it->second) { + // If the operation exists within the parent block, like with 1->N cast + // operations, we don't need to drop them. They will be automatically + // cleaned up with the region is destroyed. + if (op->getBlock()) + continue; + op->dropAllDefinedValueUses(); op->destroy(); } @@ -77,7 +83,13 @@ struct ArgConverter { auto *op = argOps[i]; auto *arg = block->addArgument(op->getResult(0)->getType()); op->getResult(0)->replaceAllUsesWith(arg); - op->destroy(); + + // If this was a 1->N value mapping it exists within the parent block so + // erase it instead of destroying. + if (op->getBlock()) + op->erase(); + else + op->destroy(); } } argMapping.clear(); @@ -97,8 +109,14 @@ struct ArgConverter { auto *op = argOps[i]; // Handle the case of a 1->N value mapping. - if (op->getNumOperands() > 1) - llvm_unreachable("1->N argument mappings are currently not handled"); + if (op->getNumOperands() > 1) { + // If all of the uses were removed, we can drop this op. Otherwise, + // keep the operation alive and let the user handle any remaining + // usages. + if (op->use_empty()) + op->erase(); + continue; + } // Handle the case where this argument had a direct mapping. if (op->getNumOperands() == 1) { @@ -132,7 +150,8 @@ struct ArgConverter { } /// Converts the signature of the given entry block. - void convertSignature(Block *block, + void convertSignature(Block *block, PatternRewriter &rewriter, + TypeConverter &converter, TypeConverter::SignatureConversion &signatureConversion, BlockAndValueMapping &mapping) { unsigned origArgCount = block->getNumArguments(); @@ -146,13 +165,15 @@ struct ArgConverter { // Remap each of the original arguments as determined by the signature // conversion. auto &newArgMapping = argMapping[block]; + rewriter.setInsertionPointToStart(block); for (unsigned i = 0; i != origArgCount; ++i) { ArrayRef remappedValues; if (auto inputMap = signatureConversion.getInputMapping(i)) remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size); BlockArgument *arg = block->getArgument(i); - newArgMapping.push_back(convertArgument(arg, remappedValues, mapping)); + newArgMapping.push_back( + convertArgument(arg, remappedValues, rewriter, converter, mapping)); } // Erase all of the original arguments. @@ -161,7 +182,8 @@ struct ArgConverter { } /// Converts the arguments of the given block. - LogicalResult convertArguments(Block *block, TypeConverter &converter, + LogicalResult convertArguments(Block *block, PatternRewriter &rewriter, + TypeConverter &converter, BlockAndValueMapping &mapping) { unsigned origArgCount = block->getNumArguments(); if (origArgCount == 0) @@ -178,10 +200,11 @@ struct ArgConverter { // Remap all of the original argument values. auto &newArgMapping = argMapping[block]; + rewriter.setInsertionPointToStart(block); for (unsigned i = 0; i != origArgCount; ++i) { SmallVector newArgs(block->addArguments(newArgTypes[i])); - newArgMapping.push_back( - convertArgument(block->getArgument(i), newArgs, mapping)); + newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs, + rewriter, converter, mapping)); } // Erase all of the original arguments. @@ -195,6 +218,8 @@ struct ArgConverter { /// to perform the conversion. Operation *convertArgument(BlockArgument *origArg, ArrayRef newValues, + PatternRewriter &rewriter, + TypeConverter &converter, BlockAndValueMapping &mapping) { // Handle the cases of 1->0 or 1->1 mappings. if (newValues.size() < 2) { @@ -209,7 +234,15 @@ struct ArgConverter { mapping.map(cast->getResult(0), newValues[0]); return cast; } - llvm_unreachable("1->N argument mappings are currently not handled"); + + // Otherwise, this is a 1->N mapping. Call into the provided type converter + // to pack the new values. + auto *cast = converter.materializeConversion(rewriter, origArg->getType(), + newValues, loc); + assert(cast->getNumResults() == 1 && + cast->getNumOperands() == newValues.size()); + origArg->replaceAllUsesWith(cast->getResult(0)); + return cast; } /// A utility function used to create a conversion cast operation with the @@ -874,10 +907,11 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, // types. if (typeConverter) { for (Block &block : - llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) - if (failed(rewriter.argConverter.convertArguments(&block, *typeConverter, - rewriter.mapping))) + llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) { + if (failed(rewriter.argConverter.convertArguments( + &block, rewriter, *typeConverter, rewriter.mapping))) return failure(); + } } // Store the number of blocks before conversion (new blocks may be added due @@ -909,8 +943,9 @@ LogicalResult FunctionConverter::convertFunction( // Update the signature of the entry block. if (signatureConversion) { - rewriter.argConverter.convertSignature( - &f->getBody().front(), *signatureConversion, rewriter.mapping); + rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter, + *typeConverter, *signatureConversion, + rewriter.mapping); } // Rewrite the function body. diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 0286284..dfc9a0c 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -227,4 +227,13 @@ def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>; def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>; def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>; +//===----------------------------------------------------------------------===// +// Test Type Legalization +//===----------------------------------------------------------------------===// + +def TestReturnOp : TEST_Op<"return", [Terminator]>, + Arguments<(ins Variadic:$inputs)>; +def TestCastOp : TEST_Op<"cast">, + Arguments<(ins Variadic:$inputs)>, Results<(outs AnyType:$res)>; + #endif // TEST_OPS diff --git a/mlir/test/TestDialect/TestPatterns.cpp b/mlir/test/TestDialect/TestPatterns.cpp index d4e5f79..f323d7f 100644 --- a/mlir/test/TestDialect/TestPatterns.cpp +++ b/mlir/test/TestDialect/TestPatterns.cpp @@ -49,6 +49,7 @@ static mlir::PassRegistration //===----------------------------------------------------------------------===// // Legalization Driver. //===----------------------------------------------------------------------===// + namespace { /// This pattern is a simple pattern that inlines the first region of a given /// operation into the parent region. @@ -77,6 +78,29 @@ struct TestDropOp : public ConversionPattern { return matchSuccess(); } }; +/// This pattern handles the case of a split return value. +struct TestSplitReturnType : public ConversionPattern { + TestSplitReturnType(MLIRContext *ctx) + : ConversionPattern("test.return", 1, ctx) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const final { + // Check for a return of F32. + if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) + return matchFailure(); + + // Check if the first operation is a cast operation, if it is we use the + // results directly. + auto *defOp = operands[0]->getDefiningOp(); + if (auto packerOp = llvm::dyn_cast_or_null(defOp)) { + SmallVector returnOperands(packerOp.getOperands()); + rewriter.replaceOpWithNewOp(op, returnOperands); + return matchSuccess(); + } + + // Otherwise, fail to match. + return matchFailure(); + } +}; } // namespace namespace { @@ -94,10 +118,35 @@ struct TestTypeConverter : public TypeConverter { return success(); } + // Split F32 into F16,F16. + if (t.isF32()) { + results.assign(2, FloatType::getF16(t.getContext())); + return success(); + } + // Otherwise, convert the type directly. results.push_back(t); return success(); } + + /// Override the hook to materialize a conversion. This is necessary because + /// we generate 1->N type mappings. + Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, + ArrayRef inputs, Location loc) { + return rewriter.create(loc, resultType, inputs); + } +}; + +struct TestConversionTarget : public ConversionTarget { + TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalOp(); + addDynamicallyLegalOp(); + } + bool isDynamicallyLegal(Operation *op) const final { + // Don't allow F32 operands. + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return type.isF32(); }); + } }; struct TestLegalizePatternDriver @@ -105,12 +154,11 @@ struct TestLegalizePatternDriver void runOnModule() override { mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - RewriteListBuilder::build( - patterns, &getContext()); + RewriteListBuilder::build(patterns, &getContext()); TestTypeConverter converter; - ConversionTarget target(getContext()); - target.addLegalOp(); + TestConversionTarget target(getContext()); if (failed(applyConversionPatterns(getModule(), target, converter, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index b71e149..449fba90 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -23,6 +23,19 @@ func @remap_input_1_to_1(%arg0: i64) -> i64 { return %arg0 : i64 } +// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16) +func @remap_input_1_to_N(%arg0: f32) -> f32 { + // CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> () + "test.return"(%arg0) : (f32) -> () +} + +// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16) +func @remap_input_1_to_N_remaining_use(%arg0: f32) { + // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 + // CHECK-NEXT: "work"([[CAST]]) : (f32) -> () + "work"(%arg0) : (f32) -> () +} + // CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64) func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) { // CHECK-NEXT: return %arg0, %arg1 : f64, f64 @@ -44,11 +57,12 @@ func @remap_nested() { // CHECK-LABEL: func @remap_moved_region_args func @remap_moved_region_args() { // CHECK-NEXT: return - // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64): - // CHECK-NEXT: "work"{{.*}} : (f64, f64) + // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): + // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 + // CHECK-NEXT: "work"{{.*}} : (f64, f64, f32) "test.region"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64): - "work"(%i0, %i1) : (i64, i64) -> () + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "work"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () return } @@ -58,8 +72,8 @@ func @remap_drop_region() { // CHECK-NEXT: return // CHECK-NEXT: } "test.drop_op"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64): - "work"(%i0, %i1) : (i64, i64) -> () + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "work"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () return }