[ORC] Fix SimpleRemoteEPC data races.
authorLang Hames <lhames@gmail.com>
Mon, 27 Sep 2021 00:56:47 +0000 (17:56 -0700)
committerLang Hames <lhames@gmail.com>
Mon, 27 Sep 2021 01:11:48 +0000 (18:11 -0700)
Adds a 'start' method to SimpleRemoteEPCTransport to defer transport startup
until the client has been configured. This avoids races on client members if the
first messages arrives while the client is being configured.

Also fixes races on the file descriptors in FDSimpleRemoteEPCTransport.

llvm/include/llvm/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.h
llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
llvm/include/llvm/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.h
llvm/lib/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.cpp
llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp

index 9945a75..f3dcb0f 100644 (file)
@@ -21,6 +21,7 @@
 #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
 #include "llvm/Support/Error.h"
 
+#include <atomic>
 #include <mutex>
 #include <string>
 #include <thread>
@@ -77,6 +78,13 @@ class SimpleRemoteEPCTransport {
 public:
   virtual ~SimpleRemoteEPCTransport();
 
+  /// Called during setup of the client to indicate that the client is ready
+  /// to receive messages.
+  ///
+  /// Transport objects should not access the client until this method is
+  /// called.
+  virtual Error start() = 0;
+
   /// Send a SimpleRemoteEPC message.
   ///
   /// This function may be called concurrently. Subclasses should implement
@@ -107,6 +115,8 @@ public:
 
   ~FDSimpleRemoteEPCTransport() override;
 
+  Error start() override;
+
   Error sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
                     ExecutorAddr TagAddr, ArrayRef<char> ArgBytes) override;
 
@@ -114,7 +124,8 @@ public:
 
 private:
   FDSimpleRemoteEPCTransport(SimpleRemoteEPCTransportClient &C, int InFD,
-                             int OutFD);
+                             int OutFD)
+      : C(C), InFD(InFD), OutFD(OutFD) {}
 
   Error readBytes(char *Dst, size_t Size, bool *IsEOF = nullptr);
   int writeBytes(const char *Src, size_t Size);
@@ -124,6 +135,7 @@ private:
   SimpleRemoteEPCTransportClient &C;
   std::thread ListenerThread;
   int InFD, OutFD;
+  std::atomic<bool> Disconnected{false};
 };
 
 struct RemoteSymbolLookupSetElement {
index 839252f..4fc7a58 100644 (file)
@@ -37,21 +37,12 @@ public:
   Create(TransportTCtorArgTs &&...TransportTCtorArgs) {
     std::unique_ptr<SimpleRemoteEPC> SREPC(
         new SimpleRemoteEPC(std::make_shared<SymbolStringPool>()));
-
-    // Prepare for setup packet.
-    std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
-    auto EIF = EIP.get_future();
-    SREPC->prepareToReceiveSetupMessage(EIP);
     auto T = TransportT::Create(
         *SREPC, std::forward<TransportTCtorArgTs>(TransportTCtorArgs)...);
     if (!T)
       return T.takeError();
-    auto EI = EIF.get();
-    if (!EI) {
-      (*T)->disconnect();
-      return EI.takeError();
-    }
-    if (auto Err = SREPC->setup(std::move(*T), std::move(*EI)))
+    SREPC->T = std::move(*T);
+    if (auto Err = SREPC->setup())
       return joinErrors(std::move(Err), SREPC->disconnect());
     return std::move(SREPC);
   }
@@ -96,10 +87,7 @@ private:
 
   Error handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
                     SimpleRemoteEPCArgBytesVector ArgBytes);
-  void prepareToReceiveSetupMessage(
-      std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP);
-  Error setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
-              SimpleRemoteEPCExecutorInfo EI);
+  Error setup();
 
   Error handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
                      SimpleRemoteEPCArgBytesVector ArgBytes);
index 2014a97..38c413d 100644 (file)
@@ -104,6 +104,8 @@ public:
     if (!T)
       return T.takeError();
     Server->T = std::move(*T);
+    if (auto Err = Server->T->start())
+      return std::move(Err);
 
     // If transport creation succeeds then start up services.
     Server->Services = std::move(S.services());
index aaa1a4f..62f4ff8 100644 (file)
@@ -69,18 +69,18 @@ FDSimpleRemoteEPCTransport::Create(SimpleRemoteEPCTransportClient &C, int InFD,
 #endif
 }
 
