Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group.
Reviewed By: ftynse, herhut
Differential Revision: https://reviews.llvm.org/D104780
let summary = "creates an empty async group";
let description = [{
The `async.create_group` allocates an empty async group. Async tokens or
- values can be added to this group later.
+ values can be added to this group later. The size of the group must be
+ specified at construction time, and `await_all` operation will first
+ wait until the number of added tokens or values reaches the group size.
Example:
```mlir
- %0 = async.create_group
+ %size = ... : index
+ %group = async.create_group %size : !async.group
...
- async.await_all %0
+ async.await_all %group
```
}];
+ let arguments = (ins Index:$size);
let results = (outs Async_GroupType:$result);
- let assemblyFormat = "attr-dict";
+ let assemblyFormat = "$size `:` type($result) attr-dict";
}
def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
Example:
```mlir
- %0 = async.create_group
+ %0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
```
Example:
```mlir
- %0 = async.create_group
+ %0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
// Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`.
def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
- let summary = "creates an async runtime value (token, value or group)";
+ let summary = "creates an async runtime token or value";
let description = [{
- The `async.runtime.create` operation creates an async dialect value
- (token, value or group). Tokens and values are created in non-ready state.
- Groups are created in empty state.
+ The `async.runtime.create` operation creates an async dialect token or
+ value. Tokens and values are created in the non-ready state.
}];
- let results = (outs Async_AnyAsyncType:$result);
+ let results = (outs Async_AnyValueOrTokenType:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
+def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> {
+ let summary = "creates an async runtime group";
+ let description = [{
+ The `async.runtime.create_group` operation creates an async dialect group
+ of the given size. Group created in the empty state.
+ }];
+
+ let arguments = (ins Index:$size);
+ let results = (outs Async_GroupType:$result);
+ let assemblyFormat = "$size `:` type($result) attr-dict ";
+}
+
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
let summary = "switches token or value to available state";
let description = [{
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t);
// Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup();
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size);
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
}
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
+ auto i64 = IntegerType::get(ctx, 64);
+ return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
}
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
- // Tokens and Groups lowered to function calls without arguments.
- if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
- rewriter.replaceOpWithNewOp<CallOp>(
- op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
- converter->convertType(resultType));
+ // Tokens creation maps to a simple function call.
+ if (resultType.isa<TokenType>()) {
+ rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
+ converter->convertType(resultType));
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
+// Convert async.runtime.create_group to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeCreateGroupOpLowering
+ : public OpConversionPattern<RuntimeCreateGroupOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TypeConverter *converter = getTypeConverter();
+ Type resultType = op->getResultTypes()[0];
+
+ rewriter.replaceOpWithNewOp<CallOp>(
+ op, kCreateGroup, converter->convertType(resultType), operands);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
// Convert async.runtime.set_available to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
- patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
- RuntimeLoadOpLowering>(llvmConverter, ctx);
+ patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
+ RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
+ ctx);
// Lower async coroutine operations to LLVM coroutine intrinsics.
patterns
numBlocks[i] = divup(tripCounts[i], blockSize[i]);
}
+ // Total number of async compute blocks.
+ Value totalBlocks = numBlocks[0];
+ for (size_t i = 1; i < op.getNumLoops(); ++i)
+ totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
+
// Create an async.group to wait on all async tokens from async execute ops.
- auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
+ auto group =
+ rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
// Build a scf.for loop nest from the parallel operation.
}
//===----------------------------------------------------------------------===//
-// Convert async.create_group operation to async.runtime.create
+// Convert async.create_group operation to async.runtime.create_group
//===----------------------------------------------------------------------===//
namespace {
LogicalResult
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
- op, GroupType::get(op->getContext()));
+ rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
+ op, GroupType::get(op->getContext()), operands);
return success();
}
};
// values to await on all of them together (wait for the completion of all
// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
- AsyncGroup(AsyncRuntime *runtime)
- : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
+ AsyncGroup(AsyncRuntime *runtime, int64_t size)
+ : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
std::atomic<int> numErrors;
}
// Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
- AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
+ AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
return group;
}
// Get the rank of the token inside the group before we drop the reference.
int rank = group->rank.fetch_add(1);
- group->pendingTokens.fetch_add(1);
auto onTokenReady = [group, token]() {
// Increment the number of errors in the group.
if (State(token->state).isError())
group->numErrors.fetch_add(1);
+ // If pending tokens go below zero it means that more tokens than the group
+ // size were added to this group.
+ assert(group->pendingTokens > 0 && "wrong group size");
+
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
-// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
// CHECK-LABEL: @create_token
func @create_token() {
// CHECK-LABEL: @create_group
func @create_group() {
- // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %0 = async.runtime.create : !async.group
+ // CHECK: %[[C:.*]] = constant 1 : index
+ // CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]])
+ %0 = async.runtime.create_group %c: !async.group
return
}
// CHECK-LABEL: @await_group
func @await_group() {
+ %c = constant 1 : index
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %0 = async.runtime.create : !async.group
+ %0 = async.runtime.create_group %c: !async.group
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.runtime.await %0 : !async.group
return
// CHECK-LABEL: @await_and_resume_group
func @await_and_resume_group() {
+ %c = constant 1 : index
%0 = async.coro.id
// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
%1 = async.coro.begin %0
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
- %2 = async.runtime.create : !async.group
+ %2 = async.runtime.create_group %c : !async.group
// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
// CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
// CHECK-LABEL: @add_token_to_group
func @add_token_to_group() {
+ %c = constant 1 : index
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
%0 = async.runtime.create : !async.token
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %1 = async.runtime.create : !async.group
+ %1 = async.runtime.create_group %c : !async.group
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.runtime.add_to_group %0, %1 : !async.token
return
// CHECK-LABEL: async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
- // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
- %0 = async.create_group
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
+ %0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
- // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
+ // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.add_to_group %token, %0 : !async.token
// CHECK: call @async_execute_fn_0
async.yield
}
- // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
+ // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.await_all %0
return
// CHECK-LABEL: @async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
- // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
- %0 = async.create_group
+ // CHECK: %[[C:.*]] = constant 1 : index
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group
+ %0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
}
// CHECK-LABEL: @create_group_and_await_all
-func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
- %0 = async.create_group
+func @create_group_and_await_all(%arg0: !async.token,
+ %arg1: !async.value<f32>) -> index {
+ %c = constant 2 : index
+ %0 = async.create_group %c : !async.group
// CHECK: async.add_to_group %arg0
// CHECK: async.add_to_group %arg1
// CHECK-LABEL: @create_group
func @create_group() -> !async.group {
- // CHECK: %0 = async.runtime.create : !async.group
- %0 = async.runtime.create : !async.group
- // CHECK: return %0 : !async.group
+ // CHECK: %[[C:.*]] = constant 10 : index
+ %c = constant 10 : index
+ // CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group
+ %0 = async.runtime.create_group %c : !async.group
+ // CHECK: return %[[V]] : !async.group
return %0 : !async.group
}
// Check error propagation from a token to the group.
// ------------------------------------------------------------------------ //
- %group0 = async.create_group
+ %c2 = constant 2 : index
+ %group0 = async.create_group %c2 : !async.group
%token4 = async.execute {
async.yield
// RUN: | FileCheck %s
func @main() {
- %group = async.create_group
+ %c1 = constant 1 : index
+ %c5 = constant 5 : index
+
+ %group = async.create_group %c5 : !async.group
%token0 = async.execute { async.yield }
%token1 = async.execute { async.yield }
async.yield
}
- %group0 = async.create_group
+ %group0 = async.create_group %c1 : !async.group
%5 = async.add_to_group %token5, %group0 : !async.token
async.await_all %group0