[mlir:async] Fix data races in AsyncRuntime
authorEugene Zhulenev <ezhulenev@google.com>
Wed, 20 Jan 2021 13:17:12 +0000 (05:17 -0800)
committerEugene Zhulenev <ezhulenev@google.com>
Wed, 20 Jan 2021 21:23:39 +0000 (13:23 -0800)
Resumed coroutine potentially can deallocate the token/value/group and destroy the mutex before the std::unique_ptr destructor.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D95037

mlir/lib/ExecutionEngine/AsyncRuntime.cpp

index a20bd6d..e38ebf9 100644 (file)
@@ -136,13 +136,14 @@ struct AsyncToken : public RefCounted {
   // asynchronously executed task. If the caller immediately will drop its
   // reference we must ensure that the token will be alive until the
   // asynchronous operation is completed.
-  AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
+  AsyncToken(AsyncRuntime *runtime)
+      : RefCounted(runtime, /*count=*/2), ready(false) {}
 
-  // Internal state below guarded by a mutex.
+  std::atomic<bool> ready;
+
+  // Pending awaiters are guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
-
-  bool ready = false;
   std::vector<std::function<void()>> awaiters;
 };
 
@@ -152,17 +153,17 @@ struct AsyncToken : public RefCounted {
 struct AsyncValue : public RefCounted {
   // AsyncValue similar to an AsyncToken created with a reference count of 2.
   AsyncValue(AsyncRuntime *runtime, int32_t size)
-      : RefCounted(runtime, /*count=*/2), storage(size) {}
-
-  // Internal state below guarded by a mutex.
-  std::mutex mu;
-  std::condition_variable cv;
+      : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {}
 
-  bool ready = false;
-  std::vector<std::function<void()>> awaiters;
+  std::atomic<bool> ready;
 
   // Use vector of bytes to store async value payload.
   std::vector<int8_t> storage;
+
+  // Pending awaiters are guarded by a mutex.
+  std::mutex mu;
+  std::condition_variable cv;
+  std::vector<std::function<void()>> awaiters;
 };
 
 // Async group provides a mechanism to group together multiple async tokens or
@@ -175,10 +176,9 @@ struct AsyncGroup : public RefCounted {
   std::atomic<int> pendingTokens;
   std::atomic<int> rank;
 
-  // Internal state below guarded by a mutex.
+  // Pending awaiters are guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
-
   std::vector<std::function<void()>> awaiters;
 };
 
@@ -291,13 +291,13 @@ extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
   std::unique_lock<std::mutex> lock(token->mu);
   if (!token->ready)
-    token->cv.wait(lock, [token] { return token->ready; });
+    token->cv.wait(lock, [token] { return token->ready.load(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
   std::unique_lock<std::mutex> lock(value->mu);
   if (!value->ready)
-    value->cv.wait(lock, [value] { return value->ready; });
+    value->cv.wait(lock, [value] { return value->ready.load(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
@@ -319,34 +319,37 @@ extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
                                                      CoroHandle handle,
                                                      CoroResume resume) {
-  std::unique_lock<std::mutex> lock(token->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (token->ready)
+  if (token->ready) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(token->mu);
     token->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
                                                      CoroHandle handle,
                                                      CoroResume resume) {
-  std::unique_lock<std::mutex> lock(value->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (value->ready)
+  if (value->ready) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(value->mu);
     value->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
                                                           CoroHandle handle,
                                                           CoroResume resume) {
-  std::unique_lock<std::mutex> lock(group->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (group->pendingTokens == 0)
+  if (group->pendingTokens == 0) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(group->mu);
     group->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 //===----------------------------------------------------------------------===//