[ORC] Add wrapper-function support methods to ExecutorProcessControl.
authorLang Hames <lhames@gmail.com>
Sat, 19 Jun 2021 07:36:47 +0000 (17:36 +1000)
committerLang Hames <lhames@gmail.com>
Thu, 1 Jul 2021 08:21:49 +0000 (18:21 +1000)
Adds support for both synchronous and asynchronous calls to wrapper functions
using SPS (Simple Packed Serialization). Also adds support for wrapping
functions on the JIT side in SPS-based wrappers that can be called from the
executor.

These new methods simplify calls between the JIT and Executor, and will be used
in upcoming ORC runtime patches to enable communication between ORC and the
runtime.

llvm/include/llvm/ExecutionEngine/Orc/Core.h
llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h
llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp [new file with mode: 0644]
llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp

index ae82691..42bcffd 100644 (file)
@@ -216,6 +216,18 @@ public:
       add(Name, Flags);
   }
 
+  /// Construct a SymbolLookupSet from DenseMap keys.
+  template <typename KeyT>
+  static SymbolLookupSet
+  fromMapKeys(const DenseMap<SymbolStringPtr, KeyT> &M,
+              SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) {
+    SymbolLookupSet Result;
+    Result.Symbols.reserve(M.size());
+    for (const auto &KV : M)
+      Result.add(KV.first, Flags);
+    return Result;
+  }
+
   /// Add an element to the set. The client is responsible for checking that
   /// duplicates are not added.
   SymbolLookupSet &
index 7969a83..566637e 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/Support/MSVCErrorWorkarounds.h"
 
 #include <future>
+#include <mutex>
 #include <vector>
 
 namespace llvm {
@@ -32,6 +33,19 @@ namespace orc {
 /// ExecutorProcessControl supports interaction with a JIT target process.
 class ExecutorProcessControl {
 public:
+  /// Sender to return the result of a WrapperFunction executed in the JIT.
+  using SendResultFunction =
+      unique_function<void(shared::WrapperFunctionResult)>;
+
+  /// An asynchronous wrapper-function.
+  using AsyncWrapperFunction = unique_function<void(
+      SendResultFunction SendResult, const char *ArgData, size_t ArgSize)>;
+
+  /// A map associating tag names with asynchronous wrapper function
+  /// implementations in the JIT.
+  using WrapperFunctionAssociationMap =
+      DenseMap<SymbolStringPtr, AsyncWrapperFunction>;
+
   /// APIs for manipulating memory in the target process.
   class MemoryAccess {
   public:
@@ -138,14 +152,91 @@ public:
   virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
                                       ArrayRef<std::string> Args) = 0;
 
-  /// Run a wrapper function in the executor.
+  /// Run a wrapper function in the executor (async version).
+  ///
+  /// The wrapper function should be callable as:
+  ///
+  /// \code{.cpp}
+  ///   CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
+  /// \endcode{.cpp}
+  ///
+  /// The given OnComplete function will be called to return the result.
+  virtual void runWrapperAsync(SendResultFunction OnComplete,
+                               JITTargetAddress WrapperFnAddr,
+                               ArrayRef<char> ArgBuffer) = 0;
+
+  /// Run a wrapper function in the executor. The wrapper function should be
+  /// callable as:
   ///
   /// \code{.cpp}
   ///   CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
   /// \endcode{.cpp}
+  shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr,
+                                           ArrayRef<char> ArgBuffer) {
+    std::promise<shared::WrapperFunctionResult> RP;
+    auto RF = RP.get_future();
+    runWrapperAsync(
+        [&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); },
+        WrapperFnAddr, ArgBuffer);
+    return RF.get();
+  }
+
+  /// Run a wrapper function using SPS to serialize the arguments and
+  /// deserialize the results.
+  template <typename SPSSignature, typename SendResultT, typename... ArgTs>
+  void runSPSWrapperAsync(SendResultT &&SendResult,
+                          JITTargetAddress WrapperFnAddr,
+                          const ArgTs &...Args) {
+    shared::WrapperFunction<SPSSignature>::callAsync(
+        [this, WrapperFnAddr](SendResultFunction SendResult,
+                              const char *ArgData, size_t ArgSize) {
+          runWrapperAsync(std::move(SendResult), WrapperFnAddr,
+                          ArrayRef<char>(ArgData, ArgSize));
+        },
+        std::move(SendResult), Args...);
+  }
+
+  /// Run a wrapper function using SPS to serialize the arguments and
+  /// deserialize the results.
+  template <typename SPSSignature, typename RetT, typename... ArgTs>
+  Error runSPSWrapper(JITTargetAddress WrapperFnAddr, RetT &RetVal,
+                      const ArgTs &...Args) {
+    return shared::WrapperFunction<SPSSignature>::call(
+        [this, WrapperFnAddr](const char *ArgData, size_t ArgSize) {
+          return runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
+        },
+        RetVal, Args...);
+  }
+
+  /// Wrap a handler that takes concrete argument types (and a sender for a
+  /// concrete return type) to produce an AsyncWrapperFunction. Uses SPS to
+  /// unpack the arguments and pack the result.
   ///
-  virtual Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) = 0;
+  /// This function is usually used when building association maps.
+  template <typename SPSSignature, typename HandlerT>
+  static AsyncWrapperFunction wrapAsyncWithSPS(HandlerT &&H) {
+    return [H = std::forward<HandlerT>(H)](SendResultFunction SendResult,
+                                           const char *ArgData,
+                                           size_t ArgSize) mutable {
+      shared::WrapperFunction<SPSSignature>::handleAsync(ArgData, ArgSize, H,
+                                                         std::move(SendResult));
+    };
+  }
+
+  /// For each symbol name, associate the AsyncWrapperFunction implementation
+  /// value with the address of that symbol.
+  ///
+  /// Symbols will be looked up using LookupKind::Static,
+  /// JITDylibLookupFlags::MatchAllSymbols (hidden tags will be found), and
+  /// LookupFlags::WeaklyReferencedSymbol (missing tags will not cause an
+  /// error, the implementations will simply be dropped).
+  Error associateJITSideWrapperFunctions(JITDylib &JD,
+                                         WrapperFunctionAssociationMap WFs);
+
+  /// Run a registered jit-side wrapper function.
+  void runJITSideWrapperFunction(SendResultFunction SendResult,
+                                 JITTargetAddress TagAddr,
+                                 ArrayRef<char> ArgBuffer);
 
   /// Disconnect from the target process.
   ///