-FDSimpleRemoteEPCTransport::FDSimpleRemoteEPCTransport(
-    SimpleRemoteEPCTransportClient &C, int InFD, int OutFD)
-    : C(C), InFD(InFD), OutFD(OutFD) {
+FDSimpleRemoteEPCTransport::~FDSimpleRemoteEPCTransport() {
 #if LLVM_ENABLE_THREADS
-  ListenerThread = std::thread([this]() { listenLoop(); });
+  ListenerThread.join();
 #endif
 }
 
-FDSimpleRemoteEPCTransport::~FDSimpleRemoteEPCTransport() {
+Error FDSimpleRemoteEPCTransport::start() {
 #if LLVM_ENABLE_THREADS
-  ListenerThread.join();
+  ListenerThread = std::thread([this]() { listenLoop(); });
+  return Error::success();
 #endif
+  llvm_unreachable("Should not be called with LLVM_ENABLE_THREADS=Off");
 }
 
 Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
@@ -98,7 +98,7 @@ Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
       TagAddr.getValue();
 
   std::lock_guard<std::mutex> Lock(M);
-  if (OutFD == -1)
+  if (Disconnected)
     return make_error<StringError>("FD-transport disconnected",
                                    inconvertibleErrorCode());
   if (int ErrNo = writeBytes(HeaderBuffer, FDMsgHeader::Size))
@@ -109,28 +109,21 @@ Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
 }
 
 void FDSimpleRemoteEPCTransport::disconnect() {
-  int CloseInFD = -1, CloseOutFD = -1;
-  {
-    std::lock_guard<std::mutex> Lock(M);
-    std::swap(InFD, CloseInFD);
-    std::swap(OutFD, CloseOutFD);
-  }
+  if (Disconnected)
+    return; // Return if already disconnected.
 
-  // If CloseOutFD == CloseInFD then set CloseOutFD to -1 up-front so that we
-  // don't double-close.
-  if (CloseOutFD == CloseInFD)
-    CloseOutFD = -1;
+  Disconnected = true;
+  bool CloseOutFD = InFD != OutFD;
 
   // Close InFD.
-  if (CloseInFD != -1)
-    while (close(CloseInFD) == -1) {
-      if (errno == EBADF)
-        break;
-    }
+  while (close(InFD) == -1) {
+    if (errno == EBADF)
+      break;
+  }
 
   // Close OutFD.
-  if (CloseOutFD != -1) {
-    while (close(CloseOutFD) == -1) {
+  if (CloseOutFD) {
+    while (close(OutFD) == -1) {
       if (errno == EBADF)
         break;
     }
@@ -160,7 +153,7 @@ Error FDSimpleRemoteEPCTransport::readBytes(char *Dst, size_t Size,
         continue;
       else {
         std::lock_guard<std::mutex> Lock(M);
-        if (InFD == -1 && IsEOF) { // Disconnected locally. Pretend this is EOF.
+        if (Disconnected && IsEOF) { // disconnect called,  pretend this is EOF.
           *IsEOF = true;
           return Error::success();
         }
index 7f47f19..9f6ecce 100644 (file)
@@ -238,12 +238,17 @@ Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
   return Error::success();
 }
 
-void SimpleRemoteEPC::prepareToReceiveSetupMessage(
-    std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP) {
+Error SimpleRemoteEPC::setup() {
+  using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
+
+  std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
+  auto EIF = EIP.get_future();
+
+  // Prepare a handler for the setup packet.
   PendingCallWrapperResults[0] =
       [&](shared::WrapperFunctionResult SetupMsgBytes) {
         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
-          ExecInfoP.set_value(
+          EIP.set_value(
               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
           return;
         }
@@ -252,29 +257,35 @@ void SimpleRemoteEPC::prepareToReceiveSetupMessage(
         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
         SimpleRemoteEPCExecutorInfo EI;
         if (SPSSerialize::deserialize(IB, EI))
-          ExecInfoP.set_value(EI);
+          EIP.set_value(EI);
         else
-          ExecInfoP.set_value(make_error<StringError>(
+          EIP.set_value(make_error<StringError>(
               "Could not deserialize setup message", inconvertibleErrorCode()));
       };
-}
 
-Error SimpleRemoteEPC::setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
-                             SimpleRemoteEPCExecutorInfo EI) {
-  using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
+  // Start the transport.
+  if (auto Err = T->start())
+    return Err;
+
+  // Wait for setup packet to arrive.
+  auto EI = EIF.get();
+  if (!EI) {
+    T->disconnect();
+    return EI.takeError();
+  }
+
   LLVM_DEBUG({
     dbgs() << "SimpleRemoteEPC received setup message:\n"
-           << "  Triple: " << EI.TargetTriple << "\n"
-           << "  Page size: " << EI.PageSize << "\n"
+           << "  Triple: " << EI->TargetTriple << "\n"
+           << "  Page size: " << EI->PageSize << "\n"
            << "  Bootstrap symbols:\n";
-    for (const auto &KV : EI.BootstrapSymbols)
+    for (const auto &KV : EI->BootstrapSymbols)
       dbgs() << "    " << KV.first() << ": "
              << formatv("{0:x16}", KV.second.getValue()) << "\n";
   });
-  this->T = std::move(T);
-  TargetTriple = Triple(EI.TargetTriple);
-  PageSize = EI.PageSize;
-  BootstrapSymbols = std::move(EI.BootstrapSymbols);
+  TargetTriple = Triple(EI->TargetTriple);
+  PageSize = EI->PageSize;
+  BootstrapSymbols = std::move(EI->BootstrapSymbols);
 
   if (auto Err = getBootstrapSymbols(
           {{JDI.JITDispatchContext, ExecutorSessionObjectName},