[mlir][Async] Add option to LLVM lowering to use opaque pointers
authorMarkus Böck <markus.boeck02@gmail.com>
Fri, 10 Feb 2023 11:35:49 +0000 (12:35 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Fri, 10 Feb 2023 16:39:36 +0000 (17:39 +0100)
Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179

This patch adds the pass option 'use-opaque-pointers' to allow the dialect conversion from async to LLVM to create LLVM opaque pointers instead of typed pointers.
The gist of the changes boil down to having to propagate the choice of whether opaque or typed pointers should be used, to various helper functions that then either create typed pointers or opaque pointers.
This sadly creates a bit of a code duplication in comparison to other patches in this series, which I think is mostly unavoidable however, since a lot of the patterns in this lowering require the use of the AsyncTypeConverter, instead of the LLVMTypeConverter.

Besides that, the tests have been converter to opaque pointers with one file with typed pointer support having been created as regression tests.

Differential Revision: https://reviews.llvm.org/D143661

mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir [new file with mode: 0644]

index 0a7ee8a..d01e05c 100644 (file)
@@ -157,6 +157,11 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
     "LLVM::LLVMDialect",
     "func::FuncDialect",
   ];
+  let options = [
+    Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+           /*default=*/"false", "Generate LLVM IR using opaque pointers "
+           "instead of typed pointers">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
index 450dbf0..f1acd84 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Async/IR/Async.h"
@@ -70,11 +71,14 @@ namespace {
 /// Async Runtime API function types.
 ///
 /// Because we can't create API function signature for type parametrized
-/// async.getValue type, we use opaque pointers (!llvm.ptr<i8>) instead. After
+/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
 /// lowering all async data types become opaque pointers at runtime.
 struct AsyncAPI {
-  // All async types are lowered to opaque i8* LLVM pointers at runtime.
-  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
+  // All async types are lowered to opaque LLVM pointers at runtime.
+  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx,
+                                                 bool useLLVMOpaquePointers) {
+    if (useLLVMOpaquePointers)
+      return LLVM::LLVMPointerType::get(ctx);
     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
   }
 
@@ -82,8 +86,9 @@ struct AsyncAPI {
     return LLVM::LLVMTokenType::get(ctx);
   }
 
-  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
-    auto ref = opaquePointerType(ctx);
+  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx,
+                                               bool useLLVMOpaquePointers) {
+    auto ref = opaquePointerType(ctx, useLLVMOpaquePointers);
     auto count = IntegerType::get(ctx, 64);
     return FunctionType::get(ctx, {ref, count}, {});
   }
@@ -92,9 +97,10 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
   }
 
-  static FunctionType createValueFunctionType(MLIRContext *ctx) {
+  static FunctionType createValueFunctionType(MLIRContext *ctx,
+                                              bool useLLVMOpaquePointers) {
     auto i64 = IntegerType::get(ctx, 64);
-    auto value = opaquePointerType(ctx);
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
     return FunctionType::get(ctx, {i64}, {value});
   }
 
@@ -103,9 +109,10 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
   }
 
-  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
-    auto storage = opaquePointerType(ctx);
+  static FunctionType getValueStorageFunctionType(MLIRContext *ctx,
+                                                  bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+    auto storage = opaquePointerType(ctx, useLLVMOpaquePointers);
     return FunctionType::get(ctx, {value}, {storage});
   }
 
@@ -113,8 +120,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
+  static FunctionType emplaceValueFunctionType(MLIRContext *ctx,
+                                               bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -122,8 +130,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
+  static FunctionType setValueErrorFunctionType(MLIRContext *ctx,
+                                                bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -132,8 +141,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
   }
 
-  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
+  static FunctionType isValueErrorFunctionType(MLIRContext *ctx,
+                                               bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
     auto i1 = IntegerType::get(ctx, 1);
     return FunctionType::get(ctx, {value}, {i1});
   }
@@ -147,8 +157,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
+  static FunctionType awaitValueFunctionType(MLIRContext *ctx,
+                                             bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -156,9 +167,15 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
   }
 
-  static FunctionType executeFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+  static FunctionType executeFunctionType(MLIRContext *ctx,
+                                          bool useLLVMOpaquePointers) {
+    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+    Type resume;
+    if (useLLVMOpaquePointers)
+      resume = LLVM::LLVMPointerType::get(ctx);
+    else
+      resume = LLVM::LLVMPointerType::get(
+          resumeFunctionType(ctx, useLLVMOpaquePointers));
     return FunctionType::get(ctx, {hdl, resume}, {});
   }
 
@@ -168,22 +185,42 @@ struct AsyncAPI {
                              {i64});
   }
 
