[ORC] Add TaskDispatch API and thread it through ExecutorProcessControl.
authorLang Hames <lhames@gmail.com>
Sat, 9 Oct 2021 00:12:06 +0000 (17:12 -0700)
committerLang Hames <lhames@gmail.com>
Mon, 11 Oct 2021 01:39:55 +0000 (18:39 -0700)
ExecutorProcessControl objects will now have a TaskDispatcher member which
should be used to dispatch work (in particular, handling incoming packets in
the implementation of remote EPC implementations like SimpleRemoteEPC).

The GenericNamedTask template can be used to wrap function objects that are
callable as 'void()' (along with an optional name to describe the task).
The makeGenericNamedTask functions can be used to create GenericNamedTask
instances without having to name the function object type.

In a future patch ExecutionSession will be updated to use the
ExecutorProcessControl's dispatcher, instead of its DispatchTaskFunction.

12 files changed:
llvm/include/llvm/ExecutionEngine/Orc/Core.h
llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h [new file with mode: 0644]
llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
llvm/lib/ExecutionEngine/Orc/Core.cpp
llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp [new file with mode: 0644]
llvm/tools/lli/lli.cpp
llvm/tools/llvm-jitlink/llvm-jitlink.cpp
llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp [new file with mode: 0644]

index a40c78f..d2761d6 100644 (file)
@@ -21,6 +21,7 @@
 #include "llvm/ExecutionEngine/JITSymbol.h"
 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
 #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
+#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
@@ -1254,21 +1255,6 @@ public:
                          const DenseMap<JITDylib *, SymbolLookupSet> &InitSyms);
 };
 
-/// Represents an abstract task for ORC to run.
-class Task : public RTTIExtends<Task, RTTIRoot> {
-public:
-  static char ID;
-
-  /// Description of the task to be performed. Used for logging.
-  virtual void printDescription(raw_ostream &OS) = 0;
-
-  /// Run the task.
-  virtual void run() = 0;
-
-private:
-  void anchor() override;
-};
-
 /// A materialization task.
 class MaterializationTask : public RTTIExtends<MaterializationTask, Task> {
 public:
index 7e05bef..147d1d3 100644 (file)
@@ -20,6 +20,7 @@
 #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
 #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
 #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
+#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
 #include "llvm/Support/DynamicLibrary.h"
 #include "llvm/Support/MSVCErrorWorkarounds.h"
 
@@ -121,6 +122,10 @@ public:
     ExecutorAddr JITDispatchContext;
   };
 
+  ExecutorProcessControl(std::shared_ptr<SymbolStringPool> SSP,
+                         std::unique_ptr<TaskDispatcher> D)
+    : SSP(std::move(SSP)), D(std::move(D)) {}
+
   virtual ~ExecutorProcessControl();
 
   /// Return the ExecutionSession associated with this instance.
@@ -136,6 +141,8 @@ public:
   /// Return a shared pointer to the SymbolStringPool for this instance.
   std::shared_ptr<SymbolStringPool> getSymbolStringPool() const { return SSP; }
 
+  TaskDispatcher &getDispatcher() { return *D; }
+
   /// Return the Triple for the target process.
   const Triple &getTargetTriple() const { return TargetTriple; }
 
@@ -264,10 +271,9 @@ public:
   virtual Error disconnect() = 0;
 
 protected:
-  ExecutorProcessControl(std::shared_ptr<SymbolStringPool> SSP)
-      : SSP(std::move(SSP)) {}
 
   std::shared_ptr<SymbolStringPool> SSP;
+  std::unique_ptr<TaskDispatcher> D;
   ExecutionSession *ES = nullptr;
   Triple TargetTriple;
   unsigned PageSize = 0;
