Value c0 = b.create<ConstantIndexOp>(0);
Value c1 = b.create<ConstantIndexOp>(1);
- // Create an async.group to wait on all async tokens from the concurrent
- // execution of multiple parallel compute function. First block will be
- // executed synchronously in the caller thread.
- Value groupSize = b.create<SubIOp>(blockCount, c1);
- Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
-
// Appends operands shared by async dispatch and parallel compute functions to
// the given operands vector.
auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
};
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+ // Create an async.group to wait on all async tokens from the concurrent
+ // execution of multiple parallel compute function. First block will be
+ // executed synchronously in the caller thread.
+ Value groupSize = b.create<SubIOp>(blockCount, c1);
+ Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+
ImplicitLocOpBuilder nb(loc, nestedBuilder);
// Launch async dispatch function for [0, blockCount) range.
nb.create<CallOp>(asyncDispatchFunction.sym_name(),
asyncDispatchFunction.getCallableResults(), operands);
+
+ // Wait for the completion of all parallel compute operations.
+ b.create<AwaitAllOp>(group);
+
nb.create<scf::YieldOp>();
};
// Dispatch either single block compute function, or launch async dispatch.
b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
-
- // Wait for the completion of all parallel compute operations.
- b.create<AwaitAllOp>(group);
}
// Dispatch parallel compute functions by submitting all async compute tasks
// CHECK: scf.if %[[IS_NOOP]] {
// CHECK-NEXT: } else {
- // CHECK: %[[GROUP:.*]] = async.create_group
- // CHECK: scf.if {{.*}} {
+ // CHECK: scf.if {{.*}} {
// CHECK: call @parallel_compute_fn(%[[C0]]
// CHECK: } else {
+ // CHECK: %[[GROUP:.*]] = async.create_group
// CHECK: call @async_dispatch_fn
+ // CHECK: async.await_all %[[GROUP]]
// CHECK: }
- // CHECK: async.await_all %[[GROUP]]
// CHECK: }
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
%one = constant 1.0 : f32