[mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates...
authorIvan Butygin <ivan.butygin@gmail.com>
Fri, 4 Nov 2022 22:26:02 +0000 (23:26 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Sat, 5 Nov 2022 11:10:36 +0000 (12:10 +0100)
Exisitng version is always limited to some specific op.

Differential Revision: https://reviews.llvm.org/D137469

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index 061edb1..6045b22 100644 (file)
@@ -507,6 +507,9 @@ void populateFunctionOpInterfaceTypeConversionPattern(
                                                    patterns, converter);
 }
 
+void populateAnyFunctionOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter);
+
 //===----------------------------------------------------------------------===//
 // Conversion PatternRewriter
 //===----------------------------------------------------------------------===//
index 505127c..61bc4ff 100644 (file)
@@ -3056,6 +3056,29 @@ auto TypeConverter::convertBlockSignature(Block *block)
 // FunctionOpInterfaceSignatureConversion
 //===----------------------------------------------------------------------===//
 
+static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
+                                        TypeConverter &typeConverter,
+                                        ConversionPatternRewriter &rewriter) {
+  FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
+
+  // Convert the original function types.
+  TypeConverter::SignatureConversion result(type.getNumInputs());
+  SmallVector<Type, 1> newResults;
+  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
+      failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
+      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
+                                         typeConverter, &result)))
+    return failure();
+
+  // Update the function signature in-place.
+  auto newType = FunctionType::get(rewriter.getContext(),
+                                   result.getConvertedTypes(), newResults);
+
+  rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
+
+  return success();
+}
+
 /// Create a default conversion pattern that rewrites the type signature of a
 /// FunctionOpInterface op. This only supports ops which use FunctionType to
 /// represent their type.
@@ -3067,27 +3090,21 @@ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
       : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
                   ConversionPatternRewriter &rewriter) const override {
     FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
-    FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
-
-    // Convert the original function types.
-    TypeConverter::SignatureConversion result(type.getNumInputs());
-    SmallVector<Type, 1> newResults;
-    if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
-        failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
-        failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
-                                           *typeConverter, &result)))
-      return failure();
-
-    // Update the function signature in-place.
-    auto newType = FunctionType::get(rewriter.getContext(),
-                                     result.getConvertedTypes(), newResults);
+    return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
+  }
+};
 
-    rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); });
+struct AnyFunctionOpInterfaceSignatureConversion
+    : public OpInterfaceConversionPattern<FunctionOpInterface> {
+  using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
 
-    return success();
+  LogicalResult
+  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
+                  ConversionPatternRewriter &rewriter) const override {
+    return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
   }
 };
 } // namespace
@@ -3099,6 +3116,12 @@ void mlir::populateFunctionOpInterfaceTypeConversionPattern(
       functionLikeOpName, patterns.getContext(), converter);
 }
 
+void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter) {
+  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
+      converter, patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//
index 17c8c1f..12f3747 100644 (file)
@@ -786,8 +786,8 @@ struct TestLegalizePatternDriver
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
              TestCreateUnregisteredOp>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
-    mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-        patterns, converter);
+    mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+                                                              converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
 
     // Define the conversion target used for the test.
@@ -1313,8 +1313,8 @@ struct TestTypeConversionDriver
                  TestTestSignatureConversionNoConverter>(converter,
                                                          &getContext());
     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
-    mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-        patterns, converter);
+    mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+                                                              converter);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))