[mlir:Async] Remove async operations if it is statically known that the parallel...
authorEugene Zhulenev <ezhulenev@google.com>
Mon, 28 Jun 2021 00:44:31 +0000 (17:44 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Tue, 29 Jun 2021 16:26:28 +0000 (09:26 -0700)
Depends On D104850

Add a test that verifies that canonicalization removes all async overheads if it is statically known that the scf.parallel operation will be computed using a single block.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D104891

mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir [new file with mode: 0644]

index d84b8f8..0783009 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
index f9ddd67..d168b8c 100644 (file)
@@ -177,6 +177,8 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
   let arguments = (ins Index:$size);
   let results = (outs Async_GroupType:$result);
 
+  let hasCanonicalizeMethod = 1;
+
   let assemblyFormat = "$size `:` type($result) attr-dict";
 }
 
index a06b2b6..bd627ed 100644 (file)
@@ -246,6 +246,36 @@ static LogicalResult verify(ExecuteOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+/// CreateGroupOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
+                                          PatternRewriter &rewriter) {
+  // Find all `await_all` users of the group.
+  llvm::SmallVector<AwaitAllOp> awaitAllUsers;
+
+  auto isAwaitAll = [&](Operation *op) -> bool {
+    if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
+      awaitAllUsers.push_back(awaitAll);
+      return true;
+    }
+    return false;
+  };
+
+  // Check if all users of the group are `await_all` operations.
+  if (!llvm::all_of(op->getUsers(), isAwaitAll))
+    return failure();
+
+  // If group is only awaited without adding anything to it, we can safely erase
+  // the create operation and all users.
+  for (AwaitAllOp awaitAll : awaitAllUsers)
+    rewriter.eraseOp(awaitAll);
+  rewriter.eraseOp(op);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 /// AwaitOp
 //===----------------------------------------------------------------------===//
 
index 1d545a5..a104fb7 100644 (file)
@@ -513,18 +513,48 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   Value groupSize = b.create<SubIOp>(blockCount, c1);
   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
 
-  // Pack the async dispath function operands to launch the work splitting.
-  SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize};
-  asyncDispatchOperands.append(tripCounts);
-  asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end());
-  asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end());
-  asyncDispatchOperands.append(op.step().begin(), op.step().end());
-  asyncDispatchOperands.append(parallelComputeFunction.captures);
-
-  // Launch async dispatch function for [0, blockCount) range.
-  b.create<CallOp>(asyncDispatchFunction.sym_name(),
-                   asyncDispatchFunction.getCallableResults(),
-                   asyncDispatchOperands);
+  // Appends operands shared by async dispatch and parallel compute functions to
+  // the given operands vector.
+  auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
+    operands.append(tripCounts);
+    operands.append(op.lowerBound().begin(), op.lowerBound().end());
+    operands.append(op.upperBound().begin(), op.upperBound().end());
+    operands.append(op.step().begin(), op.step().end());
+    operands.append(parallelComputeFunction.captures);
+  };
+
+  // Check if the block size is one, in this case we can skip the async dispatch
+  // completely. If this will be known statically, then canonicalization will
+  // erase async group operations.
+  Value isSingleBlock = b.create<CmpIOp>(CmpIPredicate::eq, blockCount, c1);
+
+  auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+    ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+    // Call parallel compute function for the single block.
+    SmallVector<Value> operands = {c0, blockSize};
+    appendBlockComputeOperands(operands);
+
+    nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
+                      parallelComputeFunction.func.getCallableResults(),
+                      operands);
+    nb.create<scf::YieldOp>();
+  };
+
+  auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+    ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+    // Launch async dispatch function for [0, blockCount) range.
+    SmallVector<Value> operands = {group, c0, blockCount, blockSize};
+    appendBlockComputeOperands(operands);
+
+    nb.create<CallOp>(asyncDispatchFunction.sym_name(),
+                      asyncDispatchFunction.getCallableResults(), operands);
+    nb.create<scf::YieldOp>();
+  };
+
+  // Dispatch either single block compute function, or launch async dispatch.
+  b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
 
   // Wait for the completion of all parallel compute operations.
   b.create<AwaitAllOp>(group);
index df538a4..a6e308e 100644 (file)
@@ -3,8 +3,13 @@
 
 // CHECK-LABEL: @loop_1d
 func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
+  // CHECK: %[[C0:.*]] = constant 0 : index
   // CHECK: %[[GROUP:.*]] = async.create_group
-  // CHECK: call @async_dispatch_fn
+  // CHECK: scf.if {{.*}} {
+  // CHECK:   call @parallel_compute_fn(%[[C0]]
+  // CHECK: } else {
+  // CHECK:   call @async_dispatch_fn
+  // CHECK: }
   // CHECK: async.await_all %[[GROUP]]
   scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
     %one = constant 1.0 : f32
diff --git a/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir b/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir
new file mode 100644 (file)
index 0000000..e26d990
--- /dev/null
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s                                                            \
+// RUN:    -async-parallel-for=async-dispatch=true                             \
+// RUN:    -canonicalize -inline -symbol-dce                                   \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s                                                            \
+// RUN:    -async-parallel-for=async-dispatch=false                            \
+// RUN:    -canonicalize -inline -symbol-dce                                   \
+// RUN: | FileCheck %s
+
+// Check that if we statically know that the parallel operation has a single
+// block then all async operations will be canonicalized away and we will
+// end up with a single synchonous compute function call.
+
+// CHECK-LABEL: @loop_1d(
+// CHECK:       %[[MEMREF:.*]]: memref<?xf32>
+func @loop_1d(%arg0: memref<?xf32>) {
+  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK-DAG: %[[C100:.*]] = constant 100 : index
+  // CHECK-DAG: %[[ONE:.*]] = constant 1.000000e+00 : f32
+  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]]
+  // CHECK:       memref.store %[[ONE]], %[[MEMREF]][%[[I]]]
+  %lb = constant 0 : index
+  %ub = constant 100 : index
+  %st = constant 1 : index
+  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+    %one = constant 1.0 : f32
+    memref.store %one, %arg0[%i] : memref<?xf32>
+  }
+
+  return
+}