EventHandler(support::RingBuffer* reader,
support::RingBuffer* writer,
std::string name,
- std::string* remote_key)
+ std::string* remote_key,
+ std::function<void()> flush_writer)
: reader_(reader),
writer_(writer),
name_(name),
- remote_key_(remote_key) {
+ remote_key_(remote_key),
+ flush_writer_(flush_writer) {
this->Clear();
if (*remote_key == "%toinit") {
/*!
* \brief Enter the io loop until the next event.
* \param client_mode Whether we are in the client.
+ * \param async_server_mode Whether we are in the async server mode.
* \param setreturn The function to set the return value encoding.
* \return The function to set return values when there is a return event.
*/
- RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) {
+ RPCCode HandleNextEvent(bool client_mode,
+ bool async_server_mode,
+ RPCSession::FEncodeReturn setreturn) {
std::swap(client_mode_, client_mode);
+ std::swap(async_server_mode_, async_server_mode);
- while (this->Ready()) {
+ RPCCode status = RPCCode::kNone;
+
+ while (status == RPCCode::kNone &&
+ state_ != kWaitForAsyncCallback &&
+ this->Ready()) {
switch (state_) {
case kInitHeader: HandleInitHeader(); break;
case kRecvPacketNumBytes: {
this->HandleProcessPacket(setreturn);
break;
}
+ case kWaitForAsyncCallback: {
+ break;
+ }
case kReturnReceived: {
this->SwitchToState(kRecvPacketNumBytes);
- std::swap(client_mode_, client_mode);
- return RPCCode::kReturn;
+ status = RPCCode::kReturn;
+ break;
}
case kCopyAckReceived: {
- std::swap(client_mode_, client_mode);
- return RPCCode::kCopyAck;
+ status = RPCCode::kCopyAck;
+ break;
}
case kShutdownReceived: {
- std::swap(client_mode_, client_mode);
- return RPCCode::kShutdown;
+ status = RPCCode::kShutdown;
}
}
}
+
+ std::swap(async_server_mode_, async_server_mode);
std::swap(client_mode_, client_mode);
- return RPCCode::kNone;
+ return status;
}
/*! \brief Clear all the states in the Handler.*/
kInitHeader,
kRecvPacketNumBytes,
kProcessPacket,
+ kWaitForAsyncCallback,
kReturnReceived,
kCopyAckReceived,
kShutdownReceived
bool init_header_step_{0};
// Whether current handler is client or server mode.
bool client_mode_{false};
+ // Whether current handler is in the async server mode.
+ bool async_server_mode_{false};
// Internal arena
support::Arena arena_;
CHECK_EQ(pending_request_bytes_, 0U)
<< "state=" << state;
}
+ // need to actively flush the writer
+ // so the data get pushed out.
+ if (state_ == kWaitForAsyncCallback) {
+ flush_writer_();
+ }
state_ = state;
CHECK(state != kInitHeader)
<< "cannot switch to init header";
this->Read(&type_hint);
size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
- char* data_ptr;
auto* sess = GetServingSession();
+ // Return Copy Ack with the given data
+ auto fcopyack = [this](char* data_ptr, size_t num_bytes) {
+ RPCCode code = RPCCode::kCopyAck;
+ uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+ this->Write(packet_nbytes);
+ this->Write(code);
+ this->WriteArray(data_ptr, num_bytes);
+ this->SwitchToState(kRecvPacketNumBytes);
+ };
+
// When session is local, we can directly treat handle
// as the cpu pointer without allocating a temp space.
if (ctx.device_type == kDLCPU &&
sess->IsLocalSession() &&
DMLC_IO_NO_ENDIAN_SWAP) {
- data_ptr = reinterpret_cast<char*>(handle) + offset;
+ char* data_ptr = reinterpret_cast<char*>(handle) + offset;
+ fcopyack(data_ptr, num_bytes);
} else {
- try {
- data_ptr = this->ArenaAlloc<char>(num_bytes);
- sess->CopyFromRemote(
- reinterpret_cast<void*>(handle), offset,
- data_ptr, 0,
- num_bytes, ctx, type_hint);
- // endian aware handling
- if (!DMLC_IO_NO_ENDIAN_SWAP) {
- dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes);
+ char* data_ptr = this->ArenaAlloc<char>(num_bytes);
+
+ auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](
+ RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ this->SwitchToState(kRecvPacketNumBytes);
+ } else {
+ // endian aware handling
+ if (!DMLC_IO_NO_ENDIAN_SWAP) {
+ dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes);
+ }
+ fcopyack(data_ptr, num_bytes);
}
- } catch (const std::runtime_error &e) {
- this->ReturnException(e.what());
- this->SwitchToState(kRecvPacketNumBytes);
- return;
- }
+ };
+
+ this->SwitchToState(kWaitForAsyncCallback);
+ sess->AsyncCopyFromRemote(
+ reinterpret_cast<void*>(handle), offset,
+ data_ptr, 0,
+ num_bytes, ctx, type_hint,
+ on_copy_complete);
}
- RPCCode code = RPCCode::kCopyAck;
- uint64_t packet_nbytes = sizeof(code) + num_bytes;
-
- // Return Copy Ack
- this->Write(packet_nbytes);
- this->Write(code);
- this->WriteArray(data_ptr, num_bytes);
-
- this->SwitchToState(kRecvPacketNumBytes);
}
void HandleCopyToRemote() {
char* dptr = reinterpret_cast<char*>(handle) + offset;
this->ReadArray(dptr, num_bytes);
- if (!DMLC_IO_NO_ENDIAN_SWAP) {
- dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
- }
+ if (!DMLC_IO_NO_ENDIAN_SWAP) {
+ dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
+ }
+ this->ReturnVoid();
+ this->SwitchToState(kRecvPacketNumBytes);
} else {
char* temp_data = this->ArenaAlloc<char>(num_bytes);
this->ReadArray(temp_data, num_bytes);
dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes);
}
- try {
- sess->CopyToRemote(
+ auto on_copy_complete = [this](RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ this->SwitchToState(kRecvPacketNumBytes);
+ } else {
+ this->ReturnVoid();
+ this->SwitchToState(kRecvPacketNumBytes);
+ }
+ };
+
+ this->SwitchToState(kWaitForAsyncCallback);
+ sess->AsyncCopyToRemote(
temp_data, 0,
reinterpret_cast<void*>(handle), offset,
- num_bytes, ctx, type_hint);
- } catch (const std::runtime_error &e) {
- this->ReturnException(e.what());
- this->SwitchToState(kRecvPacketNumBytes);
- return;
- }
+ num_bytes, ctx, type_hint,
+ on_copy_complete);
}
-
- this->ReturnVoid();
- this->SwitchToState(kRecvPacketNumBytes);
}
// Handle for packed call.
this->Read(&call_handle);
TVMArgs args = RecvPackedSeq();
- try {
- GetServingSession()->CallFunc(
- reinterpret_cast<void*>(call_handle),
- args.values, args.type_codes, args.size(),
- [this](TVMArgs ret) { this->ReturnPackedSeq(ret); });
- } catch (const std::runtime_error& e) {
- this->ReturnException(e.what());
- }
-
- this->SwitchToState(kRecvPacketNumBytes);
+ this->SwitchToState(kWaitForAsyncCallback);
+ GetServingSession()->AsyncCallFunc(
+ reinterpret_cast<void*>(call_handle),
+ args.values, args.type_codes, args.size(),
+ [this](RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ } else {
+ this->ReturnPackedSeq(args);
+ }
+ this->SwitchToState(kRecvPacketNumBytes);
+ });
}
void HandleInitServer() {
<< " server protocol=" << server_protocol_ver
<< ", client protocol=" << client_protocol_ver;
+ std::string constructor_name;
+ TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0);
+
if (args.size() == 0) {
+ constructor_name = "rpc.LocalSession";
serving_session_ = std::make_shared<LocalSession>();
} else {
- std::string constructor_name = args[0];
- auto* fconstructor = Registry::Get(constructor_name);
- CHECK(fconstructor != nullptr)
- << " Cannot find session constructor " << constructor_name;
- TVMRetValue con_ret;
-
- try {
- fconstructor->CallPacked(
- TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), &con_ret);
- } catch (const dmlc::Error& e) {
- LOG(FATAL) << "Server[" << name_ << "]:"
- << " Error caught from session constructor " << constructor_name
- << ":\n" << e.what();
- }
+ constructor_name = args[0].operator std::string();
+ constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1);
+ }
+
+ auto* fconstructor = Registry::Get(constructor_name);
+ CHECK(fconstructor != nullptr)
+ << " Cannot find session constructor " << constructor_name;
+ TVMRetValue con_ret;
- CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
- << "Server[" << name_ << "]:"
- << " Constructor " << constructor_name
- << " need to return an RPCModule";
- Module mod = con_ret;
- std::string tkey = mod->type_key();
- CHECK_EQ(tkey, "rpc")
- << "Constructor " << constructor_name << " to return an RPCModule";
- serving_session_ = RPCModuleGetSession(mod);
+ try {
+ fconstructor->CallPacked(constructor_args, &con_ret);
+ } catch (const dmlc::Error& e) {
+ LOG(FATAL) << "Server[" << name_ << "]:"
+ << " Error caught from session constructor " << constructor_name
+ << ":\n" << e.what();
}
+ CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
+ << "Server[" << name_ << "]:"
+ << " Constructor " << constructor_name
+ << " need to return an RPCModule";
+ Module mod = con_ret;
+ std::string tkey = mod->type_key();
+ CHECK_EQ(tkey, "rpc")
+ << "Constructor " << constructor_name << " to return an RPCModule";
+ serving_session_ = RPCModuleGetSession(mod);
this->ReturnVoid();
} catch (const std::runtime_error &e) {
this->ReturnException(e.what());
this->SwitchToState(kRecvPacketNumBytes);
}
+ void HandleSyscallStreamSync() {
+ TVMArgs args = RecvPackedSeq();
+ try {
+ TVMContext ctx = args[0];
+ TVMStreamHandle handle = args[1];
+
+ this->SwitchToState(kWaitForAsyncCallback);
+ GetServingSession()->AsyncStreamWait(
+ ctx, handle, [this](RPCCode status, TVMArgs args) {
+ if (status == RPCCode::kException) {
+ this->ReturnException(args.values[0].v_str);
+ } else {
+ this->ReturnVoid();
+ }
+ this->SwitchToState(kRecvPacketNumBytes);
+ });
+ } catch (const std::runtime_error& e) {
+ this->ReturnException(e.what());
+ this->SwitchToState(kRecvPacketNumBytes);
+ }
+ }
+
// Handler for special syscalls that have a specific RPCCode.
template<typename F>
void SysCallHandler(F f) {
RPCSession* GetServingSession() const {
CHECK(serving_session_ != nullptr)
<< "Need to call InitRemoteSession first before any further actions";
+ CHECK(!serving_session_->IsAsync() || async_server_mode_)
+ << "Cannot host an async session in a non-Event driven server";
+
return serving_session_.get();
}
// Utility functions
std::string name_;
// remote key
std::string* remote_key_;
+ // function to flush the writer.
+ std::function<void()> flush_writer_;
};
RPCCode RPCEndpoint::HandleUntilReturnEvent(
- bool client_mode, RPCSession::FEncodeReturn setreturn) {
+ bool client_mode,
+ RPCSession::FEncodeReturn setreturn) {
RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn &&
code != RPCCode::kShutdown &&
}
}
}
- code = handler_->HandleNextEvent(client_mode, setreturn);
+ code = handler_->HandleNextEvent(client_mode, false, setreturn);
}
return code;
}
void RPCEndpoint::Init() {
+ // callback to flush the writer.
+ auto flush_writer = [this]() {
+ while (writer_.bytes_available() != 0) {
+ size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
+ return channel_->Send(data, size);
+ }, writer_.bytes_available());
+ if (n == 0) break;
+ }
+ };
+
// Event handler
handler_ = std::make_shared<EventHandler>(
- &reader_, &writer_, name_, &remote_key_);
+ &reader_, &writer_, name_, &remote_key_, flush_writer);
+
// Quick function to for syscall remote.
syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) {
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kNone;
if (in_bytes.length() != 0) {
reader_.Write(in_bytes.c_str(), in_bytes.length());
- code = handler_->HandleNextEvent(false, [](TVMArgs) {});
+ code = handler_->HandleNextEvent(false, true, [](TVMArgs) {});
}
if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) {
handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr);
}
-void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
- TVMContext ctx = args[0];
- TVMStreamHandle handle = args[1];
- handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle);
-}
-
void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
void* from = args[0];
uint64_t from_offset = args[1];
case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break;
case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break;
case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break;
- case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break;
+ case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break;
case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break;
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
}
- CHECK_EQ(state_, kRecvPacketNumBytes);
+ if (state_ != kWaitForAsyncCallback) {
+ CHECK_EQ(state_, kRecvPacketNumBytes);
+ }
}
/*!
namespace tvm {
namespace runtime {
+bool RPCSession::IsAsync() const {
+ return false;
+}
+
+void RPCSession::SendException(FAsyncCallback callback, const char* msg) {
+ TVMValue value;
+ value.v_str = msg;
+ int32_t tcode = kTVMStr;
+ callback(RPCCode::kException, TVMArgs(&value, &tcode, 1));
+}
+
+void RPCSession::AsyncCallFunc(PackedFuncHandle func,
+ const TVMValue* arg_values,
+ const int* arg_type_codes,
+ int num_args,
+ FAsyncCallback callback) {
+ try {
+ this->CallFunc(func, arg_values, arg_type_codes, num_args,
+ [&callback](TVMArgs args) {
+ callback(RPCCode::kReturn, args);
+ });
+ } catch (const std::runtime_error& e) {
+ this->SendException(callback, e.what());
+ }
+}
+
+
+void RPCSession::AsyncCopyToRemote(void* local_from,
+ size_t local_from_offset,
+ void* remote_to,
+ size_t remote_to_offset,
+ size_t nbytes,
+ TVMContext remote_ctx_to,
+ DLDataType type_hint,
+ RPCSession::FAsyncCallback callback) {
+ TVMValue value;
+ int32_t tcode = kTVMNullptr;
+ value.v_handle = nullptr;
+
+ try {
+ this->CopyToRemote(local_from, local_from_offset,
+ remote_to, remote_to_offset,
+ nbytes, remote_ctx_to, type_hint);
+ callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+ } catch (const std::runtime_error& e) {
+ this->SendException(callback, e.what());
+ }
+}
+
+void RPCSession::AsyncCopyFromRemote(void* remote_from,
+ size_t remote_from_offset,
+ void* local_to,
+ size_t local_to_offset,
+ size_t nbytes,
+ TVMContext remote_ctx_from,
+ DLDataType type_hint,
+ RPCSession::FAsyncCallback callback) {
+ TVMValue value;
+ int32_t tcode = kTVMNullptr;
+ value.v_handle = nullptr;
+
+ try {
+ this->CopyFromRemote(remote_from, remote_from_offset,
+ local_to, local_to_offset,
+ nbytes, remote_ctx_from, type_hint);
+ callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+ } catch (const std::runtime_error& e) {
+ this->SendException(callback, e.what());
+ }
+}
+
+void RPCSession::AsyncStreamWait(TVMContext ctx,
+ TVMStreamHandle stream,
+ RPCSession::FAsyncCallback callback) {
+ TVMValue value;
+ int32_t tcode = kTVMNullptr;
+ value.v_handle = nullptr;
+
+ try {
+ this->GetDeviceAPI(ctx)->StreamSync(ctx, stream);
+ callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+ } catch (const std::runtime_error& e) {
+ this->SendException(callback, e.what());
+ }
+}
+
+
class RPCSessTable {
public:
static constexpr int kMaxRPCSession = 32;
#include <functional>
#include <memory>
#include <string>
+#include "rpc_protocol.h"
namespace tvm {
namespace runtime {
* \brief Callback to send an encoded return values via encode_args.
*
* \param encode_args The arguments that we can encode the return values into.
- * \param ret_tcode The actual remote type code of the return value.
*
* Encoding convention (as list of arguments):
* - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention.
*/
using FEncodeReturn = std::function<void(TVMArgs encoded_args)>;
+ /*!
+ * \brief Callback to send an encoded return values via encode_args.
+ *
+ * \param status The return status, can be RPCCode::kReturn or RPCCode::kException.
+ * \param encode_args The arguments that we can encode the return values into.
+ */
+ using FAsyncCallback = std::function<void(RPCCode status, TVMArgs encoded_args)>;
+
/*! \brief Destructor.*/
virtual ~RPCSession() {}
*/
virtual bool IsLocalSession() const = 0;
+ // Asynchrous variant of API
+ // These APIs are used by the RPC server to allow sessions that
+ // have special implementations for the async functions.
+ //
+ // In the async APIs, an exception is returned by the passing
+ // async_error=true, encode_args=[error_msg].
+
+ /*!
+ * \brief Whether the session is async.
+ *
+ * If the session is not async, its Aync implementations
+ * simply calls into the their synchronize counterparts,
+ * and the callback is guaranteed to be called before the async function finishes.
+ *
+ * \return the async state.
+ *
+ * \note We can only use async session in an Event driven RPC server.
+ */
+ virtual bool IsAsync() const;
+
+ /*!
+ * \brief Asynchrously call func.
+ * \param func The function handle.
+ * \param arg_values The argument values.
+ * \param arg_type_codes the type codes of the argument.
+ * \param num_args Number of arguments.
+ *
+ * \param callback The callback to pass the return value or exception.
+ */
+ virtual void AsyncCallFunc(PackedFuncHandle func,
+ const TVMValue* arg_values,
+ const int* arg_type_codes,
+ int num_args,
+ FAsyncCallback callback);
+
+ /*!
+ * \brief Asynchrous version of CopyToRemote.
+ *
+ * \param local_from The source host data.
+ * \param local_from_offset The byte offeset in the from.
+ * \param remote_to The target array.
+ * \param remote_to_offset The byte offset in the to.
+ * \param nbytes The size of the memory in bytes.
+ * \param remote_ctx_to The target context.
+ * \param type_hint Hint of content data type.
+ *
+ * \param on_complete The callback to signal copy complete.
+ * \note All the allocated memory in local_from, and remote_to
+ * must stay alive until on_compelete is called.
+ */
+ virtual void AsyncCopyToRemote(void* local_from,
+ size_t local_from_offset,
+ void* remote_to,
+ size_t remote_to_offset,
+ size_t nbytes,
+ TVMContext remote_ctx_to,
+ DLDataType type_hint,
+ FAsyncCallback on_complete);
+
+ /*!
+ * \brief Asynchrous version of CopyFromRemote.
+ *
+ * \param remote_from The source host data.
+ * \param remote_from_offset The byte offeset in the from.
+ * \param to The target array.
+ * \param to_offset The byte offset in the to.
+ * \param nbytes The size of the memory in bytes.
+ * \param remote_ctx_from The source context in the remote.
+ * \param type_hint Hint of content data type.
+ *
+ * \param on_complete The callback to signal copy complete.
+ * \note All the allocated memory in remote_from, and local_to
+ * must stay alive until on_compelete is called.
+ */
+ virtual void AsyncCopyFromRemote(void* remote_from,
+ size_t remote_from_offset,
+ void* local_to,
+ size_t local_to_offset,
+ size_t nbytes,
+ TVMContext remote_ctx_from,
+ DLDataType type_hint,
+ FAsyncCallback on_complete);
+ /*!
+ * \brief Asynchrously wait for all events in ctx, stream compeletes.
+ * \param ctx The device context.
+ * \param stream The stream to wait on.
+ * \param on_complete The callback to signal copy complete.
+ */
+ virtual void AsyncStreamWait(TVMContext ctx,
+ TVMStreamHandle stream,
+ FAsyncCallback on_compelte);
+
/*!
* \return The session table index of the session.
*/
*/
static std::shared_ptr<RPCSession> Get(int table_index);
+ protected:
+ /*!
+ * \brief Send an exception to the callback.
+ * \param msg The exception message.
+ */
+ void SendException(FAsyncCallback callback, const char* msg);
+
private:
/*! \brief index of this session in RPC session table */
int table_index_{0};