@@ -161,6 +252,9 @@ protected:
   unsigned PageSize = 0;
   MemoryAccess *MemAccess = nullptr;
   jitlink::JITLinkMemoryManager *MemMgr = nullptr;
+
+  std::mutex TagToFuncMapMutex;
+  DenseMap<JITTargetAddress, std::shared_ptr<AsyncWrapperFunction>> TagToFunc;
 };
 
 /// Call a wrapper function via ExecutorProcessControl::runWrapper.
@@ -168,8 +262,8 @@ class EPCCaller {
 public:
   EPCCaller(ExecutorProcessControl &EPC, JITTargetAddress WrapperFnAddr)
       : EPC(EPC), WrapperFnAddr(WrapperFnAddr) {}
-  Expected<shared::WrapperFunctionResult> operator()(const char *ArgData,
-                                                     size_t ArgSize) const {
+  shared::WrapperFunctionResult operator()(const char *ArgData,
+                                           size_t ArgSize) const {
     return EPC.runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
   }
 
@@ -202,8 +296,9 @@ public:
   Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
                               ArrayRef<std::string> Args) override;
 
-  Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) override;
+  void runWrapperAsync(SendResultFunction OnComplete,
+                       JITTargetAddress WrapperFnAddr,
+                       ArrayRef<char> ArgBuffer) override;
 
   Error disconnect() override;
 
index 0b5ee26..69e37f9 100644 (file)
@@ -354,9 +354,9 @@ public:
     return Result;
   }
 
