OpBuilder moduleBuilder(module.getBody()->getTerminator());
- // Get values captured by the async region
- llvm::SetVector<mlir::Value> usedAbove;
- getUsedValuesDefinedAbove(execute.body(), usedAbove);
-
- // Collect types of the captured values.
- auto usedAboveTypes =
- llvm::map_range(usedAbove, [](Value value) { return value.getType(); });
- SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end());
+ // Collect all outlined function inputs.
+ llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
+ execute.dependencies().end());
+ getUsedValuesDefinedAbove(execute.body(), functionInputs);
+
+ // Collect types for the outlined function inputs and outputs.
+ auto typesRange = llvm::map_range(
+ functionInputs, [](Value value) { return value.getType(); });
+ SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
auto outputTypes = execute.getResultTypes();
auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
entryBlock->getTerminator());
- // Map from values defined above the execute op to the function arguments.
+ // Await on all dependencies before starting to execute the body region.
+ builder.setInsertionPointToStart(resume);
+ for (size_t i = 0; i < execute.dependencies().size(); ++i)
+ builder.create<AwaitOp>(loc, func.getArgument(i));
+
+ // Map from function inputs defined above the execute op to the function
+ // arguments.
BlockAndValueMapping valueMapping;
- valueMapping.map(usedAbove, func.getArguments());
+ valueMapping.map(functionInputs, func.getArguments());
// Clone all operations from the execute operation body into the outlined
// function body, and replace all `async.yield` operations with a call
// to async runtime to emplace the result token.
- builder.setInsertionPointToStart(resume);
for (Operation &op : execute.body().getOps()) {
if (isa<async::YieldOp>(op)) {
builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
// Replace the original `async.execute` with a call to outlined function.
OpBuilder callBuilder(execute);
- SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end());
- auto callOutlinedFunc = callBuilder.create<CallOp>(
- loc, func.getName(), execute.getResultTypes(), usedAboveArgs);
+ auto callOutlinedFunc =
+ callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
+ functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
- // We currently do not support execute operations that take async
- // token dependencies, async value arguments or produce async results.
- if (!execute.dependencies().empty() || !execute.operands().empty() ||
- !execute.results().empty()) {
- execute.emitOpError(
- "Can't outline async.execute op with async dependencies, arguments "
- "or returned async results");
+ // We currently do not support execute operations that have async value
+ // operands or produce async results.
+ if (!execute.operands().empty() || !execute.results().empty()) {
+ execute.emitOpError("can't outline async.execute op with async value "
+ "operands or returned async results");
return WalkResult::interrupt();
}
}
// Function outlined from the async.execute operation.
-// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
// Create token for return op, and mark a function as a coroutine.
}
// Function outlined from the inner async.execute operation.
-// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
// Function outlined from the outer async.execute operation.
-// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
+// CHECK-LABEL: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
// CHECK: store %arg2, %arg1[%c0] : memref<1xf32>
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+// -----
+
+// CHECK-LABEL: async_execute_token_dependency
+func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
+ // CHECK: %0 = call @async_execute_fn(%arg0, %arg1)
+ %token = async.execute {
+ %c0 = constant 0 : index
+ store %arg0, %arg1[%c0] : memref<1xf32>
+ async.yield
+ }
+ // CHECK: %1 = call @async_execute_fn_0(%0, %arg0, %arg1)
+ %token_0 = async.execute [%token] {
+ %c0 = constant 0 : index
+ store %arg0, %arg1[%c0] : memref<1xf32>
+ async.yield
+ }
+ return
+}
+
+// Function outlined from the first async.execute operation.
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
+// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
+// CHECK: call @mlirAsyncRuntimeExecute
+// CHECK: llvm.call @llvm.coro.suspend
+// CHECK: store %arg0, %arg1[%c0] : memref<1xf32>
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
+
+// Function outlined from the second async.execute operation with dependency.
+// CHECK-LABEL: func @async_execute_fn_0(%arg0: !llvm.ptr<i8>, %arg1: f32, %arg2: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
+// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
+
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Suspend coroutine second time waiting for the completion of token dependency.
+// CHECK: llvm.call @llvm.coro.save
+// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%arg0, %[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Emplace result token after second resumption.
+// CHECK: store %arg1, %arg2[%c0] : memref<1xf32>
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
call @print_memref_f32(%U): (memref<*xf32>) -> ()
- %inner = async.execute {
+ // No op async region to create a token for testing async dependency.
+ %noop = async.execute {
// CHECK: Current thread id: [[THREAD1:.*]]
+ call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
+ async.yield
+ }
+
+ %inner = async.execute [%noop] {
+ // CHECK: Current thread id: [[THREAD2:.*]]
// CHECK: [1, 2, 3, 0]
store %c3, %A[%i2]: memref<4xf32>
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
}
async.await %inner : !async.token
- // CHECK: Current thread id: [[THREAD2:.*]]
+ // CHECK: Current thread id: [[THREAD3:.*]]
// CHECK: [1, 2, 3, 4]
store %c4, %A[%i3]: memref<4xf32>
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()