-  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+  static FunctionType
+  awaitTokenAndExecuteFunctionType(MLIRContext *ctx,
+                                   bool useLLVMOpaquePointers) {
+    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+    Type resume;
+    if (useLLVMOpaquePointers)
+      resume = LLVM::LLVMPointerType::get(ctx);
+    else
+      resume = LLVM::LLVMPointerType::get(
+          resumeFunctionType(ctx, useLLVMOpaquePointers));
     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
   }
 
-  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
-    auto hdl = opaquePointerType(ctx);
-    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+  static FunctionType
+  awaitValueAndExecuteFunctionType(MLIRContext *ctx,
+                                   bool useLLVMOpaquePointers) {
+    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+    Type resume;
+    if (useLLVMOpaquePointers)
+      resume = LLVM::LLVMPointerType::get(ctx);
+    else
+      resume = LLVM::LLVMPointerType::get(
+          resumeFunctionType(ctx, useLLVMOpaquePointers));
     return FunctionType::get(ctx, {value, hdl, resume}, {});
   }
 
-  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+  static FunctionType
+  awaitAllAndExecuteFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
+    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
+    Type resume;
+    if (useLLVMOpaquePointers)
+      resume = LLVM::LLVMPointerType::get(ctx);
+    else
+      resume = LLVM::LLVMPointerType::get(
+          resumeFunctionType(ctx, useLLVMOpaquePointers));
     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
   }
 
@@ -192,16 +229,17 @@ struct AsyncAPI {
   }
 
   // Auxiliary coroutine resume intrinsic wrapper.
-  static Type resumeFunctionType(MLIRContext *ctx) {
+  static Type resumeFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    auto i8Ptr = opaquePointerType(ctx);
-    return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
+    auto ptrType = opaquePointerType(ctx, useLLVMOpaquePointers);
+    return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
   }
 };
 } // namespace
 
 /// Adds Async Runtime C API declarations to the module.
-static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
+static void addAsyncRuntimeApiDeclarations(ModuleOp module,
+                                           bool useLLVMOpaquePointers) {
   auto builder =
       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
 
@@ -212,30 +250,39 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   };
 
   MLIRContext *ctx = module.getContext();