-  Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr,
-             ArrayRef<char> ArgBuffer) override {
+  void runWrapperAsync(SendResultFunction OnComplete,
+                       JITTargetAddress WrapperFnAddr,
+                       ArrayRef<char> ArgBuffer) override {
     DEBUG_WITH_TYPE("orc", {
       dbgs() << "Running as wrapper function "
              << formatv("{0:x16}", WrapperFnAddr) << " with "
@@ -366,7 +366,11 @@ public:
         WrapperFnAddr,
         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()),
                           ArgBuffer.size()));
-    return Result;
+
+    if (!Result)
+      OnComplete(shared::WrapperFunctionResult::createOutOfBandError(
+          toString(Result.takeError())));
+    OnComplete(std::move(*Result));
   }
 
   Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {
index 0fc8af7..ceaea1d 100644 (file)
@@ -172,17 +172,16 @@ private:
 namespace detail {
 
 template <typename SPSArgListT, typename... ArgTs>
-Expected<WrapperFunctionResult>
+WrapperFunctionResult
 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
   WrapperFunctionResult Result;
   char *DataPtr =
       WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
   SPSOutputBuffer OB(DataPtr, Result.size());
   if (!SPSArgListT::serialize(OB, Args...))
-    return make_error<StringError>(
-        "Error serializing arguments to blob in call",
-        inconvertibleErrorCode());
-  return std::move(Result);
+    return WrapperFunctionResult::createOutOfBandError(
+        "Error serializing arguments to blob in call");
+  return Result;
 }
 
 template <typename RetT> class WrapperFunctionHandlerCaller {
@@ -230,12 +229,8 @@ public:
     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
         std::forward<HandlerT>(H), Args, ArgIndices{});
 
-    if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
-            std::move(HandlerResult)))
-      return std::move(*Result);
-    else
-      return WrapperFunctionResult::createOutOfBandError(
-          toString(Result.takeError()));
+    return ResultSerializer<decltype(HandlerResult)>::serialize(
+        std::move(HandlerResult));
   }
 
 private:
@@ -247,10 +242,10 @@ private:
   }
 };
 
-// Map function references to function types.
+// Map function pointers to function types.
 template <typename RetT, typename... ArgTs,
           template <typename> class ResultSerializer, typename... SPSTagTs>
-class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
+class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
                                    SPSTagTs...>
     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
                                           SPSTagTs...> {};
@@ -271,9 +266,87 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
                                           SPSTagTs...> {};
 
+template <typename WrapperFunctionImplT,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper
+    : public WrapperFunctionAsyncHandlerHelper<
+          decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
+          ResultSerializer, SPSTagTs...> {};
+
+template <typename RetT, typename SendResultT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
+                                        ResultSerializer, SPSTagTs...> {
+public:
+  using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
+  using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
+
+  template <typename HandlerT, typename SendWrapperFunctionResultT>
+  static void applyAsync(HandlerT &&H,
+                         SendWrapperFunctionResultT &&SendWrapperFunctionResult,
+                         const char *ArgData, size_t ArgSize) {
+    ArgTuple Args;
+    if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
+      SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
+          "Could not deserialize arguments for wrapper function call"));
+      return;
+    }
+
+    auto SendResult =
+        [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
+          using ResultT = decltype(Result);
+          SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
+        };
+
+    callAsync(std::forward<HandlerT>(H), std::move(SendResult), Args,
+              ArgIndices{});
+  }
+
+private:
+  template <std::size_t... I>
+  static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
+                          std::index_sequence<I...>) {
+    SPSInputBuffer IB(ArgData, ArgSize);
+    return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
+  }
+
+  template <typename HandlerT, typename SerializeAndSendResultT,
+            typename ArgTupleT, std::size_t... I>
+  static void callAsync(HandlerT &&H,
+                        SerializeAndSendResultT &&SerializeAndSendResult,
+                        ArgTupleT &Args, std::index_sequence<I...>) {
+    return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
+                                     std::get<I>(Args)...);
+  }
+};
+
+// Map function pointers to function types.
+template <typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
+                                        SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
+// Map non-const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
+                                        ResultSerializer, SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
+// Map const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
+                                        ResultSerializer, SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
 public:
-  static Expected<WrapperFunctionResult> serialize(RetT Result) {
+  static WrapperFunctionResult serialize(RetT Result) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         Result);
   }
@@ -281,7 +354,7 @@ public:
 
 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
 public:
-  static Expected<WrapperFunctionResult> serialize(Error Err) {
+  static WrapperFunctionResult serialize(Error Err) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(Err)));
   }
@@ -290,7 +363,7 @@ public:
 template <typename SPSRetTagT, typename T>
 class ResultSerializer<SPSRetTagT, Expected<T>> {
 public:
-  static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
+  static WrapperFunctionResult serialize(Expected<T> E) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(E)));
   }
@@ -298,6 +371,7 @@ public:
 
 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
 public:
+  static RetT makeValue() { return RetT(); }
   static void makeSafe(RetT &Result) {}
 
   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
