Dialect conversion: decouple function signature conversion from type conversion
authorAlex Zinenko <zinenko@google.com>
Fri, 15 Feb 2019 13:06:15 +0000 (05:06 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:28:41 +0000 (16:28 -0700)
Function types are built-in in MLIR and affect the validity of the IR itself.
However, advanced target dialects such as the LLVM IR dialect may include
custom function types.  Until now, dialect conversion was expecting function
types not to be converted to the custom type: although the signatures was
allowed to change, the outer type must have been an mlir::FunctionType.  This
effectively prevented dialect conversion from creating instructions that
operate on values of the custom function type.

Dissociate function signature conversion from general type conversion.
Function signature conversion must still produce an mlir::FunctionType and is
used in places where built-in types are required to make IR valid.  General
type conversion is used for SSA values, including function and block arguments
and function results.

Exercise this behavior in the LLVM IR dialect conversion by converting function
types to LLVM IR function pointer types.  The pointer to a function is chosen
to provide consistent lowering of higher-order functions: while it is possible
to have a value of function type, it is not possible to create a function type
accepting a returning another function type.

PiperOrigin-RevId: 234124494

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/LLVMIR/convert-funcs.mlir [new file with mode: 0644]

index 348201b3976b8ca061fcb76c1700e07c5ed87fe0..7bd08bbd7662753123e6e96cd6a4ccf280ddb497 100644 (file)
@@ -164,12 +164,12 @@ protected:
   initConverters(MLIRContext *mlirContext) = 0;
 
   /// Derived classes must reimplement this hook if they need to convert
-  /// block or function argument types or function result types.
-  ///
-  /// For functions types, this function will be passed a function type and the
-  /// result must be another function type with arguments and results converted.
-  /// Note: even if some target dialects have first-class function types, they
-  /// cannot be used at the top level of MLIR function signature.
+  /// block or function argument types or function result types.  If the target
+  /// dialect has support for custom first-class function types, convertType
+  /// should create those types for arguments of MLIR function type.  It can be
+  /// used for values (constant, operands, resutls) of function type but not for
+  /// the function signatures.  For the latter, convertFunctionSignatureType is
+  /// used instead.
   ///
   /// For block attribute types, this function will be called for each attribute
   /// individually.
@@ -178,6 +178,19 @@ protected:
   /// default-constructed Type.  The failure will be then propagated to trigger
   /// the pass failure.
   virtual Type convertType(Type t) { return t; }
+
+  /// Derived classes must reimplement this hook if they need to change the
+  /// function signature during conversion.  This function will be called on
+  /// a function type corresponding to a function signature and must produce the
+  /// converted MLIR function type.
+  ///
+  /// Note: even if some target dialects have first-class function types, they
+  /// cannot be used at the top level of MLIR function signature.
+  ///
+  /// The default behavior of this function is to call convertType on individual
+  /// function operands and results, and then create a new MLIR function type
+  /// from those.
+  virtual FunctionType convertFunctionSignatureType(FunctionType t);
 };
 
 } // end namespace mlir
index 22b7bc50cfc734458b8a112c0421ec5ea3a7cafb..fa0ee7d4723bed4fe9e79228c0a916f03b271893 100644 (file)
@@ -33,6 +33,7 @@
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
 
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Type.h"
 
@@ -60,6 +61,14 @@ public:
   static Type pack(ArrayRef<Type> types, llvm::Module &llvmModule,
                    MLIRContext &context);
 
+  // Convert a function signature type to the LLVM IR dialect.  The outer
+  // function type remains `mlir::FunctionType`.  Argument types are converted
+  // to LLVM IR as is.  If the function returns a single result, its type is
+  // converted.  Otherwise, the types of results are packed into an LLVM IR
+  // structure type.
+  static FunctionType convertFunctionSignature(FunctionType t,
+                                               llvm::Module &llvmModule);
+
 private:
   // Construct a type converter.
   explicit TypeConverter(llvm::Module &llvmModule, MLIRContext *context)
@@ -70,7 +79,11 @@ private:
   // one.  Additionally, if the function returns more than one value, pack the
   // results into an LLVM IR structure type so that the converted function type
   // returns at most one result.
-  FunctionType convertFunctionType(FunctionType type);
+  Type convertFunctionType(FunctionType type);
+
+  // Convert function type arguments and results without converting the
+  // function type itself.
+  FunctionType convertFunctionSignatureType(FunctionType type);
 
   // Convert the index type.  Uses llvmModule data layout to create an integer
   // of the pointer bitwidth.