-  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
-  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+  addFuncDecl(kAddRef,
+              AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kDropRef,
+              AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
-  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
+  addFuncDecl(kCreateValue,
+              AsyncAPI::createValueFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
-  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
+  addFuncDecl(kEmplaceValue,
+              AsyncAPI::emplaceValueFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
-  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
+  addFuncDecl(kSetValueError,
+              AsyncAPI::setValueErrorFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
-  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
+  addFuncDecl(kIsValueError,
+              AsyncAPI::isValueErrorFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
-  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
+  addFuncDecl(kAwaitValue,
+              AsyncAPI::awaitValueFunctionType(ctx, useLLVMOpaquePointers));
   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
-  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
-  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
+  addFuncDecl(kExecute,
+              AsyncAPI::executeFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(
+                                    ctx, useLLVMOpaquePointers));
   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
-  addFuncDecl(kAwaitTokenAndExecute,
-              AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
-  addFuncDecl(kAwaitValueAndExecute,
-              AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
-  addFuncDecl(kAwaitAllAndExecute,
-              AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType(
+                                         ctx, useLLVMOpaquePointers));
+  addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType(
+                                         ctx, useLLVMOpaquePointers));
+  addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(
+                                       ctx, useLLVMOpaquePointers));
   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
 }
 
@@ -248,7 +295,7 @@ static constexpr const char *kResume = "__resume";
 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
 /// intrinsics. We need this function to be able to pass it to the async
 /// runtime execute API.
-static void addResumeFunction(ModuleOp module) {
+static void addResumeFunction(ModuleOp module, bool useOpaquePointers) {
   if (module.lookupSymbol(kResume))
     return;
 
@@ -257,10 +304,14 @@ static void addResumeFunction(ModuleOp module) {
   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
 
   auto voidTy = LLVM::LLVMVoidType::get(ctx);
-  auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
+  Type ptrType;
+  if (useOpaquePointers)
+    ptrType = LLVM::LLVMPointerType::get(ctx);
+  else
+    ptrType = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
 
   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
-      kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
+      kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
   resumeOp.setPrivate();
 
   auto *block = resumeOp.addEntryBlock();
@@ -278,10 +329,15 @@ namespace {
 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
 /// their runtime type (opaque pointers) and does not convert any other types.
 class AsyncRuntimeTypeConverter : public TypeConverter {
+  bool llvmOpaquePointers = false;
+
 public:
-  AsyncRuntimeTypeConverter() {
+  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options)
+      : llvmOpaquePointers(options.useOpaquePointers) {
     addConversion([](Type type) { return type; });
-    addConversion(convertAsyncTypes);
+    addConversion([this](Type type) {
+      return convertAsyncTypes(type, llvmOpaquePointers);
+    });
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
@@ -295,18 +351,52 @@ public:
     addTargetMaterialization(addUnrealizedCast);
   }
 
-  static std::optional<Type> convertAsyncTypes(Type type) {
+  /// Returns whether LLVM opaque pointers should be used instead of typed
+  /// pointers.
+  bool useOpaquePointers() const { return llvmOpaquePointers; }
+
+  /// Creates an LLVM pointer type which may either be a typed pointer or an
+  /// opaque pointer, depending on what options the converter was constructed
+  /// with.
+  LLVM::LLVMPointerType getPointerType(Type elementType) {
+    if (llvmOpaquePointers)
+      return LLVM::LLVMPointerType::get(elementType.getContext());
+    return LLVM::LLVMPointerType::get(elementType);
+  }
+
+  static std::optional<Type> convertAsyncTypes(Type type,
+                                               bool useOpaquePointers) {
     if (type.isa<TokenType, GroupType, ValueType>())
-      return AsyncAPI::opaquePointerType(type.getContext());
+      return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
 
     if (type.isa<CoroIdType, CoroStateType>())
       return AsyncAPI::tokenType(type.getContext());
     if (type.isa<CoroHandleType>())
-      return AsyncAPI::opaquePointerType(type.getContext());
+      return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
 
     return std::nullopt;
   }
 };
+
+/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
+/// as type converter. Allows access to it via the 'getTypeConverter'
+/// convenience method.
+template <typename SourceOp>
+class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
+
+  using Base = OpConversionPattern<SourceOp>;
+
+public:
+  AsyncOpConversionPattern(AsyncRuntimeTypeConverter &typeConverter,
+                           MLIRContext *context)
+      : Base(typeConverter, context) {}
+
+  /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
+  AsyncRuntimeTypeConverter *getTypeConverter() const {
+    return static_cast<AsyncRuntimeTypeConverter *>(Base::getTypeConverter());
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -314,21 +404,22 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
+class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using AsyncOpConversionPattern::AsyncOpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto token = AsyncAPI::tokenType(op->getContext());
-    auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+    auto ptrType = AsyncAPI::opaquePointerType(
+        op->getContext(), getTypeConverter()->useOpaquePointers());
     auto loc = op->getLoc();
 
     // Constants for initializing coroutine frame.
     auto constZero =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
-    auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
+    auto nullPtr = rewriter.create<LLVM::NullOp>(loc, ptrType);
 
     // Get coroutine id: @llvm.coro.id.
     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
@@ -344,14 +435,15 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
+class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using AsyncOpConversionPattern::AsyncOpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+    auto ptrType = AsyncAPI::opaquePointerType(
+        op->getContext(), getTypeConverter()->useOpaquePointers());
     auto loc = op->getLoc();
 
     // Get coroutine frame size: @llvm.coro.size.i64.
@@ -379,14 +471,14 @@ public:
     // Allocate memory for the coroutine frame.
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
         op->getParentOfType<ModuleOp>(), rewriter.getI64Type(),
-        /*TODO: opaquePointers=*/false);
+        getTypeConverter()->useOpaquePointers());
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
         loc, allocFuncOp, ValueRange{coroAlign, coroSize});
 
     // Begin a coroutine: @llvm.coro.begin.
     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
-        op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
+        op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
 
     return success();
   }
@@ -398,23 +490,25 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
+class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using AsyncOpConversionPattern::AsyncOpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
+    auto ptrType = AsyncAPI::opaquePointerType(
+        op->getContext(), getTypeConverter()->useOpaquePointers());
     auto loc = op->getLoc();
 
     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
     auto coroMem =
-        rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
+        rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
 
     // Free the memory.
-    auto freeFuncOp = LLVM::lookupOrCreateFreeFn(
-        op->getParentOfType<ModuleOp>(), /*TODO: opaquePointers=*/false);
+    auto freeFuncOp =
+        LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>(),
+                                   getTypeConverter()->useOpaquePointers());
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
                                               ValueRange(coroMem.getResult()));
 
@@ -551,9 +645,9 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
+class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
@@ -576,13 +670,14 @@ public:
         auto i64 = rewriter.getI64Type();
 
         auto storedType = converter->convertType(valueType.getValueType());
-        auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
+        auto storagePtrType = getTypeConverter()->getPointerType(storedType);
 
         // %Size = getelementptr %T* null, int 1
         // %SizeI = ptrtoint %T* %Size to i64
         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
-        auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
-                                                ArrayRef<LLVM::GEPArg>{1});
+        auto gep =
+            rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
+                                         nullPtr, ArrayRef<LLVM::GEPArg>{1});
         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
       };
 
@@ -603,9 +698,9 @@ public:
 
 namespace {
 class RuntimeCreateGroupOpLowering
-    : public OpConversionPattern<RuntimeCreateGroupOp> {
+    : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
@@ -731,9 +826,9 @@ public:
 
 namespace {
 class RuntimeAwaitAndResumeOpLowering
-    : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
+    : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using AsyncOpConversionPattern::AsyncOpConversionPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
@@ -748,10 +843,12 @@ public:
     Value handle = adaptor.getHandle();
 
     // A pointer to coroutine resume intrinsic wrapper.
-    addResumeFunction(op->getParentOfType<ModuleOp>());
-    auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
+    addResumeFunction(op->getParentOfType<ModuleOp>(),
+                      getTypeConverter()->useOpaquePointers());
+    auto resumeFnTy = AsyncAPI::resumeFunctionType(
+        op->getContext(), getTypeConverter()->useOpaquePointers());
     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
-        op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+        op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
 
     rewriter.create<func::CallOp>(
         op->getLoc(), apiFuncName, TypeRange(),
@@ -768,18 +865,21 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
+class RuntimeResumeOpLowering
+    : public AsyncOpConversionPattern<RuntimeResumeOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using AsyncOpConversionPattern::AsyncOpConversionPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // A pointer to coroutine resume intrinsic wrapper.
-    addResumeFunction(op->getParentOfType<ModuleOp>());
-    auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
+    addResumeFunction(op->getParentOfType<ModuleOp>(),
+                      getTypeConverter()->useOpaquePointers());
+    auto resumeFnTy = AsyncAPI::resumeFunctionType(
+        op->getContext(), getTypeConverter()->useOpaquePointers());
     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
-        op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+        op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
 
     // Call async runtime API to execute a coroutine in the managed thread.
     auto coroHdl = adaptor.getHandle();
@@ -796,9 +896,9 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
+class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
@@ -806,10 +906,11 @@ public:
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
-    auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+    auto ptrType = AsyncAPI::opaquePointerType(
+        rewriter.getContext(), getTypeConverter()->useOpaquePointers());
     auto storage = adaptor.getStorage();
-    auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
-                                                    TypeRange(i8Ptr), storage);
+    auto storagePtr = rewriter.create<func::CallOp>(
+        loc, kGetValueStorage, TypeRange(ptrType), storage);
 
     // Cast from i8* to the LLVM pointer type.
     auto valueType = op.getValue().getType();
@@ -818,13 +919,15 @@ public:
       return rewriter.notifyMatchFailure(
           op, "failed to convert stored value type to LLVM type");
 
-    auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
-        loc, LLVM::LLVMPointerType::get(llvmValueType),
-        storagePtr.getResult(0));
+    Value castedStoragePtr = storagePtr.getResult(0);
+    if (!getTypeConverter()->useOpaquePointers())
+      castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
+          loc, getTypeConverter()->getPointerType(llvmValueType),
+          castedStoragePtr);
 
     // Store the yielded value into the async value storage.
     auto value = adaptor.getValue();
-    rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
+    rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
 
     // Erase the original runtime store operation.
     rewriter.eraseOp(op);
@@ -839,9 +942,9 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace {
-class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
+class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
@@ -849,10 +952,11 @@ public:
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
-    auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+    auto ptrType = AsyncAPI::opaquePointerType(
+        rewriter.getContext(), getTypeConverter()->useOpaquePointers());
     auto storage = adaptor.getStorage();
-    auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
-                                                    TypeRange(i8Ptr), storage);
+    auto storagePtr = rewriter.create<func::CallOp>(
+        loc, kGetValueStorage, TypeRange(ptrType), storage);
 
     // Cast from i8* to the LLVM pointer type.
     auto valueType = op.getResult().getType();
@@ -861,12 +965,15 @@ public:
       return rewriter.notifyMatchFailure(
           op, "failed to convert loaded value type to LLVM type");
 
-    auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
-        loc, LLVM::LLVMPointerType::get(llvmValueType),
-        storagePtr.getResult(0));
+    Value castedStoragePtr = storagePtr.getResult(0);
+    if (!getTypeConverter()->useOpaquePointers())
+      castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
+          loc, getTypeConverter()->getPointerType(llvmValueType),
+          castedStoragePtr);
 
     // Load from the casted pointer.