@@ -312,6 +386,7 @@ public:
 
 template <> class ResultDeserializer<SPSError, Error> {
 public:
+  static Error makeValue() { return Error::success(); }
   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
 
   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
@@ -329,6 +404,7 @@ public:
 template <typename SPSTagT, typename T>
 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
 public:
+  static Expected<T> makeValue() { return T(); }
   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
 
   static Error deserialize(Expected<T> &E, const char *ArgData,
@@ -344,6 +420,10 @@ public:
   }
 };
 
+template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
+  // Did you forget to use Error / Expected in your handler?
+};
+
 } // end namespace detail
 
 template <typename SPSSignature> class WrapperFunction;
@@ -355,7 +435,7 @@ private:
   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
 
 public:
-  /// Call a wrapper function. Callere should be callable as
+  /// Call a wrapper function. Caller should be callable as
   /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
   template <typename CallerFn, typename RetT, typename... ArgTs>
   static Error call(const CallerFn &Caller, RetT &Result,
@@ -369,18 +449,56 @@ public:
     auto ArgBuffer =
         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
             Args...);
-    if (!ArgBuffer)
-      return ArgBuffer.takeError();
-
-    Expected<WrapperFunctionResult> ResultBuffer =
-        Caller(ArgBuffer->data(), ArgBuffer->size());
-    if (!ResultBuffer)
-      return ResultBuffer.takeError();
-    if (auto ErrMsg = ResultBuffer->getOutOfBandError())
+    if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
+      return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
+
+    WrapperFunctionResult ResultBuffer =
+        Caller(ArgBuffer.data(), ArgBuffer.size());
+    if (auto ErrMsg = ResultBuffer.getOutOfBandError())
       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
 
     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
-        Result, ResultBuffer->data(), ResultBuffer->size());
+        Result, ResultBuffer.data(), ResultBuffer.size());
+  }
+
+  /// Call an async wrapper function.
+  /// Caller should be callable as
+  /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
+  ///         WrapperFunctionResult ArgBuffer);
+  template <typename AsyncCallerFn, typename SendDeserializedResultFn,
+            typename... ArgTs>
+  static void callAsync(AsyncCallerFn &&Caller,
+                        SendDeserializedResultFn &&SendDeserializedResult,
+                        const ArgTs &...Args) {
+    using RetT = typename std::tuple_element<
+        1, typename detail::WrapperFunctionHandlerHelper<
+               std::remove_reference_t<SendDeserializedResultFn>,
+               ResultSerializer, SPSRetTagT>::ArgTuple>::type;
+
+    auto ArgBuffer =
+        detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
+            Args...);
+    if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
+      SendDeserializedResult(
+          make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
+          detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
+      return;
+    }
+
+    auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
+                                    WrapperFunctionResult R) {
+      RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
+      detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
+
+      SPSInputBuffer IB(R.data(), R.size());
+      if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
+              RetVal, R.data(), R.size()))
+        SDR(std::move(Err), std::move(RetVal));
+
+      SDR(Error::success(), std::move(RetVal));
+    };
+
+    Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
   }
 
   /// Handle a call to a wrapper function.
@@ -388,11 +506,21 @@ public:
   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
                                       HandlerT &&Handler) {
     using WFHH =
-        detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
-                                             SPSTagTs...>;
+        detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
+                                             ResultSerializer, SPSTagTs...>;
     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
   }
 
+  /// Handle a call to an async wrapper function.
+  template <typename HandlerT, typename SendResultT>
+  static void handleAsync(const char *ArgData, size_t ArgSize,
+                          HandlerT &&Handler, SendResultT &&SendResult) {
+    using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
+        std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
+    WFAHH::applyAsync(std::forward<HandlerT>(Handler),
+                      std::forward<SendResultT>(SendResult), ArgData, ArgSize);
+  }
+
 private:
   template <typename T> static const T &makeSerializable(const T &Value) {
     return Value;
@@ -411,6 +539,7 @@ private:
 template <typename... SPSTagTs>
 class WrapperFunction<void(SPSTagTs...)>
     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
+
 public:
   template <typename CallerFn, typename... ArgTs>
   static Error call(const CallerFn &Caller, const ArgTs &...Args) {
@@ -419,6 +548,7 @@ public:
   }
 
   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
+  using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
 };
 
 } // end namespace shared
index f8bd74e..12fa42c 100644 (file)
 
 #include "llvm/ExecutionEngine/Orc/Core.h"
 #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Host.h"
 #include "llvm/Support/Process.h"
 
