[mlir] Automatic reference counting for Async values + runtime support for ref counte...
authorEugene Zhulenev <ezhulenev@google.com>
Fri, 20 Nov 2020 10:42:28 +0000 (02:42 -0800)
committerEugene Zhulenev <ezhulenev@google.com>
Fri, 20 Nov 2020 11:08:44 +0000 (03:08 -0800)
Depends On D89963

**Automatic reference counting algorithm outline:**

1. `ReturnLike` operations forward the reference counted values without
    modifying the reference count.
2. Use liveness analysis to find blocks in the CFG where the lifetime of
   reference counted values ends, and insert `drop_ref` operations after
   the last use of the value.
3. Insert `add_ref` before the `async.execute` operation capturing the
   value, and pairing `drop_ref` before the async body region terminator,
   to release the captured reference counted value when execution
   completes.
4. If the reference counted value is passed only to some of the block
   successors, insert `drop_ref` operations in the beginning of the blocks
   that do not have reference coutned value uses.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D90716

19 files changed:
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp [new file with mode: 0644]
mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp [new file with mode: 0644]
mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/async-ref-counting-optimization.mlir [new file with mode: 0644]
mlir/test/Dialect/Async/async-ref-counting.mlir [new file with mode: 0644]
mlir/test/Dialect/Async/ops.mlir
mlir/test/mlir-cpu-runner/async-group.mlir
mlir/test/mlir-cpu-runner/async.mlir

index ad5a8aa..d0664b0 100644 (file)
@@ -53,6 +53,16 @@ public:
   using Base::Base;
 };
 
+// -------------------------------------------------------------------------- //
+// Helper functions of Async dialect transformations.
+// -------------------------------------------------------------------------- //
+
+/// Returns true if the type is reference counted. All async dialect types are
+/// reference counted at runtime.
+inline bool isRefCounted(Type type) {
+  return type.isa<TokenType, ValueType, GroupType>();
+}
+
 } // namespace async
 } // namespace mlir
 
index e7a5e90..e33a9e2 100644 (file)
@@ -73,4 +73,8 @@ def Async_AnyValueType : DialectType<AsyncDialect,
 def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
                                            Async_TokenType]>;
 
+def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
+                                    Async_TokenType,
+                                    Async_GroupType]>;
+
 #endif // ASYNC_BASE_TD
index cc98785..80aeabf 100644 (file)
@@ -227,4 +227,62 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
   let assemblyFormat = "$operand attr-dict";
 }
 
+//===----------------------------------------------------------------------===//
+// Async Dialect Automatic Reference Counting Operations.
+//===----------------------------------------------------------------------===//
+
+// All async values (values, tokens, groups) are reference counted at runtime
+// and automatically destructed when reference count drops to 0.
+//
+// All values are semantically created with a reference count of +1 and it is
+// the responsibility of the last async value user to drop reference count.
+//
+// Async values created when:
+//   1. Operation returns async result (e.g. the result of an `async.execute`).
+//   2. Async value passed in as a block argument.
+//
+// It is the responsiblity of the async value user to extend the lifetime by
+// adding a +1 reference, if the reference counted value captured by the
+// asynchronously executed region (`async.execute` operation), and drop it after
+// the last nested use.
+//
+// Reference counting operations can be added to the IR using automatic
+// reference count pass, that relies on liveness analysis to find the last uses
+// of all reference counted values and automatically inserts
+// `drop_ref` operations.
+//
+// See `AsyncRefCountingPass` documentation for the implementation details.
+
+def Async_AddRefOp : Async_Op<"add_ref"> {
+  let summary = "adds a reference to async value";
+  let description = [{
+    The `async.add_ref` operation adds a reference(s) to async value (token,
+    value or group).
+  }];
+
+  let arguments = (ins Async_AnyAsyncType:$operand,
+                       Confined<I32Attr, [IntPositive]>:$count);
+  let results = (outs );
+
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand)
+  }];
+}
+
+def Async_DropRefOp : Async_Op<"drop_ref"> {
+  let summary = "drops a reference to async value";
+  let description = [{
+    The `async.drop_ref` operation drops a reference(s) to async value (token,
+    value or group).
+  }];
+
+  let arguments = (ins Async_AnyAsyncType:$operand,
+                       Confined<I32Attr, [IntPositive]>:$count);
+  let results = (outs );
+
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand)
+  }];
+}
+
 #endif // ASYNC_OPS
index d5a8a82..9716bde 100644 (file)
@@ -19,6 +19,10 @@ namespace mlir {
 
 std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
 
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
+
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
index 51fd4e3..140a3b4 100644 (file)
@@ -24,4 +24,18 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
   let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
 }
 
+def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
+  let summary = "Automatic reference counting for Async dialect data types";
+  let constructor = "mlir::createAsyncRefCountingPass()";
+  let dependentDialects = ["async::AsyncDialect"];
+}
+
+def AsyncRefCountingOptimization :
+    FunctionPass<"async-ref-counting-optimization"> {
+  let summary = "Optimize automatic reference counting operations for the"
+                "Async dialect by removing redundant operations";
+  let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
+  let dependentDialects = ["async::AsyncDialect"];
+}
+
 #endif // MLIR_DIALECT_ASYNC_PASSES
