using Base::Base;
};
+// -------------------------------------------------------------------------- //
+// Helper functions of Async dialect transformations.
+// -------------------------------------------------------------------------- //
+
+/// Returns true if the type is reference counted. All async dialect types are
+/// reference counted at runtime.
+inline bool isRefCounted(Type type) {
+ return type.isa<TokenType, ValueType, GroupType>();
+}
+
} // namespace async
} // namespace mlir
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
Async_TokenType]>;
+def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
+ Async_TokenType,
+ Async_GroupType]>;
+
#endif // ASYNC_BASE_TD
let assemblyFormat = "$operand attr-dict";
}
+//===----------------------------------------------------------------------===//
+// Async Dialect Automatic Reference Counting Operations.
+//===----------------------------------------------------------------------===//
+
+// All async values (values, tokens, groups) are reference counted at runtime
+// and automatically destructed when reference count drops to 0.
+//
+// All values are semantically created with a reference count of +1 and it is
+// the responsibility of the last async value user to drop reference count.
+//
+// Async values created when:
+// 1. Operation returns async result (e.g. the result of an `async.execute`).
+// 2. Async value passed in as a block argument.
+//
+// It is the responsiblity of the async value user to extend the lifetime by
+// adding a +1 reference, if the reference counted value captured by the
+// asynchronously executed region (`async.execute` operation), and drop it after
+// the last nested use.
+//
+// Reference counting operations can be added to the IR using automatic
+// reference count pass, that relies on liveness analysis to find the last uses
+// of all reference counted values and automatically inserts
+// `drop_ref` operations.
+//
+// See `AsyncRefCountingPass` documentation for the implementation details.
+
+def Async_AddRefOp : Async_Op<"add_ref"> {
+ let summary = "adds a reference to async value";
+ let description = [{
+ The `async.add_ref` operation adds a reference(s) to async value (token,
+ value or group).
+ }];
+
+ let arguments = (ins Async_AnyAsyncType:$operand,
+ Confined<I32Attr, [IntPositive]>:$count);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand)
+ }];
+}
+
+def Async_DropRefOp : Async_Op<"drop_ref"> {
+ let summary = "drops a reference to async value";
+ let description = [{
+ The `async.drop_ref` operation drops a reference(s) to async value (token,
+ value or group).
+ }];
+
+ let arguments = (ins Async_AnyAsyncType:$operand,
+ Confined<I32Attr, [IntPositive]>:$count);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand)
+ }];
+}
+
#endif // ASYNC_OPS
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
+
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
}
+def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
+ let summary = "Automatic reference counting for Async dialect data types";
+ let constructor = "mlir::createAsyncRefCountingPass()";
+ let dependentDialects = ["async::AsyncDialect"];
+}
+
+def AsyncRefCountingOptimization :
+ FunctionPass<"async-ref-counting-optimization"> {
+ let summary = "Optimize automatic reference counting operations for the"
+ "Async dialect by removing redundant operations";
+ let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
+ let dependentDialects = ["async::AsyncDialect"];
+}
+
#endif // MLIR_DIALECT_ASYNC_PASSES
using CoroHandle = void *; // coroutine handle
using CoroResume = void (*)(void *); // coroutine resume function
+// Async runtime uses reference counting to manage the lifetime of async values
+// (values of async types like tokens, values and groups).
+using RefCountedObjPtr = void *;
+
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+ mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t);
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+ mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t);
+
// Create a new `async.token` in not-ready state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-ref-counting \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-ref-counting \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
// Async Runtime C API declaration.
//===----------------------------------------------------------------------===//
+static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
+static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
namespace {
// Async Runtime API function types.
struct AsyncAPI {
+ static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
+ auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto count = IntegerType::get(32, ctx);
+ return FunctionType::get({ref, count}, {}, ctx);
+ }
+
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
}
};
MLIRContext *ctx = module.getContext();
+ addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+ addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
- addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitAllAndExecute,
+ AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
}
//===----------------------------------------------------------------------===//
} // namespace
//===----------------------------------------------------------------------===//
+// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
+// to the corresponding API calls).
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename RefCountingOp>
+class RefCountingOpLowering : public ConversionPattern {
+public:
+ explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
+ : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
+ apiFunctionName(apiFunctionName) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ RefCountingOp refCountingOp = cast<RefCountingOp>(op);
+
+ auto count = rewriter.create<ConstantOp>(
+ op->getLoc(), rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(refCountingOp.count()));
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
+ ValueRange({operands[0], count}));
+
+ return success();
+ }
+
+private:
+ StringRef apiFunctionName;
+};
+
+// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
+class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
+public:
+ explicit AddRefOpLowering(MLIRContext *ctx)
+ : RefCountingOpLowering(ctx, kAddRef) {}
+};
+
+// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
+class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
+public:
+ explicit DropRefOpLowering(MLIRContext *ctx)
+ : RefCountingOpLowering(ctx, kDropRef) {}
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
//===----------------------------------------------------------------------===//
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
patterns.insert<CallOpOpConversion>(ctx);
+ patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
ConversionTarget target(*ctx);
+ target.addLegalOp<ConstantOp>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<AsyncDialect>();
target.addDynamicallyLegalOp<FuncOp>(
--- /dev/null
+//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements automatic reference counting for Async dialect data
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
+public:
+ AsyncRefCountingPass() = default;
+ void runOnFunction() override;
+
+private:
+ /// Adds an automatic reference counting to the `value`.
+ ///
+ /// All values are semantically created with a reference count of +1 and it is
+ /// the responsibility of the last async value user to drop reference count.
+ ///
+ /// Async values created when:
+ /// 1. Operation returns async result (e.g. the result of an
+ /// `async.execute`).
+ /// 2. Async value passed in as a block argument.
+ ///
+ /// To implement automatic reference counting, we must insert a +1 reference
+ /// before each `async.execute` operation using the value, and drop it after
+ /// the last use inside the async body region (we currently drop the reference
+ /// before the `async.yield` terminator).
+ ///
+ /// Automatic reference counting algorithm outline:
+ ///
+ /// 1. `ReturnLike` operations forward the reference counted values without
+ /// modifying the reference count.
+ ///
+ /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
+ /// reference counted values ends, and insert `drop_ref` operations after
+ /// the last use of the value.
+ ///
+ /// 3. Insert `add_ref` before the `async.execute` operation capturing the
+ /// value, and pairing `drop_ref` before the async body region terminator,
+ /// to release the captured reference counted value when execution
+ /// completes.
+ ///
+ /// 4. If the reference counted value is passed only to some of the block
+ /// successors, insert `drop_ref` operations in the beginning of the blocks
+ /// that do not have reference counted value uses.
+ ///
+ ///
+ /// Example:
+ ///
+ /// %token = ...
+ /// async.execute {
+ /// async.await %token : !async.token // await #1
+ /// async.yield
+ /// }
+ /// async.await %token : !async.token // await #2
+ ///
+ /// Based on the liveness analysis await #2 is the last use of the %token,
+ /// however the execution of the async region can be delayed, and to guarantee
+ /// that the %token is still alive when await #1 executes we need to
+ /// explicitly extend its lifetime using `add_ref` operation.
+ ///
+ /// After automatic reference counting:
+ ///
+ /// %token = ...
+ ///
+ /// // Make sure that %token is alive inside async.execute.
+ /// async.add_ref %token {count = 1 : i32} : !async.token
+ ///
+ /// async.execute {
+ /// async.await %token : !async.token // await #1
+ ///
+ /// // Drop the extra reference added to keep %token alive.
+ /// async.drop_ref %token {count = 1 : i32} : !async.token
+ ///
+ /// async.yied
+ /// }
+ /// async.await %token : !async.token // await #2
+ ///
+ /// // Drop the reference after the last use of %token.
+ /// async.drop_ref %token {count = 1 : i32} : !async.token
+ ///
+ LogicalResult addAutomaticRefCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
+ MLIRContext *ctx = value.getContext();
+ OpBuilder builder(ctx);
+
+ // Set inserton point after the operation producing a value, or at the
+ // beginning of the block if the value defined by the block argument.
+ if (Operation *op = value.getDefiningOp())
+ builder.setInsertionPointAfter(op);
+ else
+ builder.setInsertionPointToStart(value.getParentBlock());
+
+ Location loc = value.getLoc();
+ auto i32 = IntegerType::get(32, ctx);
+
+ // Drop the reference count immediately if the value has no uses.
+ if (value.getUses().empty()) {
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ return success();
+ }
+
+ // Use liveness analysis to find the placement of `drop_ref`operation.
+ auto liveness = getAnalysis<Liveness>();
+
+ // We analyse only the blocks of the region that defines the `value`, and do
+ // not check nested blocks attached to operations.
+ //
+ // By analyzing only the `definingRegion` CFG we potentially loose an
+ // opportunity to drop the reference count earlier and can extend the lifetime
+ // of reference counted value longer then it is really required.
+ //
+ // We also assume that all nested regions finish their execution before the
+ // completion of the owner operation. The only exception to this rule is
+ // `async.execute` operation, which is handled explicitly below.
+ Region *definingRegion = value.getParentRegion();
+
+ // ------------------------------------------------------------------------ //
+ // Find blocks where the `value` dies: the value is in `liveIn` set and not
+ // in the `liveOut` set. We place `drop_ref` immediately after the last use
+ // of the `value` in such regions.
+ // ------------------------------------------------------------------------ //
+
+ // Last users of the `value` inside all blocks where the value dies.
+ llvm::SmallSet<Operation *, 4> lastUsers;
+
+ for (Block &block : definingRegion->getBlocks()) {
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+ // Value in live input set or was defined in the block.
+ bool liveIn = blockLiveness->isLiveIn(value) ||
+ blockLiveness->getBlock() == value.getParentBlock();
+ if (!liveIn)
+ continue;
+
+ // Value is in the live out set.
+ bool liveOut = blockLiveness->isLiveOut(value);
+ if (liveOut)
+ continue;
+
+ // We proved that `value` dies in the `block`. Now find the last use of the
+ // `value` inside the `block`.
+
+ // Find any user of the `value` inside the block (including uses in nested
+ // regions attached to the operations in the block).
+ Operation *userInTheBlock = nullptr;
+ for (Operation *user : value.getUsers()) {
+ userInTheBlock = block.findAncestorOpInBlock(*user);
+ if (userInTheBlock)
+ break;
+ }
+
+ // Values with zero users handled explicitly in the beginning, if the value
+ // is in live out set it must have at least one use in the block.
+ assert(userInTheBlock && "value must have a user in the block");
+
+ // Find the last user of the `value` in the block;
+ Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
+ assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
+ lastUsers.insert(lastUser);
+ }
+
+ // Process all the last users of the `value` inside each block where the value
+ // dies.
+ for (Operation *lastUser : lastUsers) {
+ // Return like operations forward reference count.
+ if (lastUser->hasTrait<OpTrait::ReturnLike>())
+ continue;
+
+ // We can't currently handle other types of terminators.
+ if (lastUser->hasTrait<OpTrait::IsTerminator>())
+ return lastUser->emitError() << "async reference counting can't handle "
+ "terminators that are not ReturnLike";
+
+ // Add a drop_ref immediately after the last user.
+ builder.setInsertionPointAfter(lastUser);
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ // ------------------------------------------------------------------------ //
+ // Find blocks where the `value` is in `liveOut` set, however it is not in
+ // the `liveIn` set of all successors. If the `value` is not in the successor
+ // `liveIn` set, we add a `drop_ref` to the beginning of it.
+ // ------------------------------------------------------------------------ //
+
+ // Successors that we'll need a `drop_ref` for the `value`.
+ llvm::SmallSet<Block *, 4> dropRefSuccessors;
+
+ for (Block &block : definingRegion->getBlocks()) {
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+ // Skip the block if value is not in the `liveOut` set.
+ if (!blockLiveness->isLiveOut(value))
+ continue;
+
+ // Find successors that do not have `value` in the `liveIn` set.
+ for (Block *successor : block.getSuccessors()) {
+ const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
+
+ if (!succLiveness->isLiveIn(value))
+ dropRefSuccessors.insert(successor);
+ }
+ }
+
+ // Drop reference in all successor blocks that do not have the `value` in
+ // their `liveIn` set.
+ for (Block *dropRefSuccessor : dropRefSuccessors) {
+ builder.setInsertionPointToStart(dropRefSuccessor);
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ // ------------------------------------------------------------------------ //
+ // Find all `async.execute` operation that take `value` as an operand
+ // (dependency token or async value), or capture implicitly by the nested
+ // region. Each `async.execute` operation will require `add_ref` operation
+ // to keep all captured values alive until it will finish its execution.
+ // ------------------------------------------------------------------------ //
+
+ llvm::SmallSet<ExecuteOp, 4> executeOperations;
+
+ auto trackAsyncExecute = [&](Operation *op) {
+ if (auto execute = dyn_cast<ExecuteOp>(op))
+ executeOperations.insert(execute);
+ };
+
+ for (Operation *user : value.getUsers()) {
+ // Follow parent operations up until the operation in the `definingRegion`.
+ while (user->getParentRegion() != definingRegion) {
+ trackAsyncExecute(user);
+ user = user->getParentOp();
+ assert(user != nullptr && "value user lies outside of the value region");
+ }
+
+ // Don't forget to process the parent in the `definingRegion` (can be the
+ // original user operation itself).
+ trackAsyncExecute(user);
+ }
+
+ // Process all `async.execute` operations capturing `value`.
+ for (ExecuteOp execute : executeOperations) {
+ // Add a reference before the execute operation to keep the reference
+ // counted alive before the async region completes execution.
+ builder.setInsertionPoint(execute.getOperation());
+ builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1));
+
+ // Drop the reference inside the async region before completion.
+ OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
+ executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ return success();
+}
+
+void AsyncRefCountingPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
+ // because otherwise automatic reference counting will produce incorrect
+ // results.
+ WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
+ if (isa<AddRefOp, DropRefOp>(op))
+ return op->emitError() << "explicit reference counting is not supported";
+ return WalkResult::advance();
+ });
+
+ if (refCountingWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Add reference counting to block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(addAutomaticRefCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Add reference counting to operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(addAutomaticRefCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
+ return std::make_unique<AsyncRefCountingPass>();
+}
--- /dev/null
+//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimize Async dialect reference counting operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingOptimizationPass
+ : public AsyncRefCountingOptimizationBase<
+ AsyncRefCountingOptimizationPass> {
+public:
+ AsyncRefCountingOptimizationPass() = default;
+ void runOnFunction() override;
+
+private:
+ LogicalResult optimizeReferenceCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult
+AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
+ Region *definingRegion = value.getParentRegion();
+
+ // Find all users of the `value` inside each block, including operations that
+ // do not use `value` directly, but have a direct use inside nested region(s).
+ //
+ // Example:
+ //
+ // ^bb1:
+ // %token = ...
+ // scf.if %cond {
+ // ^bb2:
+ // async.await %token : !async.token
+ // }
+ //
+ // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
+ //
+ // In addition to the operation that uses the `value` we also keep track if
+ // this user is an `async.execute` operation itself, or has `async.execute`
+ // operations in the nested regions that do use the `value`.
+
+ struct UserInfo {
+ Operation *operation;
+ bool hasExecuteUser;
+ };
+
+ struct BlockUsersInfo {
+ llvm::SmallVector<AddRefOp, 4> addRefs;
+ llvm::SmallVector<DropRefOp, 4> dropRefs;
+ llvm::SmallVector<UserInfo, 4> users;
+ };
+
+ llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
+
+ auto updateBlockUsersInfo = [&](UserInfo user) {
+ BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
+ info.users.push_back(user);
+
+ if (auto addRef = dyn_cast<AddRefOp>(user.operation))
+ info.addRefs.push_back(addRef);
+ if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
+ info.dropRefs.push_back(dropRef);
+ };
+
+ for (Operation *user : value.getUsers()) {
+ bool isAsyncUser = isa<ExecuteOp>(user);
+
+ while (user->getParentRegion() != definingRegion) {
+ updateBlockUsersInfo({user, isAsyncUser});
+ user = user->getParentOp();
+ isAsyncUser |= isa<ExecuteOp>(user);
+ assert(user != nullptr && "value user lies outside of the value region");
+ }
+
+ updateBlockUsersInfo({user, isAsyncUser});
+ }
+
+ // Sort all operations found in the block.
+ auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
+ auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
+ return a->isBeforeInBlock(b);
+ };
+ llvm::sort(info.addRefs, isBeforeInBlock);
+ llvm::sort(info.dropRefs, isBeforeInBlock);
+ llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
+ return isBeforeInBlock(a.operation, b.operation);
+ });
+
+ return info;
+ };
+
+ // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
+ // blocks that modify the reference count of the `value`.
+ for (auto &kv : blockUsers) {
+ BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
+
+ // Find all cancellable pairs first and erase them later to keep all
+ // pointers in the `info` valid until the end.
+ //
+ // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
+ llvm::SmallDenseMap<Operation *, Operation *> cancellable;
+
+ for (AddRefOp addRef : info.addRefs) {
+ for (DropRefOp dropRef : info.dropRefs) {
+ // `drop_ref` operation after the `add_ref` with matching count.
+ if (dropRef.count() != addRef.count() ||
+ dropRef.getOperation()->isBeforeInBlock(addRef.getOperation()))
+ continue;
+
+ // `drop_ref` was already marked for removal.
+ if (cancellable.find(dropRef.getOperation()) != cancellable.end())
+ continue;
+
+ // Check `value` users between `addRef` and `dropRef` in the `block`.
+ Operation *addRefOp = addRef.getOperation();
+ Operation *dropRefOp = dropRef.getOperation();
+
+ // If there is a "regular" user after the `async.execute` user it is
+ // unsafe to erase cancellable reference counting operations pair,
+ // because async region can complete before the "regular" user and
+ // destroy the reference counted value.
+ bool hasExecuteUser = false;
+ bool unsafeToCancel = false;
+
+ for (UserInfo &user : info.users) {
+ Operation *op = user.operation;
+
+ // `user` operation lies after `addRef` ...
+ if (op == addRefOp || op->isBeforeInBlock(addRefOp))
+ continue;
+ // ... and before `dropRef`.
+ if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
+ break;
+
+ bool isRegularUser = !user.hasExecuteUser;
+ bool isExecuteUser = user.hasExecuteUser;
+
+ // It is unsafe to cancel `addRef` / `dropRef` pair.
+ if (isRegularUser && hasExecuteUser) {
+ unsafeToCancel = true;
+ break;
+ }
+
+ hasExecuteUser |= isExecuteUser;
+ }
+
+ // Mark the pair of reference counting operations for removal.
+ if (!unsafeToCancel)
+ cancellable[dropRef.getOperation()] = addRef.getOperation();
+
+ // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
+ // all the following pairs will be also unsafe.
+ break;
+ }
+ }
+
+ // Erase all cancellable `addRef <-> dropRef` operation pairs.
+ for (auto &kv : cancellable) {
+ kv.first->erase();
+ kv.second->erase();
+ }
+ }
+
+ return success();
+}
+
+void AsyncRefCountingOptimizationPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Optimize reference counting for values defined by block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(optimizeReferenceCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Optimize reference counting for values defined by operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(optimizeReferenceCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRefCountingOptimizationPass() {
+ return std::make_unique<AsyncRefCountingOptimizationPass>();
+}
add_mlir_dialect_library(MLIRAsyncTransforms
AsyncParallelFor.cpp
+ AsyncRefCounting.cpp
+ AsyncRefCountingOptimization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
#include <atomic>
+#include <cassert>
#include <condition_variable>
#include <functional>
#include <iostream>
// Async runtime API.
//===----------------------------------------------------------------------===//
-struct AsyncToken {
- bool ready = false;
+namespace {
+
+// Forward declare class defined below.
+class RefCounted;
+
+// -------------------------------------------------------------------------- //
+// AsyncRuntime orchestrates all async operations and Async runtime API is built
+// on top of the default runtime instance.
+// -------------------------------------------------------------------------- //
+
+class AsyncRuntime {
+public:
+ AsyncRuntime() : numRefCountedObjects(0) {}
+
+ ~AsyncRuntime() {
+ assert(getNumRefCountedObjects() == 0 &&
+ "all ref counted objects must be destroyed");
+ }
+
+ int32_t getNumRefCountedObjects() {
+ return numRefCountedObjects.load(std::memory_order_relaxed);
+ }
+
+private:
+ friend class RefCounted;
+
+ // Count the total number of reference counted objects in this instance
+ // of an AsyncRuntime. For debugging purposes only.
+ void addNumRefCountedObjects() {
+ numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
+ }
+ void dropNumRefCountedObjects() {
+ numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
+ }
+
+ std::atomic<int32_t> numRefCountedObjects;
+};
+
+// Returns the default per-process instance of an async runtime.
+AsyncRuntime *getDefaultAsyncRuntimeInstance() {
+ static auto runtime = std::make_unique<AsyncRuntime>();
+ return runtime.get();
+}
+
+// -------------------------------------------------------------------------- //
+// A base class for all reference counted objects created by the async runtime.
+// -------------------------------------------------------------------------- //
+
+class RefCounted {
+public:
+ RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
+ : runtime(runtime), refCount(refCount) {
+ runtime->addNumRefCountedObjects();
+ }
+
+ virtual ~RefCounted() {
+ assert(refCount.load() == 0 && "reference count must be zero");
+ runtime->dropNumRefCountedObjects();
+ }
+
+ RefCounted(const RefCounted &) = delete;
+ RefCounted &operator=(const RefCounted &) = delete;
+
+ void addRef(int32_t count = 1) { refCount.fetch_add(count); }
+
+ void dropRef(int32_t count = 1) {
+ int32_t previous = refCount.fetch_sub(count);
+ assert(previous >= count && "reference count should not go below zero");
+ if (previous == count)
+ destroy();
+ }
+
+protected:
+ virtual void destroy() { delete this; }
+
+private:
+ AsyncRuntime *runtime;
+ std::atomic<int32_t> refCount;
+};
+
+} // namespace
+
+struct AsyncToken : public RefCounted {
+ // AsyncToken created with a reference count of 2 because it will be returned
+ // to the `async.execute` caller and also will be later on emplaced by the
+ // asynchronously executed task. If the caller immediately will drop its
+ // reference we must ensure that the token will be alive until the
+ // asynchronous operation is completed.
+ AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
+
+ // Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
+
+ bool ready = false;
std::vector<std::function<void()>> awaiters;
};
-struct AsyncGroup {
- std::atomic<int> pendingTokens{0};
- std::atomic<int> rank{0};
+struct AsyncGroup : public RefCounted {
+ AsyncGroup(AsyncRuntime *runtime)
+ : RefCounted(runtime), pendingTokens(0), rank(0) {}
+
+ std::atomic<int> pendingTokens;
+ std::atomic<int> rank;
+
+ // Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
+
std::vector<std::function<void()>> awaiters;
};
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
+ RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+ refCounted->addRef(count);
+}
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
+ RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+ refCounted->dropRef(count);
+}
+
// Create a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
- AsyncToken *token = new AsyncToken;
+ AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
return token;
}
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
- AsyncGroup *group = new AsyncGroup;
+ AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
return group;
}
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
+ // Get the rank of the token inside the group before we drop the reference.
+ int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
- auto onTokenReady = [group]() {
+ auto onTokenReady = [group, token](bool dropRef) {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
+
+ // We no longer need the token or the group, drop references on them.
+ if (dropRef) {
+ group->dropRef();
+ token->dropRef();
+ }
};
- if (token->ready)
- onTokenReady();
- else
- token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
+ if (token->ready) {
+ onTokenReady(false);
+ } else {
+ group->addRef();
+ token->addRef();
+ token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
+ }
- return group->rank.fetch_add(1);
+ return rank;
}
// Switches `async.token` to ready state and runs all awaiters.
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
+
+ // Async tokens created with a ref count `2` to keep token alive until the
+ // async task completes. Drop this reference explicitly when token emplaced.
+ token->dropRef();
}
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
- auto execute = [handle, resume]() {
+ auto execute = [handle, resume, token](bool dropRef) {
+ if (dropRef)
+ token->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
- if (token->ready)
- execute();
- else
- token->awaiters.push_back([execute]() { execute(); });
+ if (token->ready) {
+ execute(false);
+ } else {
+ token->addRef();
+ token->awaiters.push_back([execute]() { execute(true); });
+ }
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
- auto execute = [handle, resume]() {
+ auto execute = [handle, resume, group](bool dropRef) {
+ if (dropRef)
+ group->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
- if (group->pendingTokens == 0)
- execute();
- else
- group->awaiters.push_back([execute]() { execute(); });
+ if (group->pendingTokens == 0) {
+ execute(false);
+ } else {
+ group->addRef();
+ group->awaiters.push_back([execute]() { execute(true); });
+ }
}
//===----------------------------------------------------------------------===//
// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
+// CHECK-LABEL: reference_counting
+func @reference_counting(%arg0: !async.token) {
+ // CHECK: %[[C2:.*]] = constant 2 : i32
+ // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]])
+ async.add_ref %arg0 {count = 2 : i32} : !async.token
+
+ // CHECK: %[[C1:.*]] = constant 1 : i32
+ // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]])
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: execute_no_async_args
func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
--- /dev/null
+// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
+
+// CHECK-LABEL: @cancellable_operations_0
+func @cancellable_operations_0(%arg0: !async.token) {
+ // CHECK-NOT: async.add_ref
+ // CHECK-NOT: async.drop_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_1
+func @cancellable_operations_1(%arg0: !async.token) {
+ // CHECK-NOT: async.add_ref
+ // CHECK: async.execute
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.execute [%arg0] {
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK-NEXT: async.yield
+ async.yield
+ }
+ // CHECK-NOT: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_2
+func @cancellable_operations_2(%arg0: !async.token) {
+ // CHECK: async.await
+ // CHECK-NEXT: async.await
+ // CHECK-NEXT: async.await
+ // CHECK-NEXT: return
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_3
+func @cancellable_operations_3(%arg0: !async.token) {
+ // CHECK-NOT: add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ %token = async.execute {
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.yield
+ }
+ // CHECK-NOT: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_0
+func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
+ // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
+ // that the body of the `async.execute` operation will run before the await
+ // operation in the function body, and will destroy the `%arg0` token.
+ // CHECK: add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ %token = async.execute {
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: async.yield
+ async.yield
+ }
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_1
+func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
+ // Same reason as above, although `async.execute` is inside the nested
+ // region or "regular" opeation.
+ //
+ // NOTE: This test is not correct w.r.t. reference counting, and at runtime
+ // would leak %arg0 value if %arg1 is false. IR like this will not be
+ // constructed by automatic reference counting pass, because it would
+ // place `async.add_ref` right before the `async.execute` inside `scf.if`.
+
+ // CHECK: async.add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ scf.if %arg1 {
+ %token = async.execute {
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.yield
+ }
+ }
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
--- /dev/null
+// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @cond
+func private @cond() -> i1
+
+// CHECK-LABEL: @token_arg_no_uses
+func @token_arg_no_uses(%arg0: !async.token) {
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_arg_conditional_await
+func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
+ cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ return
+^bb2:
+ // CHECK: async.await %arg0
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ async.await %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_no_uses
+func @token_no_uses() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ %token = async.execute {
+ async.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @token_return
+func @token_return() -> !async.token {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: return %[[TOKEN]]
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await
+func @token_await() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_await_and_return
+func @token_await_and_return() -> !async.token {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK-NOT: async.drop_ref
+ async.await %token : !async.token
+ // CHECK: return %[[TOKEN]]
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await_inside_scf_if
+func @token_await_inside_scf_if(%arg0: i1) {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: scf.if %arg0 {
+ scf.if %arg0 {
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ }
+ // CHECK: }
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_conditional_await
+func @token_conditional_await(%arg0: i1) {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+^bb2:
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_await_in_the_loop
+func @token_await_in_the_loop() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ br ^bb1
+^bb1:
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ %0 = call @cond(): () -> (i1)
+ cond_br %0, ^bb1, ^bb2
+^bb2:
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_defined_in_the_loop
+func @token_defined_in_the_loop() {
+ br ^bb1
+^bb1:
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ %0 = call @cond(): () -> (i1)
+ cond_br %0, ^bb1, ^bb2
+^bb2:
+ return
+}
+
+// CHECK-LABEL: @token_capture
+func @token_capture() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK-NEXT: async.yield
+ async.await %token : !async.token
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_nested_capture
+func @token_nested_capture() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute {
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_1:.*]] = async.execute
+ %token_1 = async.execute {
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_2:.*]] = async.execute
+ %token_2 = async.execute {
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_dependency
+func @token_dependency() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute[%token] {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK-NEXT: async.yield
+ async.yield
+ }
+
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ // CHECK: async.await %[[TOKEN_0]]
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ async.await %token_0 : !async.token
+
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @value_operand
+func @value_operand() -> f32 {
+ // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
+ %token, %results = async.execute -> !async.value<f32> {
+ %0 = constant 0.0 : f32
+ async.yield %0 : f32
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>) {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+ // CHECK: async.yield
+ async.yield
+ }
+
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+
+ // CHECK: async.await %[[TOKEN_0]]
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ async.await %token_0 : !async.token
+
+ // CHECK: async.await %[[RESULTS]]
+ // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+ %0 = async.await %results : !async.value<f32>
+
+ // CHECK: return
+ return %0 : f32
+}
%3 = addi %1, %2 : index
return %3 : index
}
+
+// CHECK-LABEL: @add_ref
+func @add_ref(%arg0: !async.token) {
+ // CHECK: async.add_ref %arg0 {count = 1 : i32}
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @drop_ref
+func @drop_ref(%arg0: !async.token) {
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
-// RUN: mlir-opt %s -convert-async-to-llvm \
+// RUN: mlir-opt %s -async-ref-counting \
+// RUN: -convert-async-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
-// RUN: mlir-opt %s -convert-async-to-llvm \
+// RUN: mlir-opt %s -async-ref-counting \
+// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm \
// RUN: -convert-std-to-llvm \