initConverters(MLIRContext *mlirContext) = 0;
/// Derived classes must reimplement this hook if they need to convert
- /// block or function argument types or function result types.
- ///
- /// For functions types, this function will be passed a function type and the
- /// result must be another function type with arguments and results converted.
- /// Note: even if some target dialects have first-class function types, they
- /// cannot be used at the top level of MLIR function signature.
+ /// block or function argument types or function result types. If the target
+ /// dialect has support for custom first-class function types, convertType
+ /// should create those types for arguments of MLIR function type. It can be
+ /// used for values (constant, operands, resutls) of function type but not for
+ /// the function signatures. For the latter, convertFunctionSignatureType is
+ /// used instead.
///
/// For block attribute types, this function will be called for each attribute
/// individually.
/// default-constructed Type. The failure will be then propagated to trigger
/// the pass failure.
virtual Type convertType(Type t) { return t; }
+
+ /// Derived classes must reimplement this hook if they need to change the
+ /// function signature during conversion. This function will be called on
+ /// a function type corresponding to a function signature and must produce the
+ /// converted MLIR function type.
+ ///
+ /// Note: even if some target dialects have first-class function types, they
+ /// cannot be used at the top level of MLIR function signature.
+ ///
+ /// The default behavior of this function is to call convertType on individual
+ /// function operands and results, and then create a new MLIR function type
+ /// from those.
+ virtual FunctionType convertFunctionSignatureType(FunctionType t);
};
} // end namespace mlir
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
static Type pack(ArrayRef<Type> types, llvm::Module &llvmModule,
MLIRContext &context);
+ // Convert a function signature type to the LLVM IR dialect. The outer
+ // function type remains `mlir::FunctionType`. Argument types are converted
+ // to LLVM IR as is. If the function returns a single result, its type is
+ // converted. Otherwise, the types of results are packed into an LLVM IR
+ // structure type.
+ static FunctionType convertFunctionSignature(FunctionType t,
+ llvm::Module &llvmModule);
+
private:
// Construct a type converter.
explicit TypeConverter(llvm::Module &llvmModule, MLIRContext *context)
// one. Additionally, if the function returns more than one value, pack the
// results into an LLVM IR structure type so that the converted function type
// returns at most one result.
- FunctionType convertFunctionType(FunctionType type);
+ Type convertFunctionType(FunctionType type);
+
+ // Convert function type arguments and results without converting the
+ // function type itself.
+ FunctionType convertFunctionSignatureType(FunctionType type);
// Convert the index type. Uses llvmModule data layout to create an integer
// of the pointer bitwidth.
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
-FunctionType TypeConverter::convertFunctionType(FunctionType type) {
+Type TypeConverter::convertFunctionType(FunctionType type) {
// Convert argument types one by one and check for errors.
+ SmallVector<llvm::Type *, 8> argTypes;
+ for (auto t : type.getInputs()) {
+ auto converted = convertType(t);
+ if (!converted)
+ return {};
+ argTypes.push_back(unwrap(converted));
+ }
+
+ // If function does not return anything, create the void result type,
+ // if it returns on element, convert it, otherwise pack the result types into
+ // a struct.
+ llvm::Type *resultType = type.getNumResults() == 0
+ ? llvm::Type::getVoidTy(llvmContext)
+ : unwrap(getPackedResultType(type.getResults()));
+ if (!resultType)
+ return {};
+ return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
+ ->getPointerTo());
+}
+
+FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) {
SmallVector<Type, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
// If function does not return anything, return immediately.
if (type.getNumResults() == 0)
- return FunctionType::get(argTypes, {}, mlirContext);
+ return FunctionType::get(argTypes, {}, type.getContext());
- // Convert result types to a single LLVM result type.
- Type resultType = getPackedResultType(type.getResults());
- if (!resultType)
- return {};
- return FunctionType::get(argTypes, {resultType}, mlirContext);
+ // Otherwise pack the result types into a struct.
+ if (auto result = getPackedResultType(type.getResults()))
+ return FunctionType::get(argTypes, {result}, type.getContext());
+
+ return {};
}
// MemRefs are converted into LLVM structure types to accommodate dynamic sizes.
return TypeConverter(module, t.getContext()).convertType(t);
}
+FunctionType TypeConverter::convertFunctionSignature(FunctionType t,
+ llvm::Module &module) {
+ return TypeConverter(module, t.getContext()).convertFunctionSignatureType(t);
+}
+
Type TypeConverter::getMemRefElementPtrType(MemRefType t,
llvm::Module &module) {
auto elementType = t.getElementType();
return TypeConverter::convert(t, *module);
}
+ // Convert function signatures using the stored LLVM IR module.
+ FunctionType convertFunctionSignatureType(FunctionType t) override {
+ return TypeConverter::convertFunctionSignature(t, *module);
+ }
+
private:
// Storage for the conversion patterns.
llvm::BumpPtrAllocator converterStorage;
// Create a new function with argument types and result types converted. Wrap
// it into a unique_ptr to make sure it is cleaned up in case of error.
- Type newFunctionType = dialectConversion->convertType(f->getType());
+ Type newFunctionType =
+ dialectConversion->convertFunctionSignatureType(f->getType());
if (!newFunctionType)
return emitError("could not convert function type");
auto newFunction = llvm::make_unique<Function>(
return false;
}
+// Create a function type with arguments and results converted.
+FunctionType
+DialectConversion::convertFunctionSignatureType(FunctionType type) {
+ SmallVector<Type, 8> arguments;
+ SmallVector<Type, 4> results;
+
+ arguments.reserve(type.getNumInputs());
+ for (auto t : type.getInputs())
+ arguments.push_back(convertType(t));
+
+ results.reserve(type.getNumResults());
+ for (auto t : type.getResults())
+ results.push_back(convertType(t));
+
+ return FunctionType::get(arguments, results, type.getContext());
+}
+
PassResult DialectConversion::runOnModule(Module *m) {
return impl::FunctionConversion::convert(this, m) ? failure() : success();
}
--- /dev/null
+// RUN: mlir-opt -convert-to-llvmir %s | FileCheck %s
+
+//CHECK: func @second_order_arg(!llvm<"void ()*">)
+func @second_order_arg(%arg0 : () -> ())
+
+//CHECK: func @second_order_result() -> !llvm<"void ()*">
+func @second_order_result() -> (() -> ())
+
+//CHECK: func @second_order_multi_result() -> !llvm<"{ i32 ()*, i64 ()*, float ()* }">
+func @second_order_multi_result() -> (() -> (i32), () -> (i64), () -> (f32))
+
+//CHECK: func @third_order(!llvm<"void ()* (void ()*)*">) -> !llvm<"void ()* (void ()*)*">
+func @third_order(%arg0 : (() -> ()) -> (() -> ())) -> ((() -> ()) -> (() -> ()))
+
+//CHECK: func @fifth_order_left(!llvm<"void (void (void (void ()*)*)*)*">)
+func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
+
+//CHECK: func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
+func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))
+
+//CHECK-LABEL: func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {
+func @pass_through(%arg0: () -> ()) -> (() -> ()) {
+//CHECK-NEXT: "llvm.br"()[^bb1(%arg0 : !llvm<"void ()*">)] : () -> ()
+ br ^bb1(%arg0 : () -> ())
+
+//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0
+^bb1(%bbarg: () -> ()):
+//CHECK-NEXT: "llvm.return"(%0) : (!llvm<"void ()*">) -> ()
+ return %bbarg : () -> ()
+}