-    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
+                                              castedStoragePtr);
 
     return success();
   }
@@ -1000,22 +1107,28 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   ModuleOp module = getOperation();
   MLIRContext *ctx = module->getContext();
 
+  LowerToLLVMOptions options(ctx);
+  options.useOpaquePointers = useOpaquePointers;
+
   // Add declarations for most functions required by the coroutines lowering.
   // We delay adding the resume function until it's needed because it currently
   // fails to compile unless '-O0' is specified.
-  addAsyncRuntimeApiDeclarations(module);
+  addAsyncRuntimeApiDeclarations(module, useOpaquePointers);
 
   // Lower async.runtime and async.coro operations to Async Runtime API and
   // LLVM coroutine intrinsics.
 
   // Convert async dialect types and operations to LLVM dialect.
-  AsyncRuntimeTypeConverter converter;
+  AsyncRuntimeTypeConverter converter(options);
   RewritePatternSet patterns(ctx);
 
   // We use conversion to LLVM type to lower async.runtime load and store
   // operations.
-  LLVMTypeConverter llvmConverter(ctx);
-  llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
+  LLVMTypeConverter llvmConverter(ctx, options);
+  llvmConverter.addConversion([&](Type type) {
+    return AsyncRuntimeTypeConverter::convertAsyncTypes(
+        type, llvmConverter.useOpaquePointers());
+  });
 
   // Convert async types in function signatures and function calls.
   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
