From d43b23608ad664f02f56e965ca78916bde220950 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 23 Jun 2021 06:24:09 -0700 Subject: [PATCH] [mlir:Async] Add the size parameter to the async.group Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group. Reviewed By: ftynse, herhut Differential Revision: https://reviews.llvm.org/D104780 --- mlir/include/mlir/Dialect/Async/IR/AsyncOps.td | 37 ++++++++++++++------ mlir/include/mlir/ExecutionEngine/AsyncRuntime.h | 2 +- mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 40 +++++++++++++++++----- .../Dialect/Async/Transforms/AsyncParallelFor.cpp | 8 ++++- .../Async/Transforms/AsyncToAsyncRuntime.cpp | 6 ++-- mlir/lib/ExecutionEngine/AsyncRuntime.cpp | 13 ++++--- .../AsyncToLLVM/convert-runtime-to-llvm.mlir | 18 ++++++---- .../Conversion/AsyncToLLVM/convert-to-llvm.mlir | 9 ++--- .../test/Dialect/Async/async-to-async-runtime.mlir | 6 ++-- mlir/test/Dialect/Async/ops.mlir | 6 ++-- mlir/test/Dialect/Async/runtime.mlir | 8 +++-- mlir/test/mlir-cpu-runner/async-error.mlir | 3 +- mlir/test/mlir-cpu-runner/async-group.mlir | 7 ++-- 13 files changed, 114 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index 9ef218e..f9ddd67a 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -160,20 +160,24 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> { let summary = "creates an empty async group"; let description = [{ The `async.create_group` allocates an empty async group. Async tokens or - values can be added to this group later. + values can be added to this group later. The size of the group must be + specified at construction time, and `await_all` operation will first + wait until the number of added tokens or values reaches the group size. Example: ```mlir - %0 = async.create_group + %size = ... : index + %group = async.create_group %size : !async.group ... - async.await_all %0 + async.await_all %group ``` }]; + let arguments = (ins Index:$size); let results = (outs Async_GroupType:$result); - let assemblyFormat = "attr-dict"; + let assemblyFormat = "$size `:` type($result) attr-dict"; } def Async_AddToGroupOp : Async_Op<"add_to_group", []> { @@ -186,7 +190,7 @@ def Async_AddToGroupOp : Async_Op<"add_to_group", []> { Example: ```mlir - %0 = async.create_group + %0 = async.create_group %size : !async.group %1 = ... : !async.token %2 = async.add_to_group %1, %0 : !async.token ``` @@ -209,7 +213,7 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> { Example: ```mlir - %0 = async.create_group + %0 = async.create_group %size : !async.group %1 = ... : !async.token %2 = async.add_to_group %1, %0 : !async.token @@ -331,17 +335,28 @@ def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> { // Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`. def Async_RuntimeCreateOp : Async_Op<"runtime.create"> { - let summary = "creates an async runtime value (token, value or group)"; + let summary = "creates an async runtime token or value"; let description = [{ - The `async.runtime.create` operation creates an async dialect value - (token, value or group). Tokens and values are created in non-ready state. - Groups are created in empty state. + The `async.runtime.create` operation creates an async dialect token or + value. Tokens and values are created in the non-ready state. }]; - let results = (outs Async_AnyAsyncType:$result); + let results = (outs Async_AnyValueOrTokenType:$result); let assemblyFormat = "attr-dict `:` type($result)"; } +def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> { + let summary = "creates an async runtime group"; + let description = [{ + The `async.runtime.create_group` operation creates an async dialect group + of the given size. Group created in the empty state. + }]; + + let arguments = (ins Index:$size); + let results = (outs Async_GroupType:$result); + let assemblyFormat = "$size `:` type($result) attr-dict "; +} + def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> { let summary = "switches token or value to available state"; let description = [{ diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h index 3b26bf6..a101b28 100644 --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -66,7 +66,7 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken(); extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t); // Create a new `async.group` in empty state. -extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(); +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size); extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index ff2460b..0156ede 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -89,7 +89,8 @@ struct AsyncAPI { } static FunctionType createGroupFunctionType(MLIRContext *ctx) { - return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); + auto i64 = IntegerType::get(ctx, 64); + return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); } static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { @@ -543,11 +544,10 @@ public: TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; - // Tokens and Groups lowered to function calls without arguments. - if (resultType.isa() || resultType.isa()) { - rewriter.replaceOpWithNewOp( - op, resultType.isa() ? kCreateToken : kCreateGroup, - converter->convertType(resultType)); + // Tokens creation maps to a simple function call. + if (resultType.isa()) { + rewriter.replaceOpWithNewOp(op, kCreateToken, + converter->convertType(resultType)); return success(); } @@ -583,6 +583,29 @@ public: } // namespace //===----------------------------------------------------------------------===// +// Convert async.runtime.create_group to the corresponding runtime API call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeCreateGroupOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TypeConverter *converter = getTypeConverter(); + Type resultType = op->getResultTypes()[0]; + + rewriter.replaceOpWithNewOp( + op, kCreateGroup, converter->convertType(resultType), operands); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// // Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// @@ -967,8 +990,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() { // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. - patterns.add(llvmConverter, ctx); + patterns.add(llvmConverter, + ctx); // Lower async coroutine operations to LLVM coroutine intrinsics. patterns diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index ce2bc70..ba09123 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -165,8 +165,14 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, numBlocks[i] = divup(tripCounts[i], blockSize[i]); } + // Total number of async compute blocks. + Value totalBlocks = numBlocks[0]; + for (size_t i = 1; i < op.getNumLoops(); ++i) + totalBlocks = rewriter.create(loc, totalBlocks, numBlocks[i]); + // Create an async.group to wait on all async tokens from async execute ops. - auto group = rewriter.create(loc, GroupType::get(ctx)); + auto group = + rewriter.create(loc, GroupType::get(ctx), totalBlocks); // Build a scf.for loop nest from the parallel operation. diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 0789a0e..ea8e353 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -302,7 +302,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { } //===----------------------------------------------------------------------===// -// Convert async.create_group operation to async.runtime.create +// Convert async.create_group operation to async.runtime.create_group //===----------------------------------------------------------------------===// namespace { @@ -313,8 +313,8 @@ public: LogicalResult matchAndRewrite(CreateGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, GroupType::get(op->getContext())); + rewriter.replaceOpWithNewOp( + op, GroupType::get(op->getContext()), operands); return success(); } }; diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp index 6bbb8e4..a8aeaec 100644 --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -211,8 +211,8 @@ struct AsyncValue : public RefCounted { // values to await on all of them together (wait for the completion of all // tokens or values added to the group). struct AsyncGroup : public RefCounted { - AsyncGroup(AsyncRuntime *runtime) - : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {} + AsyncGroup(AsyncRuntime *runtime, int64_t size) + : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {} std::atomic pendingTokens; std::atomic numErrors; @@ -249,8 +249,8 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { } // Create a new `async.group` in empty state. -extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); +extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) { + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size); return group; } @@ -261,13 +261,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, // Get the rank of the token inside the group before we drop the reference. int rank = group->rank.fetch_add(1); - group->pendingTokens.fetch_add(1); auto onTokenReady = [group, token]() { // Increment the number of errors in the group. if (State(token->state).isError()) group->numErrors.fetch_add(1); + // If pending tokens go below zero it means that more tokens than the group + // size were added to this group. + assert(group->pendingTokens > 0 && "wrong group size"); + // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir index 74c091e7..9d57ef3 100644 --- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always // CHECK-LABEL: @create_token func @create_token() { @@ -20,8 +20,11 @@ func @create_value() { // CHECK-LABEL: @create_group func @create_group() { - // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %0 = async.runtime.create : !async.group + // CHECK: %[[C:.*]] = constant 1 : index + // CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64 + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]]) + %0 = async.runtime.create_group %c: !async.group return } @@ -81,8 +84,9 @@ func @await_value() { // CHECK-LABEL: @await_group func @await_group() { + %c = constant 1 : index // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %0 = async.runtime.create : !async.group + %0 = async.runtime.create_group %c: !async.group // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]]) async.runtime.await %0 : !async.group return @@ -118,11 +122,12 @@ func @await_and_resume_value() { // CHECK-LABEL: @await_and_resume_group func @await_and_resume_group() { + %c = constant 1 : index %0 = async.coro.id // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin %1 = async.coro.begin %0 // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup - %2 = async.runtime.create : !async.group + %2 = async.runtime.create_group %c : !async.group // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]]) @@ -168,10 +173,11 @@ func @load() -> f32 { // CHECK-LABEL: @add_token_to_group func @add_token_to_group() { + %c = constant 1 : index // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken %0 = async.runtime.create : !async.token // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup - %1 = async.runtime.create : !async.group + %1 = async.runtime.create_group %c : !async.group // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]]) async.runtime.add_to_group %0, %1 : !async.token return diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir index ba2a391..da96f30 100644 --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -170,12 +170,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) { // CHECK-LABEL: async_group_await_all func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { - // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup() - %0 = async.create_group + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup + %0 = async.create_group %c : !async.group // CHECK: %[[TOKEN:.*]] = call @async_execute_fn %token = async.execute { async.yield } - // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0) + // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]]) async.add_to_group %token, %0 : !async.token // CHECK: call @async_execute_fn_0 @@ -184,7 +185,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { async.yield } - // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0) + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]]) async.await_all %0 return diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir index b77b0d6..7564f13 100644 --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -179,8 +179,10 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) { // CHECK-LABEL: @async_group_await_all func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) { - // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group - %0 = async.create_group + // CHECK: %[[C:.*]] = constant 1 : index + %c = constant 1 : index + // CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group + %0 = async.create_group %c : !async.group // CHECK: %[[TOKEN:.*]] = call @async_execute_fn %token = async.execute { async.yield } diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir index a95be65..1ec2b6d 100644 --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -122,8 +122,10 @@ func @await_value(%arg0: !async.value) -> f32 { } // CHECK-LABEL: @create_group_and_await_all -func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value) -> index { - %0 = async.create_group +func @create_group_and_await_all(%arg0: !async.token, + %arg1: !async.value) -> index { + %c = constant 2 : index + %0 = async.create_group %c : !async.group // CHECK: async.add_to_group %arg0 // CHECK: async.add_to_group %arg1 diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir index ede523f..1b39e64 100644 --- a/mlir/test/Dialect/Async/runtime.mlir +++ b/mlir/test/Dialect/Async/runtime.mlir @@ -18,9 +18,11 @@ func @create_value() -> !async.value { // CHECK-LABEL: @create_group func @create_group() -> !async.group { - // CHECK: %0 = async.runtime.create : !async.group - %0 = async.runtime.create : !async.group - // CHECK: return %0 : !async.group + // CHECK: %[[C:.*]] = constant 10 : index + %c = constant 10 : index + // CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group + %0 = async.runtime.create_group %c : !async.group + // CHECK: return %[[V]] : !async.group return %0 : !async.group } diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir index 63b9a00..77616de 100644 --- a/mlir/test/mlir-cpu-runner/async-error.mlir +++ b/mlir/test/mlir-cpu-runner/async-error.mlir @@ -85,7 +85,8 @@ func @main() { // Check error propagation from a token to the group. // ------------------------------------------------------------------------ // - %group0 = async.create_group + %c2 = constant 2 : index + %group0 = async.create_group %c2 : !async.group %token4 = async.execute { async.yield diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir index 8216d15..7df8877 100644 --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -11,7 +11,10 @@ // RUN: | FileCheck %s func @main() { - %group = async.create_group + %c1 = constant 1 : index + %c5 = constant 5 : index + + %group = async.create_group %c5 : !async.group %token0 = async.execute { async.yield } %token1 = async.execute { async.yield } @@ -30,7 +33,7 @@ func @main() { async.yield } - %group0 = async.create_group + %group0 = async.create_group %c1 : !async.group %5 = async.add_to_group %token5, %group0 : !async.token async.await_all %group0 -- 2.7.4