@@ -187,8 +200,29 @@ Type TypeConverter::getPackedResultType(ArrayRef<Type> types) {
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
-FunctionType TypeConverter::convertFunctionType(FunctionType type) {
+Type TypeConverter::convertFunctionType(FunctionType type) {
   // Convert argument types one by one and check for errors.
+  SmallVector<llvm::Type *, 8> argTypes;
+  for (auto t : type.getInputs()) {
+    auto converted = convertType(t);
+    if (!converted)
+      return {};
+    argTypes.push_back(unwrap(converted));
+  }
+
+  // If function does not return anything, create the void result type,
+  // if it returns on element, convert it, otherwise pack the result types into
+  // a struct.
+  llvm::Type *resultType = type.getNumResults() == 0
+                               ? llvm::Type::getVoidTy(llvmContext)
+                               : unwrap(getPackedResultType(type.getResults()));
+  if (!resultType)
+    return {};
+  return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
+                  ->getPointerTo());
+}
+
+FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) {
   SmallVector<Type, 8> argTypes;
   for (auto t : type.getInputs()) {
     auto converted = convertType(t);
@@ -199,13 +233,13 @@ FunctionType TypeConverter::convertFunctionType(FunctionType type) {
 
   // If function does not return anything, return immediately.
   if (type.getNumResults() == 0)
-    return FunctionType::get(argTypes, {}, mlirContext);
+    return FunctionType::get(argTypes, {}, type.getContext());
 
-  // Convert result types to a single LLVM result type.
-  Type resultType = getPackedResultType(type.getResults());
-  if (!resultType)
-    return {};
-  return FunctionType::get(argTypes, {resultType}, mlirContext);
+  // Otherwise pack the result types into a struct.
+  if (auto result = getPackedResultType(type.getResults()))
+    return FunctionType::get(argTypes, {result}, type.getContext());
+
+  return {};
 }
 
 // MemRefs are converted into LLVM structure types to accommodate dynamic sizes.
@@ -269,6 +303,11 @@ Type TypeConverter::convert(Type t, llvm::Module &module) {
   return TypeConverter(module, t.getContext()).convertType(t);
 }
 
+FunctionType TypeConverter::convertFunctionSignature(FunctionType t,
+                                                     llvm::Module &module) {
+  return TypeConverter(module, t.getContext()).convertFunctionSignatureType(t);
+}
+
 Type TypeConverter::getMemRefElementPtrType(MemRefType t,
                                             llvm::Module &module) {
   auto elementType = t.getElementType();
@@ -995,6 +1034,11 @@ protected:
     return TypeConverter::convert(t, *module);
   }
 
+  // Convert function signatures using the stored LLVM IR module.
+  FunctionType convertFunctionSignatureType(FunctionType t) override {
+    return TypeConverter::convertFunctionSignature(t, *module);
+  }
+
 private:
   // Storage for the conversion patterns.
   llvm::BumpPtrAllocator converterStorage;
index 2d9a703773191d0f8e75173d09281ddfd469d7ac..60e7b11cf389265d406953e6b123ea4a1fa3d5ae 100644 (file)
@@ -218,7 +218,8 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
 
   // Create a new function with argument types and result types converted.  Wrap
   // it into a unique_ptr to make sure it is cleaned up in case of error.
-  Type newFunctionType = dialectConversion->convertType(f->getType());
+  Type newFunctionType =
+      dialectConversion->convertFunctionSignatureType(f->getType());
   if (!newFunctionType)
     return emitError("could not convert function type");
   auto newFunction = llvm::make_unique<Function>(
@@ -305,6 +306,23 @@ bool impl::FunctionConversion::run(Module *module) {
   return false;
 }
 
+// Create a function type with arguments and results converted.
+FunctionType
+DialectConversion::convertFunctionSignatureType(FunctionType type) {
+  SmallVector<Type, 8> arguments;
+  SmallVector<Type, 4> results;
+
+  arguments.reserve(type.getNumInputs());
+  for (auto t : type.getInputs())
+    arguments.push_back(convertType(t));
+
+  results.reserve(type.getNumResults());
+  for (auto t : type.getResults())
+    results.push_back(convertType(t));
+
+  return FunctionType::get(arguments, results, type.getContext());
+}
+
 PassResult DialectConversion::runOnModule(Module *m) {
   return impl::FunctionConversion::convert(this, m) ? failure() : success();
 }
diff --git a/mlir/test/LLVMIR/convert-funcs.mlir b/mlir/test/LLVMIR/convert-funcs.mlir
new file mode 100644 (file)
index 0000000..ed828a3
--- /dev/null
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -convert-to-llvmir %s | FileCheck %s
+
+//CHECK: func @second_order_arg(!llvm<"void ()*">)
+func @second_order_arg(%arg0 : () -> ())
+
+//CHECK: func @second_order_result() -> !llvm<"void ()*">
+func @second_order_result() -> (() -> ())
+
+//CHECK: func @second_order_multi_result() -> !llvm<"{ i32 ()*, i64 ()*, float ()* }">
+func @second_order_multi_result() -> (() -> (i32), () -> (i64), () -> (f32))
+
+//CHECK: func @third_order(!llvm<"void ()* (void ()*)*">) -> !llvm<"void ()* (void ()*)*">
+func @third_order(%arg0 : (() -> ()) -> (() -> ())) -> ((() -> ()) -> (() -> ()))
+
+//CHECK: func @fifth_order_left(!llvm<"void (void (void (void ()*)*)*)*">)
+func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
+
+//CHECK: func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
+func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))
+
+//CHECK-LABEL: func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {
+func @pass_through(%arg0: () -> ()) -> (() -> ()) {
+//CHECK-NEXT:   "llvm.br"()[^bb1(%arg0 : !llvm<"void ()*">)] : () -> ()
+  br ^bb1(%arg0 : () -> ())
+
+//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">):     // pred: ^bb0
+^bb1(%bbarg: () -> ()):
+//CHECK-NEXT:   "llvm.return"(%0) : (!llvm<"void ()*">) -> ()
+  return %bbarg : () -> ()
+}