From: Ivan Butygin Date: Fri, 4 Nov 2022 22:26:02 +0000 (+0100) Subject: [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates... X-Git-Tag: upstream/17.0.6~28443 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ed4749f9373d0079a69e947486aa29042d606458;p=platform%2Fupstream%2Fllvm.git [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates on any `FunctionOpInterface` Exisitng version is always limited to some specific op. Differential Revision: https://reviews.llvm.org/D137469 --- diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 061edb1..6045b22 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -507,6 +507,9 @@ void populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); } +void populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter); + //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 505127c..61bc4ff 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3056,6 +3056,29 @@ auto TypeConverter::convertBlockSignature(Block *block) // FunctionOpInterfaceSignatureConversion //===----------------------------------------------------------------------===// +static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, + TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + FunctionType type = funcOp.getFunctionType().cast(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + + return success(); +} + /// Create a default conversion pattern that rewrites the type signature of a /// FunctionOpInterface op. This only supports ops which use FunctionType to /// represent their type. @@ -3067,27 +3090,21 @@ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef /*operands*/, ConversionPatternRewriter &rewriter) const override { FunctionOpInterface funcOp = cast(op); - FunctionType type = funcOp.getFunctionType().cast(); - - // Convert the original function types. - TypeConverter::SignatureConversion result(type.getNumInputs()); - SmallVector newResults; - if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), - *typeConverter, &result))) - return failure(); - - // Update the function signature in-place. - auto newType = FunctionType::get(rewriter.getContext(), - result.getConvertedTypes(), newResults); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; - rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); }); +struct AnyFunctionOpInterfaceSignatureConversion + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern::OpInterfaceConversionPattern; - return success(); + LogicalResult + matchAndRewrite(FunctionOpInterface funcOp, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); } }; } // namespace @@ -3099,6 +3116,12 @@ void mlir::populateFunctionOpInterfaceTypeConversionPattern( functionLikeOpName, patterns.getContext(), converter); } +void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter) { + patterns.add( + converter, patterns.getContext()); +} + //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 17c8c1f..12f3747 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -786,8 +786,8 @@ struct TestLegalizePatternDriver TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); // Define the conversion target used for the test. @@ -1313,8 +1313,8 @@ struct TestTypeConversionDriver TestTestSignatureConversionNoConverter>(converter, &getContext()); patterns.add(&getContext()); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))