add(Name, Flags);
}
+ /// Construct a SymbolLookupSet from DenseMap keys.
+ template <typename KeyT>
+ static SymbolLookupSet
+ fromMapKeys(const DenseMap<SymbolStringPtr, KeyT> &M,
+ SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) {
+ SymbolLookupSet Result;
+ Result.Symbols.reserve(M.size());
+ for (const auto &KV : M)
+ Result.add(KV.first, Flags);
+ return Result;
+ }
+
/// Add an element to the set. The client is responsible for checking that
/// duplicates are not added.
SymbolLookupSet &
#include "llvm/Support/MSVCErrorWorkarounds.h"
#include <future>
+#include <mutex>
#include <vector>
namespace llvm {
/// ExecutorProcessControl supports interaction with a JIT target process.
class ExecutorProcessControl {
public:
+ /// Sender to return the result of a WrapperFunction executed in the JIT.
+ using SendResultFunction =
+ unique_function<void(shared::WrapperFunctionResult)>;
+
+ /// An asynchronous wrapper-function.
+ using AsyncWrapperFunction = unique_function<void(
+ SendResultFunction SendResult, const char *ArgData, size_t ArgSize)>;
+
+ /// A map associating tag names with asynchronous wrapper function
+ /// implementations in the JIT.
+ using WrapperFunctionAssociationMap =
+ DenseMap<SymbolStringPtr, AsyncWrapperFunction>;
+
/// APIs for manipulating memory in the target process.
class MemoryAccess {
public:
virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
ArrayRef<std::string> Args) = 0;
- /// Run a wrapper function in the executor.
+ /// Run a wrapper function in the executor (async version).
+ ///
+ /// The wrapper function should be callable as:
+ ///
+ /// \code{.cpp}
+ /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
+ /// \endcode{.cpp}
+ ///
+ /// The given OnComplete function will be called to return the result.
+ virtual void runWrapperAsync(SendResultFunction OnComplete,
+ JITTargetAddress WrapperFnAddr,
+ ArrayRef<char> ArgBuffer) = 0;
+
+ /// Run a wrapper function in the executor. The wrapper function should be
+ /// callable as:
///
/// \code{.cpp}
/// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
/// \endcode{.cpp}
+ shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr,
+ ArrayRef<char> ArgBuffer) {
+ std::promise<shared::WrapperFunctionResult> RP;
+ auto RF = RP.get_future();
+ runWrapperAsync(
+ [&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); },
+ WrapperFnAddr, ArgBuffer);
+ return RF.get();
+ }
+
+ /// Run a wrapper function using SPS to serialize the arguments and
+ /// deserialize the results.
+ template <typename SPSSignature, typename SendResultT, typename... ArgTs>
+ void runSPSWrapperAsync(SendResultT &&SendResult,
+ JITTargetAddress WrapperFnAddr,
+ const ArgTs &...Args) {
+ shared::WrapperFunction<SPSSignature>::callAsync(
+ [this, WrapperFnAddr](SendResultFunction SendResult,
+ const char *ArgData, size_t ArgSize) {
+ runWrapperAsync(std::move(SendResult), WrapperFnAddr,
+ ArrayRef<char>(ArgData, ArgSize));
+ },
+ std::move(SendResult), Args...);
+ }
+
+ /// Run a wrapper function using SPS to serialize the arguments and
+ /// deserialize the results.
+ template <typename SPSSignature, typename RetT, typename... ArgTs>
+ Error runSPSWrapper(JITTargetAddress WrapperFnAddr, RetT &RetVal,
+ const ArgTs &...Args) {
+ return shared::WrapperFunction<SPSSignature>::call(
+ [this, WrapperFnAddr](const char *ArgData, size_t ArgSize) {
+ return runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
+ },
+ RetVal, Args...);
+ }
+
+ /// Wrap a handler that takes concrete argument types (and a sender for a
+ /// concrete return type) to produce an AsyncWrapperFunction. Uses SPS to
+ /// unpack the arguments and pack the result.
///
- virtual Expected<shared::WrapperFunctionResult>
- runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) = 0;
+ /// This function is usually used when building association maps.
+ template <typename SPSSignature, typename HandlerT>
+ static AsyncWrapperFunction wrapAsyncWithSPS(HandlerT &&H) {
+ return [H = std::forward<HandlerT>(H)](SendResultFunction SendResult,
+ const char *ArgData,
+ size_t ArgSize) mutable {
+ shared::WrapperFunction<SPSSignature>::handleAsync(ArgData, ArgSize, H,
+ std::move(SendResult));
+ };
+ }
+
+ /// For each symbol name, associate the AsyncWrapperFunction implementation
+ /// value with the address of that symbol.
+ ///
+ /// Symbols will be looked up using LookupKind::Static,
+ /// JITDylibLookupFlags::MatchAllSymbols (hidden tags will be found), and
+ /// LookupFlags::WeaklyReferencedSymbol (missing tags will not cause an
+ /// error, the implementations will simply be dropped).
+ Error associateJITSideWrapperFunctions(JITDylib &JD,
+ WrapperFunctionAssociationMap WFs);
+
+ /// Run a registered jit-side wrapper function.
+ void runJITSideWrapperFunction(SendResultFunction SendResult,
+ JITTargetAddress TagAddr,
+ ArrayRef<char> ArgBuffer);
/// Disconnect from the target process.
///
unsigned PageSize = 0;
MemoryAccess *MemAccess = nullptr;
jitlink::JITLinkMemoryManager *MemMgr = nullptr;
+
+ std::mutex TagToFuncMapMutex;
+ DenseMap<JITTargetAddress, std::shared_ptr<AsyncWrapperFunction>> TagToFunc;
};
/// Call a wrapper function via ExecutorProcessControl::runWrapper.
public:
EPCCaller(ExecutorProcessControl &EPC, JITTargetAddress WrapperFnAddr)
: EPC(EPC), WrapperFnAddr(WrapperFnAddr) {}
- Expected<shared::WrapperFunctionResult> operator()(const char *ArgData,
- size_t ArgSize) const {
+ shared::WrapperFunctionResult operator()(const char *ArgData,
+ size_t ArgSize) const {
return EPC.runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
}
Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
ArrayRef<std::string> Args) override;
- Expected<shared::WrapperFunctionResult>
- runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) override;
+ void runWrapperAsync(SendResultFunction OnComplete,
+ JITTargetAddress WrapperFnAddr,
+ ArrayRef<char> ArgBuffer) override;
Error disconnect() override;
return Result;
}
- Expected<shared::WrapperFunctionResult>
- runWrapper(JITTargetAddress WrapperFnAddr,
- ArrayRef<char> ArgBuffer) override {
+ void runWrapperAsync(SendResultFunction OnComplete,
+ JITTargetAddress WrapperFnAddr,
+ ArrayRef<char> ArgBuffer) override {
DEBUG_WITH_TYPE("orc", {
dbgs() << "Running as wrapper function "
<< formatv("{0:x16}", WrapperFnAddr) << " with "
WrapperFnAddr,
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()),
ArgBuffer.size()));
- return Result;
+
+ if (!Result)
+ OnComplete(shared::WrapperFunctionResult::createOutOfBandError(
+ toString(Result.takeError())));
+ OnComplete(std::move(*Result));
}
Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {
namespace detail {
template <typename SPSArgListT, typename... ArgTs>
-Expected<WrapperFunctionResult>
+WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
WrapperFunctionResult Result;
char *DataPtr =
WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
SPSOutputBuffer OB(DataPtr, Result.size());
if (!SPSArgListT::serialize(OB, Args...))
- return make_error<StringError>(
- "Error serializing arguments to blob in call",
- inconvertibleErrorCode());
- return std::move(Result);
+ return WrapperFunctionResult::createOutOfBandError(
+ "Error serializing arguments to blob in call");
+ return Result;
}
template <typename RetT> class WrapperFunctionHandlerCaller {
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
std::forward<HandlerT>(H), Args, ArgIndices{});
- if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
- std::move(HandlerResult)))
- return std::move(*Result);
- else
- return WrapperFunctionResult::createOutOfBandError(
- toString(Result.takeError()));
+ return ResultSerializer<decltype(HandlerResult)>::serialize(
+ std::move(HandlerResult));
}
private:
}
};
-// Map function references to function types.
+// Map function pointers to function types.
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
-class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
+class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
+template <typename WrapperFunctionImplT,
+ template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper
+ : public WrapperFunctionAsyncHandlerHelper<
+ decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
+ ResultSerializer, SPSTagTs...> {};
+
+template <typename RetT, typename SendResultT, typename... ArgTs,
+ template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
+ ResultSerializer, SPSTagTs...> {
+public:
+ using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
+ using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
+
+ template <typename HandlerT, typename SendWrapperFunctionResultT>
+ static void applyAsync(HandlerT &&H,
+ SendWrapperFunctionResultT &&SendWrapperFunctionResult,
+ const char *ArgData, size_t ArgSize) {
+ ArgTuple Args;
+ if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
+ SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
+ "Could not deserialize arguments for wrapper function call"));
+ return;
+ }
+
+ auto SendResult =
+ [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
+ using ResultT = decltype(Result);
+ SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
+ };
+
+ callAsync(std::forward<HandlerT>(H), std::move(SendResult), Args,
+ ArgIndices{});
+ }
+
+private:
+ template <std::size_t... I>
+ static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
+ std::index_sequence<I...>) {
+ SPSInputBuffer IB(ArgData, ArgSize);
+ return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
+ }
+
+ template <typename HandlerT, typename SerializeAndSendResultT,
+ typename ArgTupleT, std::size_t... I>
+ static void callAsync(HandlerT &&H,
+ SerializeAndSendResultT &&SerializeAndSendResult,
+ ArgTupleT &Args, std::index_sequence<I...>) {
+ return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
+ std::get<I>(Args)...);
+ }
+};
+
+// Map function pointers to function types.
+template <typename RetT, typename... ArgTs,
+ template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
+ SPSTagTs...>
+ : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+ SPSTagTs...> {};
+
+// Map non-const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+ template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
+ ResultSerializer, SPSTagTs...>
+ : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+ SPSTagTs...> {};
+
+// Map const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+ template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
+ ResultSerializer, SPSTagTs...>
+ : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+ SPSTagTs...> {};
+
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
public:
- static Expected<WrapperFunctionResult> serialize(RetT Result) {
+ static WrapperFunctionResult serialize(RetT Result) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
Result);
}
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
public:
- static Expected<WrapperFunctionResult> serialize(Error Err) {
+ static WrapperFunctionResult serialize(Error Err) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
template <typename SPSRetTagT, typename T>
class ResultSerializer<SPSRetTagT, Expected<T>> {
public:
- static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
+ static WrapperFunctionResult serialize(Expected<T> E) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(E)));
}
template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
public:
+ static RetT makeValue() { return RetT(); }
static void makeSafe(RetT &Result) {}
static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
template <> class ResultDeserializer<SPSError, Error> {
public:
+ static Error makeValue() { return Error::success(); }
static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
template <typename SPSTagT, typename T>
class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
public:
+ static Expected<T> makeValue() { return T(); }
static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
static Error deserialize(Expected<T> &E, const char *ArgData,
}
};
+template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
+ // Did you forget to use Error / Expected in your handler?
+};
+
} // end namespace detail
template <typename SPSSignature> class WrapperFunction;
using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
public:
- /// Call a wrapper function. Callere should be callable as
+ /// Call a wrapper function. Caller should be callable as
/// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
template <typename CallerFn, typename RetT, typename... ArgTs>
static Error call(const CallerFn &Caller, RetT &Result,
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
- if (!ArgBuffer)
- return ArgBuffer.takeError();
-
- Expected<WrapperFunctionResult> ResultBuffer =
- Caller(ArgBuffer->data(), ArgBuffer->size());
- if (!ResultBuffer)
- return ResultBuffer.takeError();
- if (auto ErrMsg = ResultBuffer->getOutOfBandError())
+ if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
+ return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
+
+ WrapperFunctionResult ResultBuffer =
+ Caller(ArgBuffer.data(), ArgBuffer.size());
+ if (auto ErrMsg = ResultBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
- Result, ResultBuffer->data(), ResultBuffer->size());
+ Result, ResultBuffer.data(), ResultBuffer.size());
+ }
+
+ /// Call an async wrapper function.
+ /// Caller should be callable as
+ /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
+ /// WrapperFunctionResult ArgBuffer);
+ template <typename AsyncCallerFn, typename SendDeserializedResultFn,
+ typename... ArgTs>
+ static void callAsync(AsyncCallerFn &&Caller,
+ SendDeserializedResultFn &&SendDeserializedResult,
+ const ArgTs &...Args) {
+ using RetT = typename std::tuple_element<
+ 1, typename detail::WrapperFunctionHandlerHelper<
+ std::remove_reference_t<SendDeserializedResultFn>,
+ ResultSerializer, SPSRetTagT>::ArgTuple>::type;
+
+ auto ArgBuffer =
+ detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
+ Args...);
+ if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
+ SendDeserializedResult(
+ make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
+ detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
+ return;
+ }
+
+ auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
+ WrapperFunctionResult R) {
+ RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
+ detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
+
+ SPSInputBuffer IB(R.data(), R.size());
+ if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
+ RetVal, R.data(), R.size()))
+ SDR(std::move(Err), std::move(RetVal));
+
+ SDR(Error::success(), std::move(RetVal));
+ };
+
+ Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
}
/// Handle a call to a wrapper function.
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
HandlerT &&Handler) {
using WFHH =
- detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
- SPSTagTs...>;
+ detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
+ ResultSerializer, SPSTagTs...>;
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
}
+ /// Handle a call to an async wrapper function.
+ template <typename HandlerT, typename SendResultT>
+ static void handleAsync(const char *ArgData, size_t ArgSize,
+ HandlerT &&Handler, SendResultT &&SendResult) {
+ using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
+ std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
+ WFAHH::applyAsync(std::forward<HandlerT>(Handler),
+ std::forward<SendResultT>(SendResult), ArgData, ArgSize);
+ }
+
private:
template <typename T> static const T &makeSerializable(const T &Value) {
return Value;
template <typename... SPSTagTs>
class WrapperFunction<void(SPSTagTs...)>
: private WrapperFunction<SPSEmpty(SPSTagTs...)> {
+
public:
template <typename CallerFn, typename... ArgTs>
static Error call(const CallerFn &Caller, const ArgTs &...Args) {
}
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
+ using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
};
} // end namespace shared
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/Process.h"
-#include <mutex>
-
namespace llvm {
namespace orc {
ExecutorProcessControl::~ExecutorProcessControl() {}
+Error ExecutorProcessControl::associateJITSideWrapperFunctions(
+ JITDylib &JD, WrapperFunctionAssociationMap WFs) {
+
+ // Look up tag addresses.
+ auto &ES = JD.getExecutionSession();
+ auto TagAddrs =
+ ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}},
+ SymbolLookupSet::fromMapKeys(
+ WFs, SymbolLookupFlags::WeaklyReferencedSymbol));
+ if (!TagAddrs)
+ return TagAddrs.takeError();
+
+ // Associate tag addresses with implementations.
+ std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+ for (auto &KV : *TagAddrs) {
+ auto TagAddr = KV.second.getAddress();
+ if (TagToFunc.count(TagAddr))
+ return make_error<StringError>("Tag " + formatv("{0:x16}", TagAddr) +
+ " (for " + *KV.first +
+ ") already registered",
+ inconvertibleErrorCode());
+ auto I = WFs.find(KV.first);
+ assert(I != WFs.end() && I->second &&
+ "AsyncWrapperFunction implementation missing");
+ TagToFunc[KV.second.getAddress()] =
+ std::make_shared<AsyncWrapperFunction>(std::move(I->second));
+ }
+ return Error::success();
+}
+
+void ExecutorProcessControl::runJITSideWrapperFunction(
+ SendResultFunction SendResult, JITTargetAddress TagAddr,
+ ArrayRef<char> ArgBuffer) {
+
+ std::shared_ptr<AsyncWrapperFunction> F;
+ {
+ std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+ auto I = TagToFunc.find(TagAddr);
+ if (I != TagToFunc.end())
+ F = I->second;
+ }
+
+ if (F)
+ (*F)(std::move(SendResult), ArgBuffer.data(), ArgBuffer.size());
+ else
+ SendResult(shared::WrapperFunctionResult::createOutOfBandError(
+ ("No function registered for tag " + formatv("{0:x16}", TagAddr))
+ .str()));
+}
+
SelfExecutorProcessControl::SelfExecutorProcessControl(
std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
return orc::runAsMain(jitTargetAddressToFunction<MainTy>(MainFnAddr), Args);
}
-Expected<shared::WrapperFunctionResult>
-SelfExecutorProcessControl::runWrapper(JITTargetAddress WrapperFnAddr,
- ArrayRef<char> ArgBuffer) {
- using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)(
- const char *Data, uint64_t Size);
+void SelfExecutorProcessControl::runWrapperAsync(SendResultFunction SendResult,
+ JITTargetAddress WrapperFnAddr,
+ ArrayRef<char> ArgBuffer) {
+ using WrapperFnTy =
+ shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size);
auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr);
- return WrapperFn(ArgBuffer.data(), ArgBuffer.size());
+ SendResult(WrapperFn(ArgBuffer.data(), ArgBuffer.size()));
}
Error SelfExecutorProcessControl::disconnect() { return Error::success(); }
add_llvm_unittest(OrcJITTests
CoreAPIsTest.cpp
+ ExecutorProcessControlTest.cpp
IndirectionUtilsTest.cpp
JITTargetMachineBuilderTest.cpp
LazyCallThroughAndReexportsTest.cpp
--- /dev/null
+//===- ExecutorProcessControlTest.cpp - Test ExecutorProcessControl utils -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
+#include "llvm/Support/MSVCErrorWorkarounds.h"
+#include "llvm/Testing/Support/Error.h"
+#include "gtest/gtest.h"
+
+#include <future>
+
+using namespace llvm;
+using namespace llvm::orc;
+using namespace llvm::orc::shared;
+
+static llvm::orc::shared::detail::CWrapperFunctionResult
+addWrapper(const char *ArgData, size_t ArgSize) {
+ return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
+ ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
+ .release();
+}
+
+static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
+ int32_t X, int32_t Y) {
+ SendResult(X + Y);
+}
+
+TEST(ExecutorProcessControl, RunWrapperTemplate) {
+ auto EPC = cantFail(
+ SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+ int32_t Result;
+ EXPECT_THAT_ERROR(EPC->runSPSWrapper<int32_t(int32_t, int32_t)>(
+ pointerToJITTargetAddress(addWrapper), Result, 2, 3),
+ Succeeded());
+ EXPECT_EQ(Result, 5);
+}
+
+TEST(ExecutorProcessControl, RunWrapperAsyncTemplate) {
+ auto EPC = cantFail(
+ SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+ std::promise<MSVCPExpected<int32_t>> RP;
+ using Sig = int32_t(int32_t, int32_t);
+ EPC->runSPSWrapperAsync<Sig>(
+ [&](Error SerializationErr, int32_t R) {
+ if (SerializationErr)
+ RP.set_value(std::move(SerializationErr));
+ RP.set_value(std::move(R));
+ },
+ pointerToJITTargetAddress(addWrapper), 2, 3);
+ Expected<int32_t> Result = RP.get_future().get();
+ EXPECT_THAT_EXPECTED(Result, HasValue(5));
+}
+
+TEST(ExecutorProcessControl, RegisterAsyncHandlerAndRun) {
+
+ constexpr JITTargetAddress AddAsyncTagAddr = 0x01;
+
+ auto EPC = cantFail(
+ SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+ ExecutionSession ES(EPC->getSymbolStringPool());
+ auto &JD = ES.createBareJITDylib("JD");
+
+ auto AddAsyncTag = ES.intern("addAsync_tag");
+ cantFail(JD.define(absoluteSymbols(
+ {{AddAsyncTag,
+ JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}})));
+
+ ExecutorProcessControl::WrapperFunctionAssociationMap Associations;
+
+ Associations[AddAsyncTag] =
+ EPC->wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
+
+ cantFail(EPC->associateJITSideWrapperFunctions(JD, std::move(Associations)));
+
+ std::promise<int32_t> RP;
+ auto RF = RP.get_future();
+
+ using ArgSerialization = SPSArgList<int32_t, int32_t>;
+ size_t ArgBufferSize = ArgSerialization::size(1, 2);
+ WrapperFunctionResult ArgBuffer;
+ char *ArgBufferData =
+ WrapperFunctionResult::allocate(ArgBuffer, ArgBufferSize);
+ SPSOutputBuffer OB(ArgBufferData, ArgBufferSize);
+ EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
+
+ EPC->runJITSideWrapperFunction(
+ [&](WrapperFunctionResult ResultBuffer) {
+ int32_t Result;
+ SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
+ EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
+ RP.set_value(Result);
+ },
+ AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
+
+ EXPECT_EQ(RF.get(), (int32_t)3);
+
+ cantFail(ES.endSession());
+}
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
+#include "llvm/ADT/FunctionExtras.h"
#include "gtest/gtest.h"
+#include <future>
+
using namespace llvm;
using namespace llvm::orc::shared;
ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; });
}
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleVoid) {
EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopWrapper));
}
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleRet) {
int32_t Result;
EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
addWrapper, Result, 1, 2));
EXPECT_EQ(Result, (int32_t)3);
}
+
+static void voidNoopAsync(unique_function<void(SPSEmpty)> SendResult) {
+ SendResult(SPSEmpty());
+}
+
+static WrapperFunctionResult voidNoopAsyncWrapper(const char *ArgData,
+ size_t ArgSize) {
+ std::promise<WrapperFunctionResult> RP;
+ auto RF = RP.get_future();
+
+ WrapperFunction<void()>::handleAsync(
+ ArgData, ArgSize, voidNoopAsync,
+ [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+
+ return RF.get();
+}
+
+static WrapperFunctionResult addAsyncWrapper(const char *ArgData,
+ size_t ArgSize) {
+ std::promise<WrapperFunctionResult> RP;
+ auto RF = RP.get_future();
+
+ WrapperFunction<int32_t(int32_t, int32_t)>::handleAsync(
+ ArgData, ArgSize,
+ [](unique_function<void(int32_t)> SendResult, int32_t X, int32_t Y) {
+ SendResult(X + Y);
+ },
+ [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+ return RF.get();
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncVoid) {
+ EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopAsyncWrapper));
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncRet) {
+ int32_t Result;
+ EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
+ addAsyncWrapper, Result, 1, 2));
+ EXPECT_EQ(Result, (int32_t)3);
+}