-#include <mutex>
-
 namespace llvm {
 namespace orc {
 
@@ -22,6 +21,56 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {}
 
 ExecutorProcessControl::~ExecutorProcessControl() {}
 
+Error ExecutorProcessControl::associateJITSideWrapperFunctions(
+    JITDylib &JD, WrapperFunctionAssociationMap WFs) {
+
+  // Look up tag addresses.
+  auto &ES = JD.getExecutionSession();
+  auto TagAddrs =
+      ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}},
+                SymbolLookupSet::fromMapKeys(
+                    WFs, SymbolLookupFlags::WeaklyReferencedSymbol));
+  if (!TagAddrs)
+    return TagAddrs.takeError();
+
+  // Associate tag addresses with implementations.
+  std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+  for (auto &KV : *TagAddrs) {
+    auto TagAddr = KV.second.getAddress();
+    if (TagToFunc.count(TagAddr))
+      return make_error<StringError>("Tag " + formatv("{0:x16}", TagAddr) +
+                                         " (for " + *KV.first +
+                                         ") already registered",
+                                     inconvertibleErrorCode());
+    auto I = WFs.find(KV.first);
+    assert(I != WFs.end() && I->second &&
+           "AsyncWrapperFunction implementation missing");
+    TagToFunc[KV.second.getAddress()] =
+        std::make_shared<AsyncWrapperFunction>(std::move(I->second));
+  }
+  return Error::success();
+}
+
+void ExecutorProcessControl::runJITSideWrapperFunction(
+    SendResultFunction SendResult, JITTargetAddress TagAddr,
+    ArrayRef<char> ArgBuffer) {
+
+  std::shared_ptr<AsyncWrapperFunction> F;
+  {
+    std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+    auto I = TagToFunc.find(TagAddr);
+    if (I != TagToFunc.end())
+      F = I->second;
+  }
+
+  if (F)
+    (*F)(std::move(SendResult), ArgBuffer.data(), ArgBuffer.size());
+  else
+    SendResult(shared::WrapperFunctionResult::createOutOfBandError(
+        ("No function registered for tag " + formatv("{0:x16}", TagAddr))
+            .str()));
+}
+
 SelfExecutorProcessControl::SelfExecutorProcessControl(
     std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
     unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
@@ -102,13 +151,13 @@ SelfExecutorProcessControl::runAsMain(JITTargetAddress MainFnAddr,
   return orc::runAsMain(jitTargetAddressToFunction<MainTy>(MainFnAddr), Args);
 }
 
-Expected<shared::WrapperFunctionResult>
-SelfExecutorProcessControl::runWrapper(JITTargetAddress WrapperFnAddr,
-                                       ArrayRef<char> ArgBuffer) {
-  using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)(
-      const char *Data, uint64_t Size);
+void SelfExecutorProcessControl::runWrapperAsync(SendResultFunction SendResult,
+                                                 JITTargetAddress WrapperFnAddr,
+                                                 ArrayRef<char> ArgBuffer) {
+  using WrapperFnTy =
+      shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size);
   auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr);
-  return WrapperFn(ArgBuffer.data(), ArgBuffer.size());
+  SendResult(WrapperFn(ArgBuffer.data(), ArgBuffer.size()));
 }
 
 Error SelfExecutorProcessControl::disconnect() { return Error::success(); }