@@ -284,9 +290,12 @@ class UnsupportedExecutorProcessControl : public ExecutorProcessControl {
 public:
   UnsupportedExecutorProcessControl(
       std::shared_ptr<SymbolStringPool> SSP = nullptr,
+      std::unique_ptr<TaskDispatcher> D = nullptr,
       const std::string &TT = "", unsigned PageSize = 0)
       : ExecutorProcessControl(SSP ? std::move(SSP)
-                                   : std::make_shared<SymbolStringPool>()) {
+                               : std::make_shared<SymbolStringPool>(),
+                               D ? std::move(D)
+                               : std::make_unique<InPlaceTaskDispatcher>()) {
     this->TargetTriple = Triple(TT);
     this->PageSize = PageSize;
   }
@@ -320,8 +329,9 @@ class SelfExecutorProcessControl
       private ExecutorProcessControl::MemoryAccess {
 public:
   SelfExecutorProcessControl(
-      std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
-      unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr);
+      std::shared_ptr<SymbolStringPool> SSP, std::unique_ptr<TaskDispatcher> D,
+      Triple TargetTriple, unsigned PageSize,
+      std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr);
 
   /// Create a SelfExecutorProcessControl with the given symbol string pool and
   /// memory manager.
@@ -330,6 +340,7 @@ public:
   /// be created and used by default.
   static Expected<std::unique_ptr<SelfExecutorProcessControl>>
   Create(std::shared_ptr<SymbolStringPool> SSP = nullptr,
+         std::unique_ptr<TaskDispatcher> D = nullptr,
          std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr = nullptr);
 
   Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override;
index a4b5ca8..55449f9 100644 (file)
@@ -34,9 +34,11 @@ public:
   /// Create a SimpleRemoteEPC using the given transport type and args.
   template <typename TransportT, typename... TransportTCtorArgTs>
   static Expected<std::unique_ptr<SimpleRemoteEPC>>
-  Create(TransportTCtorArgTs &&...TransportTCtorArgs) {
+  Create(std::unique_ptr<TaskDispatcher> D,
+         TransportTCtorArgTs &&...TransportTCtorArgs) {
     std::unique_ptr<SimpleRemoteEPC> SREPC(
-        new SimpleRemoteEPC(std::make_shared<SymbolStringPool>()));
+                                           new SimpleRemoteEPC(std::make_shared<SymbolStringPool>(),
+                                                               std::move(D)));
     auto T = TransportT::Create(
         *SREPC, std::forward<TransportTCtorArgTs>(TransportTCtorArgs)...);
     if (!T)
@@ -79,8 +81,9 @@ protected:
   virtual Expected<std::unique_ptr<MemoryAccess>> createMemoryAccess();
 
 private:
-  SimpleRemoteEPC(std::shared_ptr<SymbolStringPool> SSP)
-      : ExecutorProcessControl(std::move(SSP)) {}
+  SimpleRemoteEPC(std::shared_ptr<SymbolStringPool> SSP,
+                  std::unique_ptr<TaskDispatcher> D)
+    : ExecutorProcessControl(std::move(SSP), std::move(D)) {}
 
   Error sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
                     ExecutorAddr TagAddr, ArrayRef<char> ArgBytes);
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h b/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h
new file mode 100644 (file)
index 0000000..7bd81b8
--- /dev/null
@@ -0,0 +1,129 @@
+//===--------- TaskDispatch.h - ORC task dispatch utils ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Task and TaskDispatch classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
+#define LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
+
+#include "llvm/Config/llvm-config.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ExtensibleRTTI.h"
+
+#include <string>
+
+#if LLVM_ENABLE_THREADS
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#endif
+
+namespace llvm {
+namespace orc {
+
+/// Represents an abstract task for ORC to run.
+class Task : public RTTIExtends<Task, RTTIRoot> {
+public:
+  static char ID;
+
+  virtual ~Task() {}
+
+  /// Description of the task to be performed. Used for logging.
+  virtual void printDescription(raw_ostream &OS) = 0;
+
+  /// Run the task.
+  virtual void run() = 0;
+
+private:
+  void anchor() override;
+};
+
+/// Base class for generic tasks.
+class GenericNamedTask : public RTTIExtends<GenericNamedTask, Task> {
+public:
+  static char ID;
+  static const char *DefaultDescription;
+};
+
+/// Generic task implementation.
+template <typename FnT> class GenericNamedTaskImpl : public GenericNamedTask {
+public:
+  GenericNamedTaskImpl(FnT &&Fn, std::string DescBuffer)
+      : Fn(std::forward<FnT>(Fn)), Desc(DescBuffer.c_str()),
+        DescBuffer(std::move(DescBuffer)) {}
+  GenericNamedTaskImpl(FnT &&Fn, const char *Desc)
+      : Fn(std::forward<FnT>(Fn)), Desc(Desc) {
+    assert(Desc && "Description cannot be null");
+  }
+  void printDescription(raw_ostream &OS) override { OS << Desc; }
+  void run() override { Fn(); }
+
+private:
+  FnT Fn;
+  const char *Desc;
+  std::string DescBuffer;
+};
+
+/// Create a generic named task from a std::string description.
+template <typename FnT>
+std::unique_ptr<GenericNamedTask> makeGenericNamedTask(FnT &&Fn,
+                                                       std::string Desc) {
+  return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn),
+                                                     std::move(Desc));
+}
+
+/// Create a generic named task from a const char * description.
+template <typename FnT>
+std::unique_ptr<GenericNamedTask>
+makeGenericNamedTask(FnT &&Fn, const char *Desc = nullptr) {
+  if (!Desc)
+    Desc = GenericNamedTask::DefaultDescription;
+  return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn),
+                                                     Desc);
+}
+
+/// Abstract base for classes that dispatch ORC Tasks.
+class TaskDispatcher {
+public:
+  virtual ~TaskDispatcher();
+
+  /// Run the given task.
+  virtual void dispatch(std::unique_ptr<Task> T) = 0;
+
+  /// Called by ExecutionSession. Waits until all tasks have completed.
+  virtual void shutdown() = 0;
+};
+
+/// Runs all tasks on the current thread.
+class InPlaceTaskDispatcher : public TaskDispatcher {
+public:
+  void dispatch(std::unique_ptr<Task> T) override;
+  void shutdown() override;
+};
+
+#if LLVM_ENABLE_THREADS
+
+class DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
+public:
+  void dispatch(std::unique_ptr<Task> T) override;
+  void shutdown() override;
+private:
+  std::mutex DispatchMutex;
+  bool Running = true;
+  size_t Outstanding = 0;
+  std::condition_variable OutstandingCV;
+};
+
+#endif // LLVM_ENABLE_THREADS
+
+} // End namespace orc
+} // End namespace llvm
+
+#endif // LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H
index 8d69cc3..390cfe6 100644 (file)
@@ -32,6 +32,7 @@ add_llvm_component_library(LLVMOrcJIT
   Speculation.cpp
   SpeculateAnalyses.cpp
   ExecutorProcessControl.cpp
+  TaskDispatch.cpp
   ThreadSafeModule.cpp
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc
index 8ec3bf6..c29593b 100644 (file)
@@ -29,7 +29,6 @@ char SymbolsNotFound::ID = 0;
 char SymbolsCouldNotBeRemoved::ID = 0;
 char MissingSymbolDefinitions::ID = 0;
 char UnexpectedSymbolDefinitions::ID = 0;
-char Task::ID = 0;
 char MaterializationTask::ID = 0;
 
 RegisterDependenciesFunction NoDependenciesToRegister =
@@ -1799,8 +1798,6 @@ void Platform::lookupInitSymbolsAsync(
   }
 }
 