index 12beffe..26b0a23 100644 (file)
@@ -48,6 +48,18 @@ typedef struct AsyncGroup MLIR_AsyncGroup;
 using CoroHandle = void *;           // coroutine handle
 using CoroResume = void (*)(void *); // coroutine resume function
 
+// Async runtime uses reference counting to manage the lifetime of async values
+// (values of async types like tokens, values and groups).
+using RefCountedObjPtr = void *;
+
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+    mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t);
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+    mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t);
+
 // Create a new `async.token` in not-ready state.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
 
index 0618771..74c0556 100644 (file)
@@ -1,4 +1,5 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
+// RUN:               -async-ref-counting                                      \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \
index 79fa4c2..196ab89 100644 (file)
@@ -1,4 +1,5 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
+// RUN:               -async-ref-counting                                      \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \
index 0cbf3de..b08f7e4 100644 (file)
@@ -33,6 +33,8 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
 // Async Runtime C API declaration.
 //===----------------------------------------------------------------------===//
 
+static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
+static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
@@ -49,6 +51,12 @@ static constexpr const char *kAwaitAllAndExecute =
 namespace {
 // Async Runtime API function types.
 struct AsyncAPI {
+  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
+    auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
+    auto count = IntegerType::get(32, ctx);
+    return FunctionType::get({ref, count}, {}, ctx);
+  }
+
   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
   }
@@ -113,6 +121,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   };
 
   MLIRContext *ctx = module.getContext();
+  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
@@ -121,7 +131,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
   addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
-  addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitAllAndExecute,
+              AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
 }
 
 //===----------------------------------------------------------------------===//
@@ -589,6 +600,55 @@ public:
 } // namespace
 
 //===----------------------------------------------------------------------===//
+// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
+// to the corresponding API calls).
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename RefCountingOp>
+class RefCountingOpLowering : public ConversionPattern {
+public:
+  explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
+      : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
+        apiFunctionName(apiFunctionName) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    RefCountingOp refCountingOp = cast<RefCountingOp>(op);
+
+    auto count = rewriter.create<ConstantOp>(
+        op->getLoc(), rewriter.getI32Type(),
+        rewriter.getI32IntegerAttr(refCountingOp.count()));
+
+    rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
+                                        ValueRange({operands[0], count}));
+
+    return success();
+  }
+
+private:
+  StringRef apiFunctionName;
+};
+
+// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
+class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
+public:
+  explicit AddRefOpLowering(MLIRContext *ctx)
+      : RefCountingOpLowering(ctx, kAddRef) {}
+};
+
+// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
+class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
+public:
+  explicit DropRefOpLowering(MLIRContext *ctx)
+      : RefCountingOpLowering(ctx, kDropRef) {}
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
 //===----------------------------------------------------------------------===//
 
@@ -794,10 +854,12 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
 
   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
   patterns.insert<CallOpOpConversion>(ctx);
+  patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
   patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
 
   ConversionTarget target(*ctx);