index b1cfd18..b544cfa 100644 (file)
@@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(OrcJITTests
   CoreAPIsTest.cpp
+  ExecutorProcessControlTest.cpp
   IndirectionUtilsTest.cpp
   JITTargetMachineBuilderTest.cpp
   LazyCallThroughAndReexportsTest.cpp
diff --git a/llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp
new file mode 100644 (file)
index 0000000..23096c8
--- /dev/null
@@ -0,0 +1,105 @@
+//===- ExecutorProcessControlTest.cpp - Test ExecutorProcessControl utils -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
+#include "llvm/Support/MSVCErrorWorkarounds.h"
+#include "llvm/Testing/Support/Error.h"
+#include "gtest/gtest.h"
+
+#include <future>
+
+using namespace llvm;
+using namespace llvm::orc;
+using namespace llvm::orc::shared;
+
+static llvm::orc::shared::detail::CWrapperFunctionResult
+addWrapper(const char *ArgData, size_t ArgSize) {
+  return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
+             ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
+      .release();
+}
+
+static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
+                            int32_t X, int32_t Y) {
+  SendResult(X + Y);
+}
+
+TEST(ExecutorProcessControl, RunWrapperTemplate) {
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+  int32_t Result;
+  EXPECT_THAT_ERROR(EPC->runSPSWrapper<int32_t(int32_t, int32_t)>(
+                        pointerToJITTargetAddress(addWrapper), Result, 2, 3),
+                    Succeeded());
+  EXPECT_EQ(Result, 5);
+}
+
+TEST(ExecutorProcessControl, RunWrapperAsyncTemplate) {
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+  std::promise<MSVCPExpected<int32_t>> RP;
+  using Sig = int32_t(int32_t, int32_t);
+  EPC->runSPSWrapperAsync<Sig>(
+      [&](Error SerializationErr, int32_t R) {
+        if (SerializationErr)
+          RP.set_value(std::move(SerializationErr));
+        RP.set_value(std::move(R));
+      },
+      pointerToJITTargetAddress(addWrapper), 2, 3);
+  Expected<int32_t> Result = RP.get_future().get();
+  EXPECT_THAT_EXPECTED(Result, HasValue(5));
+}
+
+TEST(ExecutorProcessControl, RegisterAsyncHandlerAndRun) {
+
+  constexpr JITTargetAddress AddAsyncTagAddr = 0x01;
+
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+  ExecutionSession ES(EPC->getSymbolStringPool());
+  auto &JD = ES.createBareJITDylib("JD");
+
+  auto AddAsyncTag = ES.intern("addAsync_tag");
+  cantFail(JD.define(absoluteSymbols(
+      {{AddAsyncTag,
+        JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}})));
+
+  ExecutorProcessControl::WrapperFunctionAssociationMap Associations;
+
+  Associations[AddAsyncTag] =
+      EPC->wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
+
+  cantFail(EPC->associateJITSideWrapperFunctions(JD, std::move(Associations)));
+
+  std::promise<int32_t> RP;
+  auto RF = RP.get_future();
+
+  using ArgSerialization = SPSArgList<int32_t, int32_t>;
+  size_t ArgBufferSize = ArgSerialization::size(1, 2);
+  WrapperFunctionResult ArgBuffer;
+  char *ArgBufferData =
+      WrapperFunctionResult::allocate(ArgBuffer, ArgBufferSize);
+  SPSOutputBuffer OB(ArgBufferData, ArgBufferSize);
+  EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
+
+  EPC->runJITSideWrapperFunction(
+      [&](WrapperFunctionResult ResultBuffer) {
+        int32_t Result;
+        SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
+        EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
+        RP.set_value(Result);
+      },
+      AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
+
+  EXPECT_EQ(RF.get(), (int32_t)3);
+
+  cantFail(ES.endSession());
+}
index 1f177b4..4205183 100644 (file)
@@ -7,8 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
+#include "llvm/ADT/FunctionExtras.h"
 #include "gtest/gtest.h"
 
+#include <future>
+
 using namespace llvm;
 using namespace llvm::orc::shared;
 
@@ -65,13 +68,54 @@ static WrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) {
       ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; });
 }
 
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleVoid) {
   EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopWrapper));
 }
 
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleRet) {
   int32_t Result;
   EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
       addWrapper, Result, 1, 2));
   EXPECT_EQ(Result, (int32_t)3);
 }
+
+static void voidNoopAsync(unique_function<void(SPSEmpty)> SendResult) {
+  SendResult(SPSEmpty());
+}
+
+static WrapperFunctionResult voidNoopAsyncWrapper(const char *ArgData,
+                                                  size_t ArgSize) {
+  std::promise<WrapperFunctionResult> RP;
+  auto RF = RP.get_future();
+
+  WrapperFunction<void()>::handleAsync(
+      ArgData, ArgSize, voidNoopAsync,
+      [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+
+  return RF.get();
+}
+
+static WrapperFunctionResult addAsyncWrapper(const char *ArgData,
+                                             size_t ArgSize) {
+  std::promise<WrapperFunctionResult> RP;
+  auto RF = RP.get_future();
+
+  WrapperFunction<int32_t(int32_t, int32_t)>::handleAsync(
+      ArgData, ArgSize,
+      [](unique_function<void(int32_t)> SendResult, int32_t X, int32_t Y) {
+        SendResult(X + Y);
+      },
+      [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+  return RF.get();
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncVoid) {
+  EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopAsyncWrapper));
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncRet) {
+  int32_t Result;
+  EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
+      addAsyncWrapper, Result, 1, 2));
+  EXPECT_EQ(Result, (int32_t)3);
+}