-void Task::anchor() {}
-
 void MaterializationTask::printDescription(raw_ostream &OS) {
   OS << "Materialization task: " << MU->getName() << " in "
      << MR->getTargetJITDylib().getName();
index dd57fbd..1485789 100644 (file)
@@ -24,9 +24,10 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {}
 ExecutorProcessControl::~ExecutorProcessControl() {}
 
 SelfExecutorProcessControl::SelfExecutorProcessControl(
-    std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
-    unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
-    : ExecutorProcessControl(std::move(SSP)) {
+    std::shared_ptr<SymbolStringPool> SSP, std::unique_ptr<TaskDispatcher> D,
+    Triple TargetTriple, unsigned PageSize,
+    std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
+    : ExecutorProcessControl(std::move(SSP), std::move(D)) {
 
   OwnedMemMgr = std::move(MemMgr);
   if (!OwnedMemMgr)
@@ -45,11 +46,20 @@ SelfExecutorProcessControl::SelfExecutorProcessControl(
 Expected<std::unique_ptr<SelfExecutorProcessControl>>
 SelfExecutorProcessControl::Create(
     std::shared_ptr<SymbolStringPool> SSP,
+    std::unique_ptr<TaskDispatcher> D,
     std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr) {
 
   if (!SSP)
     SSP = std::make_shared<SymbolStringPool>();
 
+  if (!D) {
+#if LLVM_ENABLE_THREADS
+    D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
+#else
+    D = std::make_unique<InPlaceTaskDispatcher>();
+#endif
+  }
+
   auto PageSize = sys::Process::getPageSize();
   if (!PageSize)
     return PageSize.takeError();
@@ -57,7 +67,8 @@ SelfExecutorProcessControl::Create(
   Triple TT(sys::getProcessTriple());
 
   return std::make_unique<SelfExecutorProcessControl>(
-      std::move(SSP), std::move(TT), *PageSize, std::move(MemMgr));
+      std::move(SSP), std::move(D), std::move(TT), *PageSize,
+      std::move(MemMgr));
 }
 
 Expected<tpctypes::DylibHandle>
diff --git a/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp b/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp
new file mode 100644 (file)
index 0000000..111c84e
--- /dev/null
@@ -0,0 +1,48 @@
+//===------------ TaskDispatch.cpp - ORC task dispatch utils --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
+
+namespace llvm {
+namespace orc {
+
+char Task::ID = 0;
+char GenericNamedTask::ID = 0;
+const char *GenericNamedTask::DefaultDescription = "Generic Task";
+
+void Task::anchor() {}
+TaskDispatcher::~TaskDispatcher() {}
+
+void InPlaceTaskDispatcher::dispatch(std::unique_ptr<Task> T) { T->run(); }
+
+void InPlaceTaskDispatcher::shutdown() {}
+
+#if LLVM_ENABLE_THREADS
+void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
+  {
+    std::lock_guard<std::mutex> Lock(DispatchMutex);
+    ++Outstanding;
+  }
+
+  std::thread([this, T = std::move(T)]() mutable {
+    T->run();
+    std::lock_guard<std::mutex> Lock(DispatchMutex);
+    --Outstanding;
+    OutstandingCV.notify_all();
+  }).detach();
+}
+
+void DynamicThreadPoolTaskDispatcher::shutdown() {
+  std::unique_lock<std::mutex> Lock(DispatchMutex);
+  Running = false;
+  OutstandingCV.wait(Lock, [this]() { return Outstanding == 0; });
+}
+#endif
+
+} // namespace orc
+} // namespace llvm
index 385ba2a..5a05dd7 100644 (file)
@@ -1150,6 +1150,7 @@ Expected<std::unique_ptr<orc::ExecutorProcessControl>> launchRemote() {
 
   // Return a SimpleRemoteEPC instance connected to our end of the pipes.
   return orc::SimpleRemoteEPC::Create<orc::FDSimpleRemoteEPCTransport>(
+      std::make_unique<llvm::orc::InPlaceTaskDispatcher>(),
       PipeFD[1][0], PipeFD[0][1]);
 #endif
 }
index 00dab88..7eb1da5 100644 (file)
@@ -718,6 +718,7 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> launchExecutor() {
   close(FromExecutor[WriteEnd]);
 
   return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
+      std::make_unique<DynamicThreadPoolTaskDispatcher>(),
       FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
 #endif
 }
@@ -795,7 +796,8 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> connectToExecutor() {
   if (!SockFD)
     return SockFD.takeError();
 
-  return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(*SockFD, *SockFD);
+  return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
+      std::make_unique<DynamicThreadPoolTaskDispatcher>(), *SockFD, *SockFD);
 #endif
 }
 
@@ -832,8 +834,9 @@ Expected<std::unique_ptr<Session>> Session::Create(Triple TT) {
     if (!PageSize)
       return PageSize.takeError();
     EPC = std::make_unique<SelfExecutorProcessControl>(
-        std::make_shared<SymbolStringPool>(), std::move(TT), *PageSize,
-        createMemoryManager());
+        std::make_shared<SymbolStringPool>(),
+        std::make_unique<DynamicThreadPoolTaskDispatcher>(),
+        std::move(TT), *PageSize, createMemoryManager());
   }
 
   Error Err = Error::success();
index e8904db..404afac 100644 (file)
@@ -32,6 +32,7 @@ add_llvm_unittest(OrcJITTests
   SimpleExecutorMemoryManagerTest.cpp
   SimplePackedSerializationTest.cpp
   SymbolStringPoolTest.cpp
+  TaskDispatchTest.cpp
   ThreadSafeModuleTest.cpp
   WrapperFunctionUtilsTest.cpp
   )
diff --git a/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp b/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp
new file mode 100644 (file)
index 0000000..60a5e75
--- /dev/null
@@ -0,0 +1,33 @@
+//===----------- TaskDispatchTest.cpp - Test TaskDispatch APIs ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
+#include "gtest/gtest.h"
+
+#include <future>
+
+using namespace llvm;
+using namespace llvm::orc;
+
+TEST(InPlaceTaskDispatchTest, GenericNamedTask) {
+  auto D = std::make_unique<InPlaceTaskDispatcher>();
+  bool B = false;
+  D->dispatch(makeGenericNamedTask([&]() { B = true; }));
+  EXPECT_TRUE(B);
+}
+
+#if LLVM_ENABLE_THREADS
+TEST(DynamicThreadPoolDispatchTest, GenericNamedTask) {
+  auto D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
+  std::promise<bool> P;
+  auto F = P.get_future();
+  D->dispatch(makeGenericNamedTask(
+      [P = std::move(P)]() mutable { P.set_value(true); }));
+  EXPECT_TRUE(F.get());
+}
+#endif