+  target.addLegalOp<ConstantOp>();
   target.addLegalDialect<LLVM::LLVMDialect>();
   target.addIllegalDialect<AsyncDialect>();
   target.addDynamicallyLegalOp<FuncOp>(
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
new file mode 100644 (file)
index 0000000..ea1da59
--- /dev/null
@@ -0,0 +1,324 @@
+//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements automatic reference counting for Async dialect data
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
+public:
+  AsyncRefCountingPass() = default;
+  void runOnFunction() override;
+
+private:
+  /// Adds an automatic reference counting to the `value`.
+  ///
+  /// All values are semantically created with a reference count of +1 and it is
+  /// the responsibility of the last async value user to drop reference count.
+  ///
+  /// Async values created when:
+  ///   1. Operation returns async result (e.g. the result of an
+  ///      `async.execute`).
+  ///   2. Async value passed in as a block argument.
+  ///
+  /// To implement automatic reference counting, we must insert a +1 reference
+  /// before each `async.execute` operation using the value, and drop it after
+  /// the last use inside the async body region (we currently drop the reference
+  /// before the `async.yield` terminator).
+  ///
+  /// Automatic reference counting algorithm outline:
+  ///
+  /// 1. `ReturnLike` operations forward the reference counted values without
+  ///     modifying the reference count.
+  ///
+  /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
+  ///    reference counted values ends, and insert `drop_ref` operations after
+  ///    the last use of the value.
+  ///
+  /// 3. Insert `add_ref` before the `async.execute` operation capturing the
+  ///    value, and pairing `drop_ref` before the async body region terminator,
+  ///    to release the captured reference counted value when execution
+  ///    completes.
+  ///
+  /// 4. If the reference counted value is passed only to some of the block
+  ///    successors, insert `drop_ref` operations in the beginning of the blocks
+  ///    that do not have reference counted value uses.
+  ///
+  ///
+  /// Example:
+  ///
+  ///   %token = ...
+  ///   async.execute {
+  ///     async.await %token : !async.token   // await #1
+  ///     async.yield
+  ///   }
+  ///   async.await %token : !async.token     // await #2
+  ///
+  /// Based on the liveness analysis await #2 is the last use of the %token,
+  /// however the execution of the async region can be delayed, and to guarantee
+  /// that the %token is still alive when await #1 executes we need to
+  /// explicitly extend its lifetime using `add_ref` operation.
+  ///
+  /// After automatic reference counting:
+  ///
+  ///   %token = ...
+  ///
+  ///   // Make sure that %token is alive inside async.execute.
+  ///   async.add_ref %token {count = 1 : i32} : !async.token
+  ///
+  ///   async.execute {
+  ///     async.await %token : !async.token   // await #1
+  ///
+  ///     // Drop the extra reference added to keep %token alive.
+  ///     async.drop_ref %token {count = 1 : i32} : !async.token
+  ///
+  ///     async.yied
+  ///   }
+  ///   async.await %token : !async.token     // await #2
+  ///
+  ///   // Drop the reference after the last use of %token.
+  ///   async.drop_ref %token {count = 1 : i32} : !async.token
+  ///
+  LogicalResult addAutomaticRefCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
+  MLIRContext *ctx = value.getContext();
+  OpBuilder builder(ctx);
+
+  // Set inserton point after the operation producing a value, or at the
+  // beginning of the block if the value defined by the block argument.
+  if (Operation *op = value.getDefiningOp())
+    builder.setInsertionPointAfter(op);
+  else
+    builder.setInsertionPointToStart(value.getParentBlock());
+
+  Location loc = value.getLoc();
+  auto i32 = IntegerType::get(32, ctx);
+
+  // Drop the reference count immediately if the value has no uses.
+  if (value.getUses().empty()) {
+    builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+    return success();
+  }
+
+  // Use liveness analysis to find the placement of `drop_ref`operation.
+  auto liveness = getAnalysis<Liveness>();
+
+  // We analyse only the blocks of the region that defines the `value`, and do
+  // not check nested blocks attached to operations.
+  //
+  // By analyzing only the `definingRegion` CFG we potentially loose an
+  // opportunity to drop the reference count earlier and can extend the lifetime
+  // of reference counted value longer then it is really required.
+  //
+  // We also assume that all nested regions finish their execution before the
+  // completion of the owner operation. The only exception to this rule is
+  // `async.execute` operation, which is handled explicitly below.
+  Region *definingRegion = value.getParentRegion();
+
+  // ------------------------------------------------------------------------ //
+  // Find blocks where the `value` dies: the value is in `liveIn` set and not
+  // in the `liveOut` set. We place `drop_ref` immediately after the last use
+  // of the `value` in such regions.
+  // ------------------------------------------------------------------------ //
+
+  // Last users of the `value` inside all blocks where the value dies.
+  llvm::SmallSet<Operation *, 4> lastUsers;
+
+  for (Block &block : definingRegion->getBlocks()) {
+    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+    // Value in live input set or was defined in the block.
+    bool liveIn = blockLiveness->isLiveIn(value) ||
+                  blockLiveness->getBlock() == value.getParentBlock();
+    if (!liveIn)
+      continue;
+
+    // Value is in the live out set.
+    bool liveOut = blockLiveness->isLiveOut(value);
+    if (liveOut)
+      continue;
+
+    // We proved that `value` dies in the `block`. Now find the last use of the
+    // `value` inside the `block`.
+
+    // Find any user of the `value` inside the block (including uses in nested
+    // regions attached to the operations in the block).
+    Operation *userInTheBlock = nullptr;
+    for (Operation *user : value.getUsers()) {
+      userInTheBlock = block.findAncestorOpInBlock(*user);
+      if (userInTheBlock)
+        break;
+    }
+
+    // Values with zero users handled explicitly in the beginning, if the value
+    // is in live out set it must have at least one use in the block.
+    assert(userInTheBlock && "value must have a user in the block");
+
+    // Find the last user of the `value` in the block;
+    Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
+    assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
+    lastUsers.insert(lastUser);
+  }
+
+  // Process all the last users of the `value` inside each block where the value
+  // dies.
+  for (Operation *lastUser : lastUsers) {
+    // Return like operations forward reference count.
+    if (lastUser->hasTrait<OpTrait::ReturnLike>())
+      continue;
+
+    // We can't currently handle other types of terminators.
+    if (lastUser->hasTrait<OpTrait::IsTerminator>())
+      return lastUser->emitError() << "async reference counting can't handle "
+                                      "terminators that are not ReturnLike";
+
+    // Add a drop_ref immediately after the last user.
+    builder.setInsertionPointAfter(lastUser);
+    builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+  }
+
+  // ------------------------------------------------------------------------ //
+  // Find blocks where the `value` is in `liveOut` set, however it is not in
+  // the `liveIn` set of all successors. If the `value` is not in the successor
+  // `liveIn` set, we add a `drop_ref` to the beginning of it.
+  // ------------------------------------------------------------------------ //
+
+  // Successors that we'll need a `drop_ref` for the `value`.
+  llvm::SmallSet<Block *, 4> dropRefSuccessors;
+
+  for (Block &block : definingRegion->getBlocks()) {
+    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+    // Skip the block if value is not in the `liveOut` set.
+    if (!blockLiveness->isLiveOut(value))
+      continue;
+
+    // Find successors that do not have `value` in the `liveIn` set.
+    for (Block *successor : block.getSuccessors()) {
+      const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
+
+      if (!succLiveness->isLiveIn(value))
+        dropRefSuccessors.insert(successor);
+    }
+  }
+
+  // Drop reference in all successor blocks that do not have the `value` in
+  // their `liveIn` set.
+  for (Block *dropRefSuccessor : dropRefSuccessors) {
+    builder.setInsertionPointToStart(dropRefSuccessor);
+    builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+  }
+
+  // ------------------------------------------------------------------------ //
+  // Find all `async.execute` operation that take `value` as an operand
+  // (dependency token or async value), or capture implicitly by the nested
+  // region. Each `async.execute` operation will require `add_ref` operation
+  // to keep all captured values alive until it will finish its execution.
+  // ------------------------------------------------------------------------ //
+
+  llvm::SmallSet<ExecuteOp, 4> executeOperations;
+
+  auto trackAsyncExecute = [&](Operation *op) {
+    if (auto execute = dyn_cast<ExecuteOp>(op))
+      executeOperations.insert(execute);
+  };
+
+  for (Operation *user : value.getUsers()) {
+    // Follow parent operations up until the operation in the `definingRegion`.
+    while (user->getParentRegion() != definingRegion) {
+      trackAsyncExecute(user);
+      user = user->getParentOp();
+      assert(user != nullptr && "value user lies outside of the value region");
+    }
+
+    // Don't forget to process the parent in the `definingRegion` (can be the
+    // original user operation itself).
+    trackAsyncExecute(user);
+  }
+
+  // Process all `async.execute` operations capturing `value`.
+  for (ExecuteOp execute : executeOperations) {
+    // Add a reference before the execute operation to keep the reference
+    // counted alive before the async region completes execution.
+    builder.setInsertionPoint(execute.getOperation());
+    builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1));
+
+    // Drop the reference inside the async region before completion.
+    OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
+    executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+  }
+
+  return success();
+}
+
+void AsyncRefCountingPass::runOnFunction() {
+  FuncOp func = getFunction();
+
+  // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
+  // because otherwise automatic reference counting will produce incorrect
+  // results.
+  WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
+    if (isa<AddRefOp, DropRefOp>(op))
+      return op->emitError() << "explicit reference counting is not supported";
+    return WalkResult::advance();
+  });
+
+  if (refCountingWalk.wasInterrupted())
+    signalPassFailure();
+
+  // Add reference counting to block arguments.
+  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+    for (BlockArgument arg : block->getArguments())
+      if (isRefCounted(arg.getType()))
+        if (failed(addAutomaticRefCounting(arg)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (blockWalk.wasInterrupted())
+    signalPassFailure();
+
+  // Add reference counting to operation results.
+  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+    for (unsigned i = 0; i < op->getNumResults(); ++i)
+      if (isRefCounted(op->getResultTypes()[i]))
+        if (failed(addAutomaticRefCounting(op->getResult(i))))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (opWalk.wasInterrupted())
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
+  return std::make_unique<AsyncRefCountingPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
new file mode 100644 (file)
index 0000000..cbcb30c
--- /dev/null
@@ -0,0 +1,218 @@
+//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimize Async dialect reference counting operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingOptimizationPass
+    : public AsyncRefCountingOptimizationBase<
+          AsyncRefCountingOptimizationPass> {
+public:
+  AsyncRefCountingOptimizationPass() = default;
+  void runOnFunction() override;
+
+private:
+  LogicalResult optimizeReferenceCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult
+AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
+  Region *definingRegion = value.getParentRegion();
+
+  // Find all users of the `value` inside each block, including operations that
+  // do not use `value` directly, but have a direct use inside nested region(s).
+  //
+  // Example:
+  //
+  //  ^bb1:
+  //    %token = ...
+  //    scf.if %cond {
+  //      ^bb2:
+  //      async.await %token : !async.token
+  //    }
+  //
+  // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
+  //
+  // In addition to the operation that uses the `value` we also keep track if
+  // this user is an `async.execute` operation itself, or has `async.execute`
+  // operations in the nested regions that do use the `value`.
+
+  struct UserInfo {
+    Operation *operation;
+    bool hasExecuteUser;
+  };
+
+  struct BlockUsersInfo {
+    llvm::SmallVector<AddRefOp, 4> addRefs;
+    llvm::SmallVector<DropRefOp, 4> dropRefs;
+    llvm::SmallVector<UserInfo, 4> users;
+  };
+
+  llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
+
+  auto updateBlockUsersInfo = [&](UserInfo user) {
+    BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
+    info.users.push_back(user);
+
+    if (auto addRef = dyn_cast<AddRefOp>(user.operation))
+      info.addRefs.push_back(addRef);
+    if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
+      info.dropRefs.push_back(dropRef);
+  };
+
+  for (Operation *user : value.getUsers()) {
+    bool isAsyncUser = isa<ExecuteOp>(user);
+
+    while (user->getParentRegion() != definingRegion) {
+      updateBlockUsersInfo({user, isAsyncUser});
+      user = user->getParentOp();
+      isAsyncUser |= isa<ExecuteOp>(user);
+      assert(user != nullptr && "value user lies outside of the value region");
+    }
+
+    updateBlockUsersInfo({user, isAsyncUser});
+  }
+
+  // Sort all operations found in the block.
+  auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
+    auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
+      return a->isBeforeInBlock(b);
+    };
+    llvm::sort(info.addRefs, isBeforeInBlock);
+    llvm::sort(info.dropRefs, isBeforeInBlock);
+    llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
+      return isBeforeInBlock(a.operation, b.operation);
+    });
+
+    return info;
+  };
+
+  // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
+  // blocks that modify the reference count of the `value`.
+  for (auto &kv : blockUsers) {
+    BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
+
+    // Find all cancellable pairs first and erase them later to keep all
+    // pointers in the `info` valid until the end.
+    //
+    // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
+    llvm::SmallDenseMap<Operation *, Operation *> cancellable;
+
+    for (AddRefOp addRef : info.addRefs) {
+      for (DropRefOp dropRef : info.dropRefs) {
+        // `drop_ref` operation after the `add_ref` with matching count.
+        if (dropRef.count() != addRef.count() ||
+            dropRef.getOperation()->isBeforeInBlock(addRef.getOperation()))
+          continue;
+
+        // `drop_ref` was already marked for removal.
+        if (cancellable.find(dropRef.getOperation()) != cancellable.end())
+          continue;
+
+        // Check `value` users between `addRef` and `dropRef` in the `block`.
+        Operation *addRefOp = addRef.getOperation();
+        Operation *dropRefOp = dropRef.getOperation();
+
+        // If there is a "regular" user after the `async.execute` user it is
+        // unsafe to erase cancellable reference counting operations pair,
+        // because async region can complete before the "regular" user and
+        // destroy the reference counted value.
+        bool hasExecuteUser = false;
+        bool unsafeToCancel = false;
+
+        for (UserInfo &user : info.users) {
+          Operation *op = user.operation;
+
+          // `user` operation lies after `addRef` ...
+          if (op == addRefOp || op->isBeforeInBlock(addRefOp))
+            continue;
+          // ... and before `dropRef`.
+          if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
+            break;
+
+          bool isRegularUser = !user.hasExecuteUser;
+          bool isExecuteUser = user.hasExecuteUser;
+
+          // It is unsafe to cancel `addRef` / `dropRef` pair.
+          if (isRegularUser && hasExecuteUser) {
+            unsafeToCancel = true;
+            break;
+          }
+
+          hasExecuteUser |= isExecuteUser;
+        }
+
+        // Mark the pair of reference counting operations for removal.
+        if (!unsafeToCancel)
+          cancellable[dropRef.getOperation()] = addRef.getOperation();
+
+        // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
+        // all the following pairs will be also unsafe.
+        break;
+      }
+    }
+
+    // Erase all cancellable `addRef <-> dropRef` operation pairs.
+    for (auto &kv : cancellable) {
+      kv.first->erase();
+      kv.second->erase();
+    }
+  }
+
+  return success();
+}
+
+void AsyncRefCountingOptimizationPass::runOnFunction() {
+  FuncOp func = getFunction();
+
+  // Optimize reference counting for values defined by block arguments.
+  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+    for (BlockArgument arg : block->getArguments())
+      if (isRefCounted(arg.getType()))
+        if (failed(optimizeReferenceCounting(arg)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (blockWalk.wasInterrupted())
+    signalPassFailure();
+
+  // Optimize reference counting for values defined by operation results.
+  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+    for (unsigned i = 0; i < op->getNumResults(); ++i)
+      if (isRefCounted(op->getResultTypes()[i]))
+        if (failed(optimizeReferenceCounting(op->getResult(i))))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (opWalk.wasInterrupted())
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRefCountingOptimizationPass() {
+  return std::make_unique<AsyncRefCountingOptimizationPass>();
+}
index 9de4387..dccae73 100644 (file)
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAsyncTransforms
   AsyncParallelFor.cpp
+  AsyncRefCounting.cpp
+  AsyncRefCountingOptimization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
index 332c7ff..f769965 100644 (file)
@@ -16,6 +16,7 @@
 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
 
 #include <atomic>
+#include <cassert>
 #include <condition_variable>
 #include <functional>
 #include <iostream>
 // Async runtime API.
 //===----------------------------------------------------------------------===//
 
-struct AsyncToken {
-  bool ready = false;
+namespace {
+
+// Forward declare class defined below.
+class RefCounted;
+
+// -------------------------------------------------------------------------- //
+// AsyncRuntime orchestrates all async operations and Async runtime API is built
+// on top of the default runtime instance.
+// -------------------------------------------------------------------------- //
+
+class AsyncRuntime {
+public:
+  AsyncRuntime() : numRefCountedObjects(0) {}
+
+  ~AsyncRuntime() {
+    assert(getNumRefCountedObjects() == 0 &&
+           "all ref counted objects must be destroyed");
+  }
+
+  int32_t getNumRefCountedObjects() {
+    return numRefCountedObjects.load(std::memory_order_relaxed);
+  }
+
+private:
+  friend class RefCounted;
+
+  // Count the total number of reference counted objects in this instance
+  // of an AsyncRuntime. For debugging purposes only.
+  void addNumRefCountedObjects() {
+    numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
+  }
+  void dropNumRefCountedObjects() {
+    numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
+  }
+
+  std::atomic<int32_t> numRefCountedObjects;
+};
+
+// Returns the default per-process instance of an async runtime.
+AsyncRuntime *getDefaultAsyncRuntimeInstance() {
+  static auto runtime = std::make_unique<AsyncRuntime>();
+  return runtime.get();
+}
+
+// -------------------------------------------------------------------------- //
+// A base class for all reference counted objects created by the async runtime.
+// -------------------------------------------------------------------------- //
+
+class RefCounted {
+public:
+  RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
+      : runtime(runtime), refCount(refCount) {
+    runtime->addNumRefCountedObjects();
+  }
+
+  virtual ~RefCounted() {
+    assert(refCount.load() == 0 && "reference count must be zero");
+    runtime->dropNumRefCountedObjects();
+  }
+
+  RefCounted(const RefCounted &) = delete;
+  RefCounted &operator=(const RefCounted &) = delete;
+
+  void addRef(int32_t count = 1) { refCount.fetch_add(count); }
+
+  void dropRef(int32_t count = 1) {
+    int32_t previous = refCount.fetch_sub(count);
+    assert(previous >= count && "reference count should not go below zero");
+    if (previous == count)
+      destroy();
+  }
+
+protected:
+  virtual void destroy() { delete this; }
+
+private:
+  AsyncRuntime *runtime;
+  std::atomic<int32_t> refCount;
+};
+
+} // namespace
+
+struct AsyncToken : public RefCounted {
+  // AsyncToken created with a reference count of 2 because it will be returned
+  // to the `async.execute` caller and also will be later on emplaced by the
+  // asynchronously executed task. If the caller immediately will drop its
+  // reference we must ensure that the token will be alive until the
+  // asynchronous operation is completed.
+  AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
+
+  // Internal state below guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
+
+  bool ready = false;
   std::vector<std::function<void()>> awaiters;
 };
 
-struct AsyncGroup {
-  std::atomic<int> pendingTokens{0};
-  std::atomic<int> rank{0};
+struct AsyncGroup : public RefCounted {
+  AsyncGroup(AsyncRuntime *runtime)
+      : RefCounted(runtime), pendingTokens(0), rank(0) {}
+
+  std::atomic<int> pendingTokens;
+  std::atomic<int> rank;
+
+  // Internal state below guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
+
   std::vector<std::function<void()>> awaiters;
 };
 
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
+  RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+  refCounted->addRef(count);
+}
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
+  RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+  refCounted->dropRef(count);
+}
+
 // Create a new `async.token` in not-ready state.
 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
-  AsyncToken *token = new AsyncToken;
+  AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
   return token;
 }
 
 // Create a new `async.group` in empty state.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
-  AsyncGroup *group = new AsyncGroup;
+  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
   return group;
 }
 
@@ -59,23 +171,34 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
   std::unique_lock<std::mutex> lockToken(token->mu);
   std::unique_lock<std::mutex> lockGroup(group->mu);
 
+  // Get the rank of the token inside the group before we drop the reference.
+  int rank = group->rank.fetch_add(1);
   group->pendingTokens.fetch_add(1);
 
-  auto onTokenReady = [group]() {
+  auto onTokenReady = [group, token](bool dropRef) {
     // Run all group awaiters if it was the last token in the group.
     if (group->pendingTokens.fetch_sub(1) == 1) {
       group->cv.notify_all();
       for (auto &awaiter : group->awaiters)
         awaiter();
     }
+
+    // We no longer need the token or the group, drop references on them.
+    if (dropRef) {
+      group->dropRef();
+      token->dropRef();
+    }
   };
 
-  if (token->ready)
-    onTokenReady();
-  else
-    token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
+  if (token->ready) {
+    onTokenReady(false);
+  } else {
+    group->addRef();
+    token->addRef();
+    token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
+  }
 
-  return group->rank.fetch_add(1);
+  return rank;
 }
 
 // Switches `async.token` to ready state and runs all awaiters.
@@ -85,6 +208,10 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
   token->cv.notify_all();
   for (auto &awaiter : token->awaiters)
     awaiter();
+
+  // Async tokens created with a ref count `2` to keep token alive until the
+  // async task completes. Drop this reference explicitly when token emplaced.
+  token->dropRef();
 }
 
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
@@ -114,14 +241,18 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
                                                      CoroResume resume) {
   std::unique_lock<std::mutex> lock(token->mu);
 
-  auto execute = [handle, resume]() {
+  auto execute = [handle, resume, token](bool dropRef) {
+    if (dropRef)
+      token->dropRef();
     mlirAsyncRuntimeExecute(handle, resume);
   };
 
-  if (token->ready)
-    execute();
-  else
-    token->awaiters.push_back([execute]() { execute(); });
+  if (token->ready) {
+    execute(false);
+  } else {
+    token->addRef();
+    token->awaiters.push_back([execute]() { execute(true); });
+  }
 }
 
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
@@ -129,14 +260,18 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
                                           CoroResume resume) {
   std::unique_lock<std::mutex> lock(group->mu);
 
-  auto execute = [handle, resume]() {
+  auto execute = [handle, resume, group](bool dropRef) {
+    if (dropRef)
+      group->dropRef();
     mlirAsyncRuntimeExecute(handle, resume);
   };
 
-  if (group->pendingTokens == 0)
-    execute();
-  else
-    group->awaiters.push_back([execute]() { execute(); });
+  if (group->pendingTokens == 0) {
+    execute(false);
+  } else {
+    group->addRef();
+    group->awaiters.push_back([execute]() { execute(true); });
+  }
 }
 
 //===----------------------------------------------------------------------===//
index 1fd71a6..dadb28d 100644 (file)
@@ -1,5 +1,20 @@
 // RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
 
+// CHECK-LABEL: reference_counting
+func @reference_counting(%arg0: !async.token) {
+  // CHECK: %[[C2:.*]] = constant 2 : i32
+  // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]])
+  async.add_ref %arg0 {count = 2 : i32} : !async.token
+
+  // CHECK: %[[C1:.*]] = constant 1 : i32
+  // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]])
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: execute_no_async_args
 func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
   // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
new file mode 100644 (file)
index 0000000..6500fa0
--- /dev/null
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
+
+// CHECK-LABEL: @cancellable_operations_0
+func @cancellable_operations_0(%arg0: !async.token) {
+  // CHECK-NOT: async.add_ref
+  // CHECK-NOT: async.drop_ref
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_1
+func @cancellable_operations_1(%arg0: !async.token) {
+  // CHECK-NOT: async.add_ref
+  // CHECK: async.execute
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.execute [%arg0] {
+    // CHECK: async.drop_ref
+    async.drop_ref %arg0 {count = 1 : i32} : !async.token
+    // CHECK-NEXT: async.yield
+    async.yield
+  }
+  // CHECK-NOT: async.drop_ref
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_2
+func @cancellable_operations_2(%arg0: !async.token) {
+  // CHECK: async.await
+  // CHECK-NEXT: async.await
+  // CHECK-NEXT: async.await
+  // CHECK-NEXT: return
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.await %arg0 : !async.token
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  async.await %arg0 : !async.token
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.await %arg0 : !async.token
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_3
+func @cancellable_operations_3(%arg0: !async.token) {
+  // CHECK-NOT: add_ref
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  %token = async.execute {
+    async.await %arg0 : !async.token
+    // CHECK: async.drop_ref
+    async.drop_ref %arg0 {count = 1 : i32} : !async.token
+    async.yield
+  }
+  // CHECK-NOT: async.drop_ref
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: async.await
+  async.await %arg0 : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_0
+func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
+  // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
+  // that the body of the `async.execute` operation will run before the await
+  // operation in the function body, and will destroy the `%arg0` token.
+  // CHECK: add_ref
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  %token = async.execute {
+    // CHECK: async.await
+    async.await %arg0 : !async.token
+    // CHECK: async.drop_ref
+    async.drop_ref %arg0 {count = 1 : i32} : !async.token
+    // CHECK: async.yield
+    async.yield
+  }
+  // CHECK: async.await
+  async.await %arg0 : !async.token
+  // CHECK: drop_ref
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_1
+func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
+  // Same reason as above, although `async.execute` is inside the nested
+  // region or "regular" opeation.
+  //
+  // NOTE: This test is not correct w.r.t. reference counting, and at runtime
+  // would leak %arg0 value if %arg1 is false. IR like this will not be
+  // constructed by automatic reference counting pass, because it would
+  // place `async.add_ref` right before the `async.execute` inside `scf.if`.
+  
+  // CHECK: async.add_ref
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  scf.if %arg1 {
+    %token = async.execute {
+      async.await %arg0 : !async.token
+      // CHECK: async.drop_ref
+      async.drop_ref %arg0 {count = 1 : i32} : !async.token
+      async.yield
+    }
+  }
+  // CHECK: async.await
+  async.await %arg0 : !async.token
+  // CHECK: async.drop_ref
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir
new file mode 100644 (file)
index 0000000..504a18f
--- /dev/null
@@ -0,0 +1,253 @@
+// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @cond
+func private @cond() -> i1
+
+// CHECK-LABEL: @token_arg_no_uses
+func @token_arg_no_uses(%arg0: !async.token) {
+  // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+  return
+}
+
+// CHECK-LABEL: @token_arg_conditional_await
+func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
+  cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+  return
+^bb2:
+  // CHECK: async.await %arg0
+  // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+  async.await %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @token_no_uses
+func @token_no_uses() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  %token = async.execute {
+    async.yield
+  }
+  return
+}
+
+// CHECK-LABEL: @token_return
+func @token_return() -> !async.token {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  // CHECK: return %[[TOKEN]]
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await
+func @token_await() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  // CHECK: async.await %[[TOKEN]]
+  async.await %token : !async.token
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @token_await_and_return
+func @token_await_and_return() -> !async.token {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  // CHECK: async.await %[[TOKEN]]
+  // CHECK-NOT: async.drop_ref
+  async.await %token : !async.token
+  // CHECK: return %[[TOKEN]]
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await_inside_scf_if
+func @token_await_inside_scf_if(%arg0: i1) {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  // CHECK: scf.if %arg0 {
+  scf.if %arg0 {
+    // CHECK: async.await %[[TOKEN]]
+    async.await %token : !async.token
+  }
+  // CHECK: }
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @token_conditional_await
+func @token_conditional_await(%arg0: i1) {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  return
+^bb2:
+  // CHECK: async.await %[[TOKEN]]
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.await %token : !async.token
+  return
+}
+
+// CHECK-LABEL: @token_await_in_the_loop
+func @token_await_in_the_loop() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  br ^bb1
+^bb1:
+  // CHECK: async.await %[[TOKEN]]
+  async.await %token : !async.token
+  %0 = call @cond(): () -> (i1)
+  cond_br %0, ^bb1, ^bb2
+^bb2:
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  return
+}
+
+// CHECK-LABEL: @token_defined_in_the_loop
+func @token_defined_in_the_loop() {
+  br ^bb1
+^bb1:
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+  // CHECK: async.await %[[TOKEN]]
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.await %token : !async.token
+  %0 = call @cond(): () -> (i1)
+  cond_br %0, ^bb1, ^bb2
+^bb2:
+  return
+}
+
+// CHECK-LABEL: @token_capture
+func @token_capture() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+
+  // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: %[[TOKEN_0:.*]] = async.execute
+  %token_0 = async.execute {
+    // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+    // CHECK-NEXT: async.yield
+    async.await %token : !async.token
+    async.yield
+  }
+  // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @token_nested_capture
+func @token_nested_capture() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+
+  // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: %[[TOKEN_0:.*]] = async.execute
+  %token_0 = async.execute {
+    // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+    // CHECK: %[[TOKEN_1:.*]] = async.execute
+    %token_1 = async.execute {
+      // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+      // CHECK: %[[TOKEN_2:.*]] = async.execute
+      %token_2 = async.execute {
+        // CHECK: async.await %[[TOKEN]]
+        // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+        async.await %token : !async.token
+        async.yield
+      }
+      // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32}
+      // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+      async.yield
+    }
+    // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32}
+    // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+    async.yield
+  }
+  // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @token_dependency
+func @token_dependency() {
+  // CHECK: %[[TOKEN:.*]] = async.execute
+  %token = async.execute {
+    async.yield
+  }
+
+  // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: %[[TOKEN_0:.*]] = async.execute
+  %token_0 = async.execute[%token] {
+    // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+    // CHECK-NEXT: async.yield
+    async.yield
+  }
+
+  // CHECK: async.await %[[TOKEN]]
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.await %token : !async.token
+  // CHECK: async.await %[[TOKEN_0]]
+  // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+  async.await %token_0 : !async.token
+
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @value_operand
+func @value_operand() -> f32 {
+  // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
+  %token, %results = async.execute -> !async.value<f32> {
+    %0 = constant 0.0 : f32
+    async.yield %0 : f32
+  }
+
+  // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32}
+  // CHECK: %[[TOKEN_0:.*]] = async.execute
+  %token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>)  {
+    // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+    // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+    // CHECK: async.yield
+    async.yield
+  }
+
+  // CHECK: async.await %[[TOKEN]]
+  // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.await %token : !async.token
+
+  // CHECK: async.await %[[TOKEN_0]]
+  // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+  async.await %token_0 : !async.token
+
+  // CHECK: async.await %[[RESULTS]]
+  // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+  %0 = async.await %results : !async.value<f32>
+
+  // CHECK: return
+  return %0 : f32
+}
index a95be65..54dc673 100644 (file)
@@ -134,3 +134,17 @@ func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>)
   %3 = addi %1, %2 : index
   return %3 : index
 }
+
+// CHECK-LABEL: @add_ref
+func @add_ref(%arg0: !async.token) {
+  // CHECK: async.add_ref %arg0 {count = 1 : i32}
+  async.add_ref %arg0 {count = 1 : i32} : !async.token
+  return
+}
+
+// CHECK-LABEL: @drop_ref
+func @drop_ref(%arg0: !async.token) {
+  // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+  async.drop_ref %arg0 {count = 1 : i32} : !async.token
+  return
+}
index 87004ff..50f85ff 100644 (file)
@@ -1,4 +1,5 @@
-// RUN:   mlir-opt %s -convert-async-to-llvm                                   \
+// RUN:   mlir-opt %s -async-ref-counting                                      \
+// RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-std-to-llvm                                     \
 // RUN: | mlir-cpu-runner                                                      \
 // RUN:     -e main -entry-point-result=void -O0                               \
index fd0268e..5f06dd1 100644 (file)
@@ -1,4 +1,5 @@
-// RUN:   mlir-opt %s -convert-async-to-llvm                                   \
+// RUN:   mlir-opt %s -async-ref-counting                                      \
+// RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-linalg-to-loops                                 \
 // RUN:               -convert-linalg-to-llvm                                  \
 // RUN:               -convert-std-to-llvm                                     \