@@ -1036,8 +1149,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   // Lower async.runtime operations that rely on LLVM type converter to convert
   // from async value payload type to the LLVM type.
   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
-               RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
-                                                              ctx);
+               RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
 
   // Lower async coroutine operations to LLVM coroutine intrinsics.
   patterns
index 404be5c..ad5595a 100644 (file)
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s
 
 // CHECK-LABEL: @coro_id
 func.func @coro_id() {
   // CHECK: %0 = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %1 = llvm.mlir.null : !llvm.ptr<i8>
+  // CHECK: %1 = llvm.mlir.null : !llvm.ptr
   // CHECK: %2 = llvm.intr.coro.id %0, %1, %1, %1 : !llvm.token
   %0 = async.coro.id
   return
index 5929ab0..7ff5a2c 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @create_token
 func.func @create_token() {
@@ -9,7 +9,7 @@ func.func @create_token() {
 
 // CHECK-LABEL: @create_value
 func.func @create_value() {
-  // CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr<f32>
+  // CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr
   // CHECK: %[[OFFSET:.*]] = llvm.getelementptr %[[NULL]][1]
   // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[OFFSET]]
   // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue(%[[SIZE]])
@@ -152,8 +152,7 @@ func.func @store() {
   // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
   %1 = async.runtime.create : !async.value<f32>
   // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-  // CHECK: llvm.store %[[CST]], %[[P1]]
+  // CHECK: llvm.store %[[CST]], %[[P0]] : f32, !llvm.ptr
   async.runtime.store %0, %1 : !async.value<f32>
   return
 }
@@ -163,8 +162,7 @@ func.func @load() -> f32 {
   // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
   %0 = async.runtime.create : !async.value<f32>
   // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-  // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]]
+  // CHECK: %[[VALUE:.*]] = llvm.load %[[P0]] : !llvm.ptr -> f32
   %1 = async.runtime.load %0 : !async.value<f32>
   // CHECK: return %[[VALUE]] : f32
   return %1 : f32
index 05bffc6..fd419dc 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s
 
 // CHECK-LABEL: reference_counting
 func.func @reference_counting(%arg0: !async.token) {
@@ -35,7 +35,7 @@ func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Function outlined from the async.execute operation.
 // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
-// CHECK-SAME: -> !llvm.ptr<i8>
+// CHECK-SAME: -> !llvm.ptr
 
 // Create token for return op, and mark a function as a coroutine.
 // CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken()
@@ -97,7 +97,7 @@ func.func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 
 // Function outlined from the inner async.execute operation.
 // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
-// CHECK-SAME: -> !llvm.ptr<i8>
+// CHECK-SAME: -> !llvm.ptr
 // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
 // CHECK: call @mlirAsyncRuntimeExecute
@@ -108,7 +108,7 @@ func.func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 
 // Function outlined from the outer async.execute operation.
 // CHECK-LABEL: func private @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
-// CHECK-SAME: -> !llvm.ptr<i8>
+// CHECK-SAME: -> !llvm.ptr
 // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin
 
@@ -147,7 +147,7 @@ func.func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Function outlined from the first async.execute operation.
 // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
-// CHECK-SAME: -> !llvm.ptr<i8>
+// CHECK-SAME: -> !llvm.ptr
 // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
 // CHECK: call @mlirAsyncRuntimeExecute
@@ -156,8 +156,8 @@ func.func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
 // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
 
 // Function outlined from the second async.execute operation with dependency.
-// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr<i8>, %arg1: f32, %arg2: memref<1xf32>)
-// CHECK-SAME: -> !llvm.ptr<i8>
+// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr, %arg1: f32, %arg2: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr
 // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin
 
@@ -200,7 +200,7 @@ func.func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
 }
 
 // Function outlined from the async.execute operation.
-// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr<i8>)
+// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr)
 // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin
 
@@ -227,8 +227,7 @@ func.func @execute_and_return_f32() -> f32 {
   }
 
   // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
-  // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
-  // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] :  !llvm.ptr<f32>
+  // CHECK: %[[LOADED:.*]] = llvm.load %[[STORAGE]] : !llvm.ptr -> f32
   %0 = async.await %result : !async.value<f32>
 
   return %0 : f32
@@ -247,8 +246,7 @@ func.func @execute_and_return_f32() -> f32 {
 // Emplace result value.
 // CHECK: %[[CST:.*]] = arith.constant 1.230000e+02 : f32
 // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
-// CHECK: llvm.store %[[CST]], %[[ST_F32]] : !llvm.ptr<f32>
+// CHECK: llvm.store %[[CST]], %[[STORAGE]] : f32, !llvm.ptr
 // CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
 
 // Emplace result token.
@@ -280,7 +278,7 @@ func.func @async_value_operands() {
 // CHECK-LABEL: func private @async_execute_fn()
 
 // Function outlined from the second async.execute operation.
-// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr<i8>)
+// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr)
 // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
 
@@ -295,8 +293,7 @@ func.func @async_value_operands() {
 
 // Get the operand value storage, cast to f32 and add the value.
 // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%arg0)
-// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
-// CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] :  !llvm.ptr<f32>
+// CHECK: %[[LOADED:.*]] = llvm.load %[[STORAGE]] : !llvm.ptr -> f32
 // CHECK: arith.addf %[[LOADED]], %[[LOADED]] : f32
 
 // Emplace result token.
diff --git a/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir b/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir
new file mode 100644 (file)
index 0000000..07cd2ad
--- /dev/null
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=0' | FileCheck %s
+
+
+
+// CHECK-LABEL: @store
+func.func @store() {
+  // CHECK: %[[CST:.*]] = arith.constant 1.0
+  %0 = arith.constant 1.0 : f32
+  // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+  %1 = async.runtime.create : !async.value<f32>
+  // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
+  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+  // CHECK: llvm.store %[[CST]], %[[P1]]
+  async.runtime.store %0, %1 : !async.value<f32>
+  return
+}
+
+// CHECK-LABEL: @load
+func.func @load() -> f32 {
+  // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+  %0 = async.runtime.create : !async.value<f32>
+  // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
+  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+  // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]]
+  %1 = async.runtime.load %0 : !async.value<f32>
+  // CHECK: return %[[VALUE]] : f32
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: execute_no_async_args
+func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
+  // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
+  %token = async.execute {
+    %c0 = arith.constant 0 : index
+    memref.store %arg0, %arg1[%c0] : memref<1xf32>
+    async.yield
+  }
+  // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
+  // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+  // CHECK: %[[TRUE:.*]] = arith.constant true
+  // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
+  // CHECK: cf.assert %[[NOT_ERROR]]
+  // CHECK-NEXT: return
+  async.await %token : !async.token
+  return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr<i8>
+
+// Create token for return op, and mark a function as a coroutine.
+// CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
+
+// Pass a suspended coroutine to the async runtime.
+// CHECK: %[[STATE:.*]] = llvm.intr.coro.save
+// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]])
+// CHECK: %[[SUSPENDED:.*]] = llvm.intr.coro.suspend %[[STATE]]
+
+// Decide the next block based on the code returned from suspend.
+// CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
+// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
+// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
+// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
+
+// Resume coroutine after suspension.
+// CHECK: ^[[RESUME]]:
+// CHECK: memref.store %arg0, %arg1[%c0] : memref<1xf32>
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET]])
+
+// Delete coroutine.
+// CHECK: ^[[CLEANUP]]:
+// CHECK: %[[MEM:.*]] = llvm.intr.coro.free
+// CHECK: llvm.call @free(%[[MEM]])
+
+// Suspend coroutine, and also a return statement for ramp function.
+// CHECK: ^[[SUSPEND]]:
+// CHECK: llvm.intr.coro.end
+// CHECK: return %[[RET]]
+
+// -----
+
+// CHECK-LABEL: execute_and_return_f32
+func.func @execute_and_return_f32() -> f32 {
+ // CHECK: %[[RET:.*]]:2 = call @async_execute_fn
+  %token, %result = async.execute -> !async.value<f32> {
+    %c0 = arith.constant 123.0 : f32
+    async.yield %c0 : f32
+  }
+
+  // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
+  // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+  // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] :  !llvm.ptr<f32>
+  %0 = async.await %result : !async.value<f32>
+
+  return %0 : f32
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn()
+// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
+
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
+// CHECK: llvm.intr.coro.suspend
+
+// Emplace result value.
+// CHECK: %[[CST:.*]] = arith.constant 1.230000e+02 : f32
+// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
+// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+// CHECK: llvm.store %[[CST]], %[[ST_F32]] : !llvm.ptr<f32>
+// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
+
+// Emplace result token.
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])
+
+// -----
+
+// CHECK-LABEL: @await_and_resume_group
+func.func @await_and_resume_group() {
+  %c = arith.constant 1 : index
+  %0 = async.coro.id
+  // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
+  %1 = async.coro.begin %0
+  // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
+  %2 = async.runtime.create_group %c : !async.group
+  // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
+  // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
+  // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
+  async.runtime.await_and_resume %2, %1 : !async.group
+  return
+}