[ORC] Handle void and no-argument async wrapper calls.
authorLang Hames <lhames@gmail.com>
Thu, 19 Aug 2021 02:19:36 +0000 (12:19 +1000)
committerLang Hames <lhames@gmail.com>
Thu, 19 Aug 2021 02:20:31 +0000 (12:20 +1000)
llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
llvm/unittests/ExecutionEngine/Orc/ExecutionSessionWrapperFunctionCallsTest.cpp

index e92f580..d1d7d54 100644 (file)
@@ -547,6 +547,18 @@ public:
     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...);
   }
 
+  template <typename AsyncCallerFn, typename SendDeserializedResultFn,
+            typename... ArgTs>
+  static void callAsync(AsyncCallerFn &&Caller,
+                        SendDeserializedResultFn &&SendDeserializedResult,
+                        const ArgTs &...Args) {
+    WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync(
+        Caller,
+        [SDR = std::move(SendDeserializedResult)](
+            Error SerializeErr, SPSEmpty E) { SDR(std::move(SerializeErr)); },
+        Args...);
+  }
+
   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
 };
index 9044ead..39554e2 100644 (file)
@@ -30,6 +30,11 @@ static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
   SendResult(X + Y);
 }
 
+static llvm::orc::shared::detail::CWrapperFunctionResult
+voidWrapper(const char *ArgData, size_t ArgSize) {
+  return WrapperFunction<void()>::handle(ArgData, ArgSize, []() {}).release();
+}
+
 TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) {
   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
 
@@ -40,12 +45,24 @@ TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) {
   EXPECT_EQ(Result, 5);
 }
 
-TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperAsyncTemplate) {
+TEST(ExecutionSessionWrapperFunctionCalls, RunVoidWrapperAsyncTemplate) {
+  ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
+
+  std::promise<MSVCPError> RP;
+  ES.callSPSWrapperAsync<void()>(
+      [&](Error SerializationErr) {
+        RP.set_value(std::move(SerializationErr));
+      },
+      pointerToJITTargetAddress(voidWrapper));
+  Error Err = RP.get_future().get();
+  EXPECT_THAT_ERROR(std::move(Err), Succeeded());
+}
+
+TEST(ExecutionSessionWrapperFunctionCalls, RunNonVoidWrapperAsyncTemplate) {
   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
 
   std::promise<MSVCPExpected<int32_t>> RP;
-  using Sig = int32_t(int32_t, int32_t);
-  ES.callSPSWrapperAsync<Sig>(
+  ES.callSPSWrapperAsync<int32_t(int32_t, int32_t)>(
       [&](Error SerializationErr, int32_t R) {
         if (SerializationErr)
           RP.set_value(std::move(SerializationErr));