#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
/// Async Runtime API function types.
///
/// Because we can't create API function signature for type parametrized
-/// async.getValue type, we use opaque pointers (!llvm.ptr<i8>) instead. After
+/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
/// lowering all async data types become opaque pointers at runtime.
struct AsyncAPI {
- // All async types are lowered to opaque i8* LLVM pointers at runtime.
- static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
+ // All async types are lowered to opaque LLVM pointers at runtime.
+ static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ if (useLLVMOpaquePointers)
+ return LLVM::LLVMPointerType::get(ctx);
return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
}
return LLVM::LLVMTokenType::get(ctx);
}
- static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
- auto ref = opaquePointerType(ctx);
+ static FunctionType addOrDropRefFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto ref = opaquePointerType(ctx, useLLVMOpaquePointers);
auto count = IntegerType::get(ctx, 64);
return FunctionType::get(ctx, {ref, count}, {});
}
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
}
- static FunctionType createValueFunctionType(MLIRContext *ctx) {
+ static FunctionType createValueFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
auto i64 = IntegerType::get(ctx, 64);
- auto value = opaquePointerType(ctx);
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
return FunctionType::get(ctx, {i64}, {value});
}
return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
}
- static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
- auto storage = opaquePointerType(ctx);
+ static FunctionType getValueStorageFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+ auto storage = opaquePointerType(ctx, useLLVMOpaquePointers);
return FunctionType::get(ctx, {value}, {storage});
}
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
- static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
+ static FunctionType emplaceValueFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
return FunctionType::get(ctx, {value}, {});
}
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
- static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
+ static FunctionType setValueErrorFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
return FunctionType::get(ctx, {value}, {});
}
return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
}
- static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
+ static FunctionType isValueErrorFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
auto i1 = IntegerType::get(ctx, 1);
return FunctionType::get(ctx, {value}, {i1});
}
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
- static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
+ static FunctionType awaitValueFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
return FunctionType::get(ctx, {value}, {});
}
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
}
- static FunctionType executeFunctionType(MLIRContext *ctx) {
- auto hdl = opaquePointerType(ctx);
- auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+ static FunctionType executeFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+ Type resume;
+ if (useLLVMOpaquePointers)
+ resume = LLVM::LLVMPointerType::get(ctx);
+ else
+ resume = LLVM::LLVMPointerType::get(
+ resumeFunctionType(ctx, useLLVMOpaquePointers));
return FunctionType::get(ctx, {hdl, resume}, {});
}
{i64});
}
- static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = opaquePointerType(ctx);
- auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+ static FunctionType
+ awaitTokenAndExecuteFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+ Type resume;
+ if (useLLVMOpaquePointers)
+ resume = LLVM::LLVMPointerType::get(ctx);
+ else
+ resume = LLVM::LLVMPointerType::get(
+ resumeFunctionType(ctx, useLLVMOpaquePointers));
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
- static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
- auto value = opaquePointerType(ctx);
- auto hdl = opaquePointerType(ctx);
- auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+ static FunctionType
+ awaitValueAndExecuteFunctionType(MLIRContext *ctx,
+ bool useLLVMOpaquePointers) {
+ auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+ auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+ Type resume;
+ if (useLLVMOpaquePointers)
+ resume = LLVM::LLVMPointerType::get(ctx);
+ else
+ resume = LLVM::LLVMPointerType::get(
+ resumeFunctionType(ctx, useLLVMOpaquePointers));
return FunctionType::get(ctx, {value, hdl, resume}, {});
}
- static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = opaquePointerType(ctx);
- auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+ static FunctionType
+ awaitAllAndExecuteFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
+ auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+ Type resume;
+ if (useLLVMOpaquePointers)
+ resume = LLVM::LLVMPointerType::get(ctx);
+ else
+ resume = LLVM::LLVMPointerType::get(
+ resumeFunctionType(ctx, useLLVMOpaquePointers));
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
}
// Auxiliary coroutine resume intrinsic wrapper.
- static Type resumeFunctionType(MLIRContext *ctx) {
+ static Type resumeFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- auto i8Ptr = opaquePointerType(ctx);
- return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
+ auto ptrType = opaquePointerType(ctx, useLLVMOpaquePointers);
+ return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
}
};
} // namespace
/// Adds Async Runtime C API declarations to the module.
-static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
+static void addAsyncRuntimeApiDeclarations(ModuleOp module,
+ bool useLLVMOpaquePointers) {
auto builder =
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
};
MLIRContext *ctx = module.getContext();
- addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
- addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+ addFuncDecl(kAddRef,
+ AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
+ addFuncDecl(kDropRef,
+ AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
- addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
+ addFuncDecl(kCreateValue,
+ AsyncAPI::createValueFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
- addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
+ addFuncDecl(kEmplaceValue,
+ AsyncAPI::emplaceValueFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
- addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
+ addFuncDecl(kSetValueError,
+ AsyncAPI::setValueErrorFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
- addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
+ addFuncDecl(kIsValueError,
+ AsyncAPI::isValueErrorFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
- addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
+ addFuncDecl(kAwaitValue,
+ AsyncAPI::awaitValueFunctionType(ctx, useLLVMOpaquePointers));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
- addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
- addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
+ addFuncDecl(kExecute,
+ AsyncAPI::executeFunctionType(ctx, useLLVMOpaquePointers));
+ addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(
+ ctx, useLLVMOpaquePointers));
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
- addFuncDecl(kAwaitTokenAndExecute,
- AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
- addFuncDecl(kAwaitValueAndExecute,
- AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
- addFuncDecl(kAwaitAllAndExecute,
- AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType(
+ ctx, useLLVMOpaquePointers));
+ addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType(
+ ctx, useLLVMOpaquePointers));
+ addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(
+ ctx, useLLVMOpaquePointers));
addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
}
/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
/// intrinsics. We need this function to be able to pass it to the async
/// runtime execute API.
-static void addResumeFunction(ModuleOp module) {
+static void addResumeFunction(ModuleOp module, bool useOpaquePointers) {
if (module.lookupSymbol(kResume))
return;
auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
+ Type ptrType;
+ if (useOpaquePointers)
+ ptrType = LLVM::LLVMPointerType::get(ctx);
+ else
+ ptrType = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
+ kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock();
/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
/// their runtime type (opaque pointers) and does not convert any other types.
class AsyncRuntimeTypeConverter : public TypeConverter {
+ bool llvmOpaquePointers = false;
+
public:
- AsyncRuntimeTypeConverter() {
+ AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options)
+ : llvmOpaquePointers(options.useOpaquePointers) {
addConversion([](Type type) { return type; });
- addConversion(convertAsyncTypes);
+ addConversion([this](Type type) {
+ return convertAsyncTypes(type, llvmOpaquePointers);
+ });
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
addTargetMaterialization(addUnrealizedCast);
}
- static std::optional<Type> convertAsyncTypes(Type type) {
+ /// Returns whether LLVM opaque pointers should be used instead of typed
+ /// pointers.
+ bool useOpaquePointers() const { return llvmOpaquePointers; }
+
+ /// Creates an LLVM pointer type which may either be a typed pointer or an
+ /// opaque pointer, depending on what options the converter was constructed
+ /// with.
+ LLVM::LLVMPointerType getPointerType(Type elementType) {
+ if (llvmOpaquePointers)
+ return LLVM::LLVMPointerType::get(elementType.getContext());
+ return LLVM::LLVMPointerType::get(elementType);
+ }
+
+ static std::optional<Type> convertAsyncTypes(Type type,
+ bool useOpaquePointers) {
if (type.isa<TokenType, GroupType, ValueType>())
- return AsyncAPI::opaquePointerType(type.getContext());
+ return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
if (type.isa<CoroIdType, CoroStateType>())
return AsyncAPI::tokenType(type.getContext());
if (type.isa<CoroHandleType>())
- return AsyncAPI::opaquePointerType(type.getContext());
+ return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
return std::nullopt;
}
};
+
+/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
+/// as type converter. Allows access to it via the 'getTypeConverter'
+/// convenience method.
+template <typename SourceOp>
+class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
+
+ using Base = OpConversionPattern<SourceOp>;
+
+public:
+ AsyncOpConversionPattern(AsyncRuntimeTypeConverter &typeConverter,
+ MLIRContext *context)
+ : Base(typeConverter, context) {}
+
+ /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
+ AsyncRuntimeTypeConverter *getTypeConverter() const {
+ return static_cast<AsyncRuntimeTypeConverter *>(Base::getTypeConverter());
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
namespace {
-class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
+class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto token = AsyncAPI::tokenType(op->getContext());
- auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+ auto ptrType = AsyncAPI::opaquePointerType(
+ op->getContext(), getTypeConverter()->useOpaquePointers());
auto loc = op->getLoc();
// Constants for initializing coroutine frame.
auto constZero =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
- auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
+ auto nullPtr = rewriter.create<LLVM::NullOp>(loc, ptrType);
// Get coroutine id: @llvm.coro.id.
rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
//===----------------------------------------------------------------------===//
namespace {
-class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
+class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+ auto ptrType = AsyncAPI::opaquePointerType(
+ op->getContext(), getTypeConverter()->useOpaquePointers());
auto loc = op->getLoc();
// Get coroutine frame size: @llvm.coro.size.i64.
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type(),
- /*TODO: opaquePointers=*/false);
+ getTypeConverter()->useOpaquePointers());
auto coroAlloc = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
- op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
+ op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
return success();
}
//===----------------------------------------------------------------------===//
namespace {
-class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
+class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+ auto ptrType = AsyncAPI::opaquePointerType(
+ op->getContext(), getTypeConverter()->useOpaquePointers());
auto loc = op->getLoc();
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
auto coroMem =
- rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
+ rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
// Free the memory.
- auto freeFuncOp = LLVM::lookupOrCreateFreeFn(
- op->getParentOfType<ModuleOp>(), /*TODO: opaquePointers=*/false);
+ auto freeFuncOp =
+ LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>(),
+ getTypeConverter()->useOpaquePointers());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
ValueRange(coroMem.getResult()));
//===----------------------------------------------------------------------===//
namespace {
-class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
+class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
auto i64 = rewriter.getI64Type();
auto storedType = converter->convertType(valueType.getValueType());
- auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
+ auto storagePtrType = getTypeConverter()->getPointerType(storedType);
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
- auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
- ArrayRef<LLVM::GEPArg>{1});
+ auto gep =
+ rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
};
namespace {
class RuntimeCreateGroupOpLowering
- : public OpConversionPattern<RuntimeCreateGroupOp> {
+ : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
namespace {
class RuntimeAwaitAndResumeOpLowering
- : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
+ : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
Value handle = adaptor.getHandle();
// A pointer to coroutine resume intrinsic wrapper.
- addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
+ addResumeFunction(op->getParentOfType<ModuleOp>(),
+ getTypeConverter()->useOpaquePointers());
+ auto resumeFnTy = AsyncAPI::resumeFunctionType(
+ op->getContext(), getTypeConverter()->useOpaquePointers());
auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+ op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
rewriter.create<func::CallOp>(
op->getLoc(), apiFuncName, TypeRange(),
//===----------------------------------------------------------------------===//
namespace {
-class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
+class RuntimeResumeOpLowering
+ : public AsyncOpConversionPattern<RuntimeResumeOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// A pointer to coroutine resume intrinsic wrapper.
- addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
+ addResumeFunction(op->getParentOfType<ModuleOp>(),
+ getTypeConverter()->useOpaquePointers());
+ auto resumeFnTy = AsyncAPI::resumeFunctionType(
+ op->getContext(), getTypeConverter()->useOpaquePointers());
auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+ op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
// Call async runtime API to execute a coroutine in the managed thread.
auto coroHdl = adaptor.getHandle();
//===----------------------------------------------------------------------===//
namespace {
-class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
+class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
Location loc = op->getLoc();
// Get a pointer to the async value storage from the runtime.
- auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+ auto ptrType = AsyncAPI::opaquePointerType(
+ rewriter.getContext(), getTypeConverter()->useOpaquePointers());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
- TypeRange(i8Ptr), storage);
+ auto storagePtr = rewriter.create<func::CallOp>(
+ loc, kGetValueStorage, TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getValue().getType();
return rewriter.notifyMatchFailure(
op, "failed to convert stored value type to LLVM type");
- auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
- loc, LLVM::LLVMPointerType::get(llvmValueType),
- storagePtr.getResult(0));
+ Value castedStoragePtr = storagePtr.getResult(0);
+ if (!getTypeConverter()->useOpaquePointers())
+ castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
+ loc, getTypeConverter()->getPointerType(llvmValueType),
+ castedStoragePtr);
// Store the yielded value into the async value storage.
auto value = adaptor.getValue();
- rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
+ rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
// Erase the original runtime store operation.
rewriter.eraseOp(op);
//===----------------------------------------------------------------------===//
namespace {
-class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
+class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
Location loc = op->getLoc();
// Get a pointer to the async value storage from the runtime.
- auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+ auto ptrType = AsyncAPI::opaquePointerType(
+ rewriter.getContext(), getTypeConverter()->useOpaquePointers());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
- TypeRange(i8Ptr), storage);
+ auto storagePtr = rewriter.create<func::CallOp>(
+ loc, kGetValueStorage, TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getResult().getType();
return rewriter.notifyMatchFailure(
op, "failed to convert loaded value type to LLVM type");
- auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
- loc, LLVM::LLVMPointerType::get(llvmValueType),
- storagePtr.getResult(0));
+ Value castedStoragePtr = storagePtr.getResult(0);
+ if (!getTypeConverter()->useOpaquePointers())
+ castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
+ loc, getTypeConverter()->getPointerType(llvmValueType),
+ castedStoragePtr);
// Load from the casted pointer.
- rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
+ castedStoragePtr);
return success();
}
ModuleOp module = getOperation();
MLIRContext *ctx = module->getContext();
+ LowerToLLVMOptions options(ctx);
+ options.useOpaquePointers = useOpaquePointers;
+
// Add declarations for most functions required by the coroutines lowering.
// We delay adding the resume function until it's needed because it currently
// fails to compile unless '-O0' is specified.
- addAsyncRuntimeApiDeclarations(module);
+ addAsyncRuntimeApiDeclarations(module, useOpaquePointers);
// Lower async.runtime and async.coro operations to Async Runtime API and
// LLVM coroutine intrinsics.
// Convert async dialect types and operations to LLVM dialect.
- AsyncRuntimeTypeConverter converter;
+ AsyncRuntimeTypeConverter converter(options);
RewritePatternSet patterns(ctx);
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
- LLVMTypeConverter llvmConverter(ctx);
- llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
+ LLVMTypeConverter llvmConverter(ctx, options);
+ llvmConverter.addConversion([&](Type type) {
+ return AsyncRuntimeTypeConverter::convertAsyncTypes(
+ type, llvmConverter.useOpaquePointers());
+ });
// Convert async types in function signatures and function calls.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
- RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
- ctx);
+ RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
// Lower async coroutine operations to LLVM coroutine intrinsics.
patterns