[REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra (#5484)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 5 May 2020 00:05:33 +0000 (17:05 -0700)
committerGitHub <noreply@github.com>
Tue, 5 May 2020 00:05:33 +0000 (17:05 -0700)
* Update dmlc-core which was mistakenly overriden

* [REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra.

This PR refactors the RPC protocol to make it more modularized.

- RPCSession: represent a set of features that need to be implemented
- RPCEndPont: End point that forwards the RPCSession requests over a communication channel.
- RPCModule: Exposes an RPCSession as an rpc device in the TVM Runtime API.

In the new design, the local machine is presented as a special case of RPCSession.
The remote is just another client session that calls into RPCEndPoint.
The RPC communication path is as follows.

```
client -> ClientSession -> EndPoint[client@n0]
-> networking[between n0 <=> n1]
-> EndPoint[server@n1] -> LocalSession[@n1]

```

Because of the new modular design, we can now chain more sessions together.
For example, we can now run the following proxy setup (testcase in test_runtime_rpc.test_session_constructor).

```
client -> ClientSession -> Endpoint[client@n0]
-> networking[between n0 <=> n1]
-> Endpoint[server@n1] -> ClientSession -> Endpoint[client@n1]
-> networking[between n1 <=> n2]
-> Endpoint[server@n2] -> LocalSession[@n2]
```

We can also implement other types of Sessions.
As an example, We introduced a PopenSession that communicates with
the another process via a pipe.

We also add more comments about the internal of the RPC.
The communication protocol is simplfied using a similar convention as PackedFunc.
This allows us to further reduce the amount of special remote syscalls.

Due to the major improvement and simplification, we are making a non-compatible update to the RPC protocol.
It means that the client and server needs to be upgraded to together in order for it to function correctly.

This PR also introduces a versioning mechanism to the current RPC procotol,
so that future upgrade will be produce more user friendly with error messages.

* Address review comments

* Remove ld library path

47 files changed:
.gitignore
3rdparty/dmlc-core
apps/cpp_rpc/rpc_server.cc
apps/cpp_rpc/rpc_tracker_client.h
include/tvm/runtime/c_runtime_api.h
include/tvm/runtime/device_api.h
jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java
jvm/core/src/main/java/org/apache/tvm/rpc/Client.java
jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java
jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java
python/tvm/_ffi/_ctypes/packed_func.py
python/tvm/_ffi/_cython/packed_func.pxi
python/tvm/_ffi/base.py
python/tvm/contrib/cc.py
python/tvm/contrib/graph_runtime.py
python/tvm/error.py
python/tvm/rpc/__init__.py
python/tvm/rpc/_ffi_api.py [new file with mode: 0644]
python/tvm/rpc/base.py
python/tvm/rpc/client.py
python/tvm/rpc/minrpc.py [new file with mode: 0644]
python/tvm/rpc/proxy.py
python/tvm/rpc/server.py
python/tvm/runtime/module.py
src/runtime/c_runtime_api.cc
src/runtime/module.cc
src/runtime/registry.cc
src/runtime/rpc/minrpc/minrpc_server.h [new file with mode: 0644]
src/runtime/rpc/minrpc/posix_popen_server.cc [new file with mode: 0644]
src/runtime/rpc/rpc_channel.cc [new file with mode: 0644]
src/runtime/rpc/rpc_channel.h [new file with mode: 0644]
src/runtime/rpc/rpc_device_api.cc
src/runtime/rpc/rpc_endpoint.cc [new file with mode: 0644]
src/runtime/rpc/rpc_endpoint.h [new file with mode: 0644]
src/runtime/rpc/rpc_event_impl.cc
src/runtime/rpc/rpc_local_session.cc [new file with mode: 0644]
src/runtime/rpc/rpc_local_session.h [new file with mode: 0644]
src/runtime/rpc/rpc_module.cc
src/runtime/rpc/rpc_pipe_impl.cc [new file with mode: 0644]
src/runtime/rpc/rpc_protocol.h [new file with mode: 0644]
src/runtime/rpc/rpc_server_env.cc
src/runtime/rpc/rpc_session.cc
src/runtime/rpc/rpc_session.h
src/runtime/rpc/rpc_socket_impl.cc
src/support/arena.h
tests/python/unittest/test_runtime_rpc.py
web/tvm_runtime.js

index 068cb87..1fcb2dc 100644 (file)
@@ -2,9 +2,10 @@
 __pycache__/
 *.py[cod]
 *$py.class
-
+*.S
 # C extensions
 *.so
+*.ll
 
 # Distribution / packaging
 .Python
index 808f485..ff3db43 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 808f485387f9a03f78fa9f1159f387d0d91b7a28
+Subproject commit ff3db4367a30f542aafb83b4af45e685b80102d0
index ea4ab00..57a68f4 100644 (file)
@@ -33,7 +33,7 @@
 #include <string>
 
 #include "../../src/support/socket.h"
-#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/runtime/rpc/rpc_endpoint.h"
 #include "../../src/runtime/rpc/rpc_socket_impl.h"
 #include "rpc_env.h"
 #include "rpc_server.h"
@@ -86,7 +86,7 @@ class RPCServer {
     tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
     custom_addr_(std::move(custom_addr))
   {
-    
+
   }
 
   /*!
@@ -98,7 +98,7 @@ class RPCServer {
       tracker_sock_.Close();
       listen_sock_.Close();
     } catch(...) {
-      
+
     }
   }
 
@@ -144,7 +144,7 @@ class RPCServer {
       }
 
       int timeout = GetTimeOutFromOpts(opts);
-#if defined(__linux__) || defined(__ANDROID__) 
+#if defined(__linux__) || defined(__ANDROID__)
       // step 3: serving
       if (timeout != 0) {
         const pid_t timer_pid = fork();
@@ -197,7 +197,7 @@ class RPCServer {
       try {
         SpawnRPCChild(conn.sockfd, seconds(timeout));
       } catch (const std::exception&) {
-        
+
       }
       auto dur = high_resolution_clock::now() - start_time;
 
@@ -217,10 +217,10 @@ class RPCServer {
    * \param opts Parsed options for socket
    * \param ping_period Timeout for select call waiting
    */
-  void AcceptConnection(TrackerClient* tracker, 
+  void AcceptConnection(TrackerClient* tracker,
                         support::TCPSocket* conn_sock,
-                        support::SockAddr* addr, 
-                        std::string* opts, 
+                        support::SockAddr* addr,
+                        std::string* opts,
                         int ping_period = 2) {
     std::set<std::string> old_keyset;
     std::string matchkey;
@@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) {
   tvm::support::TCPSocket sock(socket);
   const auto env = RPCEnv();
   RPCServerLoop(int(sock.sockfd));
-  
+
   sock.Close();
   env.CleanUp();
 }
@@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track
   rpc.Start();
 }
 
-TVM_REGISTER_GLOBAL("rpc._ServerCreate")
+TVM_REGISTER_GLOBAL("rpc.ServerCreate")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
     RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
   });
index dfd576f..112f7d2 100644 (file)
@@ -31,7 +31,7 @@
 #include <vector>
 #include <string>
 
-#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/runtime/rpc/rpc_endpoint.h"
 #include "../../src/support/socket.h"
 
 namespace tvm {
index 920ecfb..79bcdc6 100644 (file)
@@ -550,6 +550,54 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
  */
 TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
 
+/*!
+ * \brief Allocate a data space on device.
+ * \param ctx The device context to perform operation.
+ * \param nbytes The number of bytes in memory.
+ * \param alignment The alignment of the memory.
+ * \param type_hint The type of elements. Only needed by certain backends such
+ *                   as nbytes & alignment are sufficient for most backends.
+ * \param out_data The allocated device pointer.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx,
+                                    size_t nbytes,
+                                    size_t alignment,
+                                    DLDataType type_hint,
+                                    void** out_data);
+
+/*!
+ * \brief Free a data space on device.
+ * \param ctx The device context to perform operation.
+ * \param ptr The data space.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr);
+
+/*!
+ * \brief Copy data from one place to another.
+ * \param from The source array.
+ * \param from_offset The byte offeset in the from.
+ * \param to The target array.
+ * \param to_offset The byte offset in the to.
+ * \param num_bytes The size of the memory in bytes
+ * \param ctx_from The source context
+ * \param ctx_to The target context
+ * \param type_hint The type of elements, only neded by certain backends.
+ *                  can be useful for cross device endian converison.
+ * \param stream Optional stream object.
+ * \return 0 when success, -1 when failure happens.
+ */
+TVM_DLL int TVMDeviceCopyDataFromTo(const void* from,
+                                    size_t from_offset,
+                                    void* to,
+                                    size_t to_offset,
+                                    size_t num_bytes,
+                                    TVMContext ctx_from,
+                                    TVMContext ctx_to,
+                                    DLDataType type_hint,
+                                    TVMStreamHandle stream);
+
 #ifdef __cplusplus
 }  // TVM_EXTERN_C
 #endif
index f2ddc84..1206918 100644 (file)
@@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI {
    * \param event_dst The destination stream to synchronize.
    */
   virtual void SyncStreamFromTo(TVMContext ctx,
-                                        TVMStreamHandle event_src,
-                                        TVMStreamHandle event_dst);
 /*!
+                                TVMStreamHandle event_src,
+                                TVMStreamHandle event_dst);
+ /*!
    * \brief Allocate temporal workspace for backend execution.
    *
    *  \note We have the following assumption about backend temporal
@@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI {
    * as OpenGL, as nbytes is sufficient for most backends.
    */
   virtual void* AllocWorkspace(TVMContext ctx,
-                                       size_t nbytes,
-                                       DLDataType type_hint = {});
+                               size_t nbytes,
+                               DLDataType type_hint = {});
   /*!
    * \brief Free temporal workspace in backend execution.
    *
index c31c67f..61ff966 100644 (file)
@@ -38,53 +38,14 @@ public class GraphRuntime {
    * @return Runtime graph module that can be used to execute the graph.
    */
   public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) {
-    Module graphModule = null;
-    if (ctx.deviceType >= RPC.RPC_SESS_MASK) {
-      if (!(ctx instanceof  TVMRemoteContext)) {
-        throw new IllegalArgumentException(
-            "Looks like you are using remote context with no RPCSession bind."
-            + "Use session.context instead.");
-      }
-      RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession;
-      // check arguments
-      if (!"rpc".equals(libmod.typeKey())) {
-        throw new IllegalArgumentException("libmod.typeKey != rpc");
-      }
-      final int sessIndex = (int) ((Function) reflectionStaticCall(
-          RPC.class, "getApi", "_SessTableIndex"))
-          .pushArg(libmod).invoke().asLong();
-      if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
-        throw new IllegalArgumentException(String.format(
-            "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
-            sessIndex, reflectionGetField(rpcSession, "tblIndex")));
-      }
-
-      Function rpcModuleHandle = (Function) reflectionStaticCall(
-          RPC.class, "getApi","_ModuleHandle");
-      if (rpcModuleHandle == null) {
-        throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
-            + "Did you compile tvm_runtime with the correct version?");
-      }
-
-      Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
-      if (fcreate == null) {
-        throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
-            + "Did you compile tvm_runtime with correct version?");
-      }
-
-      TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke();
-      graphModule = fcreate.call(graphJson, hmod,
-          ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule();
-    } else {
-      Function fcreate = Function.getFunction("tvm.graph_runtime.create");
-      if (fcreate == null) {
-        throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
-            + "Did you compile tvm_runtime with correct version?");
-      }
-      graphModule = fcreate.pushArg(graphJson)
-          .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
-          .invoke().asModule();
+    Function fcreate = Function.getFunction("tvm.graph_runtime.create");
+    if (fcreate == null) {
+      throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+          + "Did you compile tvm_runtime with correct version?");
     }
+    Module graphModule = fcreate.pushArg(graphJson)
+        .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
+        .invoke().asModule();
 
     return new GraphModule(graphModule, ctx);
   }
index 5178ac9..69321c3 100644 (file)
@@ -29,7 +29,7 @@ public class Client {
    * @return The connected session.
    */
   public static RPCSession connect(String url, int port, String key) {
-    Function doConnect = RPC.getApi("_Connect");
+    Function doConnect = RPC.getApi("Connect");
     if (doConnect == null) {
       throw new RuntimeException("Please compile with USE_RPC=1");
     }
index 29a457f..1f3191f 100644 (file)
@@ -46,7 +46,7 @@ public class NativeServerLoop implements Runnable {
     try {
       tempDir = serverEnv();
       System.err.println("starting server loop...");
-      RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
+      RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
       System.err.println("done server loop...");
     } catch (IOException e) {
       e.printStackTrace();
index 92b3284..b9f6214 100644 (file)
@@ -39,7 +39,7 @@ public class RPCSession {
 
   RPCSession(Module sess) {
     session = sess;
-    tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
+    tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong();
   }
 
   /**
@@ -237,7 +237,7 @@ public class RPCSession {
    * @return The remote module containing remote function.
    */
   public Module loadModule(String path) {
-    return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
+    return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
   }
 
 
index dc2dc19..b17174a 100644 (file)
@@ -141,7 +141,13 @@ def _make_tvm_args(args, temp_args):
         elif isinstance(arg, TVMContext):
             values[i].v_int64 = _ctx_to_int64(arg)
             type_codes[i] = TypeCode.TVM_CONTEXT
-        elif isinstance(arg, bytearray):
+        elif isinstance(arg, (bytearray, bytes)):
+            # from_buffer only taeks in bytearray.
+            if isinstance(arg, bytes):
+                byte_arr = bytearray(arg)
+                temp_args.append(byte_arr)
+                arg = byte_arr
+
             arr = TVMByteArray()
             arr.data = ctypes.cast(
                 (ctypes.c_byte * len(arg)).from_buffer(arg),
index 1f68df1..45bcf64 100644 (file)
@@ -142,7 +142,13 @@ cdef inline int make_arg(object arg,
         value[0].v_ctx = (<DLContext*>(
             <unsigned long long>ctypes.addressof(arg)))[0]
         tcode[0] = kTVMContext
-    elif isinstance(arg, bytearray):
+    elif isinstance(arg, (bytes, bytearray)):
+        # from_buffer only taeks in bytearray.
+        if isinstance(arg, bytes):
+            byte_arr = bytearray(arg)
+            temp_args.append(byte_arr)
+            arg = byte_arr
+
         arr = TVMByteArray()
         arr.data = ctypes.cast(
             (ctypes.c_byte * len(arg)).from_buffer(arg),
index 8d3ce19..8674e31 100644 (file)
@@ -48,7 +48,6 @@ def _load_lib():
     """Load libary by searching possible path."""
     lib_path = libinfo.find_lib_path()
     lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
-    # DMatrix functions
     lib.TVMGetLastError.restype = ctypes.c_char_p
     return lib, os.path.basename(lib_path[0])
 
index ae37923..8ad47ac 100644 (file)
@@ -90,7 +90,8 @@ create_shared.get_target_triple = get_target_by_dump_machine(
 def cross_compiler(compile_func,
                    options=None,
                    output_format=None,
-                   get_target_triple=None):
+                   get_target_triple=None,
+                   add_files=None):
     """Create a cross compiler function by specializing compile_func with options.
 
     This function can be used to construct compile functions that
@@ -111,6 +112,10 @@ def cross_compiler(compile_func,
     get_target_triple: Optional[Callable]
         Function that can target triple according to dumpmachine option of compiler.
 
+    add_files: Optional[List[str]]
+        List of paths to additional object, source, library files
+        to pass as part of the compilation.
+
     Returns
     -------
     fcompile : Callable[[str, str, Optional[str]], None]
@@ -133,6 +138,7 @@ def cross_compiler(compile_func,
     """
     base_options = [] if options is None else options
     kwargs = {}
+    add_files = [] if add_files is None else add_files
 
     # handle case where compile_func is the name of the cc
     if isinstance(compile_func, str):
@@ -144,7 +150,7 @@ def cross_compiler(compile_func,
         all_options = base_options
         if options is not None:
             all_options += options
-        compile_func(outputs, objects, options=all_options, **kwargs)
+        compile_func(outputs, objects + add_files, options=all_options, **kwargs)
 
     if not output_format and hasattr(compile_func, "output_format"):
         output_format = compile_func.output_format
index 73235f7..740d1c3 100644 (file)
 import numpy as np
 import tvm._ffi
 
-from .._ffi.base import string_types
-from .._ffi.runtime_ctypes import TVMContext
-from ..rpc import base as rpc_base
+from tvm.rpc import _ffi_api as _rpc_ffi_api
+from tvm.rpc import base as rpc_base
+from tvm._ffi.base import string_types
+from tvm._ffi.runtime_ctypes import TVMContext
 
 
 def create(graph_json_str, libmod, ctx):
@@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx):
         device_type = cur_ctx.device_type
         if device_type >= rpc_base.RPC_SESS_MASK:
             assert libmod.type_key == "rpc"
-            assert rpc_base._SessTableIndex(
+            assert _rpc_ffi_api.SessTableIndex(
                 libmod) == cur_ctx._rpc_sess._tbl_index
             num_rpc_ctx += 1
             device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
index 4c3e606..b3502f6 100644 (file)
@@ -58,6 +58,11 @@ register_error("KeyError", KeyError)
 
 
 @register_error
+class RPCError(RuntimeError):
+    """Error thrown by the remote server handling the RPC call."""
+
+
+@register_error
 class OpError(TVMError):
     """Base class of all operator errors in frontends."""
 
index 5f959eb..b64ba33 100644 (file)
@@ -26,4 +26,6 @@ upload and run remote RPC server, get the result back to verify correctness.
 """
 
 from .server import Server
-from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker
+from .client import connect, connect_tracker
+from .client import RPCSession, LocalSession, PopenSession, TrackerSession
+from .minrpc import with_minrpc
diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py
new file mode 100644 (file)
index 0000000..1a7cc73
--- /dev/null
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.rpc"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("rpc", __name__)
index bc81534..f0e33f8 100644 (file)
@@ -17,8 +17,6 @@
 """Base definitions for RPC."""
 # pylint: disable=invalid-name
 
-from __future__ import absolute_import
-
 import socket
 import time
 import json
@@ -26,7 +24,6 @@ import errno
 import struct
 import random
 import logging
-import tvm._ffi
 
 from .._ffi.base import py_str
 
@@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
             logger.warning("Cannot connect to tracker %s, retry in %g secs...",
                            str(addr), retry_period)
             time.sleep(retry_period)
-
-
-# Still use tvm.rpc for the foreign functions
-tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
index ed57e0d..9997673 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """RPC client tools"""
-from __future__ import absolute_import
-
 import os
+import stat
 import socket
 import struct
 import time
+
 import tvm._ffi
 from tvm.contrib import util
 from tvm._ffi.base import TVMError
 from tvm.runtime import ndarray as nd
-from tvm.runtime import load_module as _load_module
 
 from . import base
+from . import server
+from . import _ffi_api
 
 
 class RPCSession(object):
@@ -38,9 +39,23 @@ class RPCSession(object):
     # pylint: disable=invalid-name
     def __init__(self, sess):
         self._sess = sess
-        self._tbl_index = base._SessTableIndex(sess)
+        self._tbl_index = _ffi_api.SessTableIndex(sess)
         self._remote_funcs = {}
 
+    def system_lib(self):
+        """Get system-wide library module.
+
+        Returns
+        -------
+        module : runtime.Module
+            The system-wide library module.
+
+        See Also
+        --------
+        tvm.runtime.system_lib
+        """
+        return self.get_function("runtime.SystemLib")()
+
     def get_function(self, name):
         """Get function from the session.
 
@@ -145,7 +160,7 @@ class RPCSession(object):
         m : Module
             The remote module containing remote function.
         """
-        return base._LoadRemoteModule(self._sess, path)
+        return _ffi_api.LoadRemoteModule(self._sess, path)
 
     def cpu(self, dev_id=0):
         """Construct CPU device."""
@@ -183,28 +198,41 @@ class LocalSession(RPCSession):
     need to be ran both locally and remotely.
     """
     def __init__(self):
-        # pylint: disable=super-init-not-called
-        self.context = nd.context
-        self.get_function = tvm._ffi.get_global_func
-        self._temp = util.tempdir()
+        self._temp = server._server_env([])
+        RPCSession.__init__(self, _ffi_api.LocalSession())
 
-    def upload(self, data, target=None):
-        if isinstance(data, bytearray):
-            if not target:
-                raise ValueError("target must present when file is a bytearray")
-            blob = data
-        else:
-            blob = bytearray(open(data, "rb").read())
-            if not target:
-                target = os.path.basename(data)
-        with open(self._temp.relpath(target), "wb") as f:
-            f.write(blob)
 
-    def download(self, path):
-        return bytearray(open(self._temp.relpath(path), "rb").read())
+@tvm._ffi.register_func("rpc.PopenSession")
+def _popen_session(binary):
+    temp = util.tempdir()
 
-    def load_module(self, path):
-        return _load_module(self._temp.relpath(path))
+    if isinstance(binary, (bytes, bytearray)):
+        path_exec = temp.relpath("server.minrpc")
+        with open(path_exec, "wb") as outfile:
+            outfile.write(binary)
+        os.chmod(path_exec, stat.S_IXUSR | stat.S_IRUSR)
+        path_exec = os.path.abspath(path_exec)
+    else:
+        path_exec = os.path.abspath(binary)
+        if not os.path.isfile(path_exec):
+            raise RuntimeError(f"{path_exec} does not exist.")
+        if not os.access(path_exec, os.X_OK):
+            raise RuntimeError(f"{path_exec} is not executable.")
+
+    sess = _ffi_api.CreatePipeClient(path_exec)
+    return sess
+
+
+class PopenSession(RPCSession):
+    """RPCSession interface backed by popen.
+
+    Parameters
+    ----------
+    binary : List[Union[str, bytes]]
+        The binary to be executed.
+    """
+    def __init__(self, binary):
+        RPCSession.__init__(self, _popen_session(binary))
 
 
 class TrackerSession(object):
@@ -378,7 +406,7 @@ class TrackerSession(object):
                 key, max_retry, str(last_err)))
 
 
-def connect(url, port, key="", session_timeout=0):
+def connect(url, port, key="", session_timeout=0, session_constructor_args=None):
     """Connect to RPC Server
 
     Parameters
@@ -397,15 +425,43 @@ def connect(url, port, key="", session_timeout=0):
         the connection when duration is longer than this value.
         When duration is zero, it means the request must always be kept alive.
 
+    session_constructor_args: List
+        List of additional arguments to passed as the remote session constructor.
+        The first element of the list is always a string specifying the name of
+        the session constructor, the following args are the positional args to that function.
+
     Returns
     -------
     sess : RPCSession
         The connected session.
+
+    Examples
+    --------
+    Normal usage
+    .. code-block:: python
+
+        client = rpc.connect(server_url, server_port, server_key)
+
+    Session_constructor can be used to customize the session in the remote
+    The following code connects to a remote internal server via a proxy
+    by constructing another RPCClientSession on the proxy machine and use that
+    as the serving session of the proxy endpoint.
+
+    .. code-block:: python
+
+        client_via_proxy = rpc.connect(
+            proxy_server_url, proxy_server_port, proxy_server_key,
+            session_constructor_args=[
+                "rpc.Connect", internal_url, internal_port, internal_key])
+
     """
     try:
         if session_timeout:
             key += " -timeout=%s" % str(session_timeout)
-        sess = base._Connect(url, port, key)
+        session_constructor_args = session_constructor_args if session_constructor_args else []
+        if not isinstance(session_constructor_args, (list, tuple)):
+            raise TypeError("Expect the session constructor to be a list or tuple")
+        sess = _ffi_api.Connect(url, port, key, *session_constructor_args)
     except NameError:
         raise RuntimeError("Please compile with USE_RPC=1")
     return RPCSession(sess)
diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py
new file mode 100644 (file)
index 0000000..760c536
--- /dev/null
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utils to path."""
+import os
+from tvm._ffi import libinfo
+from tvm.contrib import cc
+
+
+def find_minrpc_server_libpath(server="posix_popen_server"):
+    """Get the path of minrpc server libary.
+
+    Parameters
+    ----------
+    server : str
+        The kind of built in minrpc server.
+
+    Returns
+    -------
+    path : str
+        The path to the min server library.
+    """
+    curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+    source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", ".."))
+
+    path = os.path.join(
+        source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server))
+
+    candidates = [path]
+    if not os.path.isfile(path):
+        raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates))
+    return path
+
+
+def with_minrpc(compile_func,
+                server="posix_popen_server",
+                runtime="libtvm"):
+    """Attach the compiler function with minrpc related options.
+
+    Parameters
+    ----------
+    compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
+        The compilation function to decorate.
+
+    server : str
+        The server type.
+
+    runtime : str
+        The runtime library.
+
+    Returns
+    -------
+    fcompile : function
+        The return compilation.
+    """
+    server_path = find_minrpc_server_libpath(server)
+    runtime_path = libinfo.find_lib_path(
+        [runtime, runtime + ".so", runtime + ".dylib"])[0]
+
+    runtime_dir = os.path.abspath(os.path.dirname(runtime_path))
+    options = ["-std=c++14"]
+    # Make sure the rpath to the libtvm is set so we can do local tests.
+    # Note that however, this approach won't work on remote.
+    # Always recommend to to link statically.
+    options += ["-Wl,-rpath=" + runtime_dir]
+    options += ["-I" + path for path in libinfo.find_include_path()]
+    fcompile = cc.cross_compiler(
+        compile_func,
+        options=options,
+        add_files=[server_path, runtime_path])
+    fcompile.__name__ = "with_minrpc"
+    fcompile.need_system_lib = True
+    return fcompile
index c3a3647..03746da 100644 (file)
@@ -42,6 +42,7 @@ except ImportError as error_msg:
     raise ImportError(
         "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
 
+from . import _ffi_api
 from . import base
 from .base import TrackerCode
 from .server import _server_env
@@ -549,7 +550,7 @@ def websocket_proxy_server(url, key=""):
             data = bytes(data)
             conn.write_message(data, binary=True)
             return len(data)
-        on_message = base._CreateEventDrivenServer(
+        on_message = _ffi_api.CreateEventDrivenServer(
             _fsend, "WebSocketProxyServer", "%toinit")
         return on_message
 
index 03749c1..15a3c7d 100644 (file)
@@ -43,6 +43,7 @@ from tvm._ffi.base import py_str
 from tvm._ffi.libinfo import find_lib_path
 from tvm.runtime.module import load_module as _load_module
 from tvm.contrib import util
+from . import _ffi_api
 from . import base
 from . base import TrackerCode
 
@@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None):
         temp = util.tempdir()
 
     # pylint: disable=unused-variable
-    @tvm._ffi.register_func("tvm.rpc.server.workpath")
+    @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True)
     def get_workpath(path):
         return temp.relpath(path)
 
@@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None):
     """Server loop"""
     sockfd = sock.fileno()
     temp = _server_env(load_library, work_path)
-    base._ServerLoop(sockfd)
+    _ffi_api.ServerLoop(sockfd)
     if not work_path:
         temp.remove()
     logger.info("Finish serving %s", addr)
@@ -330,7 +331,7 @@ class Server(object):
                  utvm_dev_config_args=None,
                  ):
         try:
-            if base._ServerLoop is None:
+            if _ffi_api.ServerLoop is None:
                 raise RuntimeError("Please compile with USE_RPC=1")
         except NameError:
             raise RuntimeError("Please compile with USE_RPC=1")
index 716f87f..b580e3f 100644 (file)
@@ -244,6 +244,7 @@ class Module(object):
     def export_library(self,
                        file_name,
                        fcompile=None,
+                       addons=None,
                        **kwargs):
         """Export the module and its imported device code one library.
 
@@ -283,7 +284,7 @@ class Module(object):
 
         modules = self._collect_dso_modules()
         temp = _util.tempdir()
-        files = []
+        files = addons if addons else []
         is_system_lib = False
         has_c_module = False
         llvm_target_triple = None
@@ -313,6 +314,9 @@ class Module(object):
         if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"):
             llvm_target_triple = fcompile.get_target_triple()
 
+        if getattr(fcompile, "need_system_lib", False) and not is_system_lib:
+            raise ValueError("%s need --system-lib option" % str(fcompile))
+
         if self.imported_modules:
             if enabled("llvm") and llvm_target_triple:
                 path_obj = temp.relpath("devc.o")
index fb1f74d..32b3381 100644 (file)
@@ -460,6 +460,7 @@ int TVMFuncCall(TVMFunctionHandle func,
                 TVMValue* ret_val,
                 int* ret_type_code) {
   API_BEGIN();
+
   TVMRetValue rv;
   (*static_cast<const PackedFunc*>(func)).CallPacked(
       TVMArgs(args, arg_type_codes, num_args), &rv);
@@ -585,6 +586,42 @@ int TVMCbArgToReturn(TVMValue* value, int* code) {
   API_END();
 }
 
+
+int TVMDeviceAllocDataSpace(DLContext ctx,
+                            size_t nbytes,
+                            size_t alignment,
+                            DLDataType type_hint,
+                            void** out_data) {
+  API_BEGIN();
+  out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(
+      ctx, nbytes, alignment, type_hint);
+  API_END();
+}
+
+int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) {
+  API_BEGIN();
+  DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr);
+  API_END();
+}
+
+int TVMDeviceCopyDataFromTo(const void* from,
+                            size_t from_offset,
+                            void* to,
+                            size_t to_offset,
+                            size_t num_bytes,
+                            TVMContext ctx_from,
+                            TVMContext ctx_to,
+                            DLDataType type_hint,
+                            TVMStreamHandle stream) {
+  API_BEGIN();
+  TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to;
+  DeviceAPIManager::Get(ctx)->CopyDataFromTo(
+      from, from_offset,
+      to, to_offset,
+      num_bytes, ctx_from, ctx_to, type_hint, stream);
+  API_END();
+}
+
 // set device api
 TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
 .set_body([](TVMArgs args, TVMRetValue *ret) {
index d2ed7ff..813a79d 100644 (file)
@@ -36,7 +36,7 @@ void ModuleNode::Import(Module other) {
   if (!std::strcmp(this->type_key(), "rpc")) {
     static const PackedFunc* fimport_ = nullptr;
     if (fimport_ == nullptr) {
-      fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
+      fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule");
       CHECK(fimport_ != nullptr);
     }
     (*fimport_)(GetRef<Module>(this), other);
index 4717d89..855a342 100644 (file)
@@ -37,7 +37,7 @@ struct Registry::Manager {
   // map storing the functions.
   // We delibrately used raw pointer
   // This is because PackedFunc can contain callbacks into the host languge(python)
-  // and the resource can become invalid because of indeterminstic order of destruction.
+  // and the resource can become invalid because of indeterminstic order of destruction and forking.
   // The resources will only be recycled during program exit.
   std::unordered_map<std::string, Registry*> fmap;
   // mutex
@@ -60,20 +60,18 @@ Registry& Registry::set_body(PackedFunc f) {  // NOLINT(*)
   return *this;
 }
 
-Registry& Registry::Register(const std::string& name, bool override) {  // NOLINT(*)
+Registry& Registry::Register(const std::string& name, bool can_override) {  // NOLINT(*)
   Manager* m = Manager::Global();
   std::lock_guard<std::mutex> lock(m->mutex);
-  auto it = m->fmap.find(name);
-  if (it == m->fmap.end()) {
-    Registry* r = new Registry();
-    r->name_ = name;
-    m->fmap[name] = r;
-    return *r;
-  } else {
-    CHECK(override)
-      << "Global PackedFunc " << name << " is already registered";
-    return *it->second;
+  if (m->fmap.count(name)) {
+    CHECK(can_override)
+        << "Global PackedFunc " << name << " is already registered";
   }
+
+  Registry* r = new Registry();
+  r->name_ = name;
+  m->fmap[name] = r;
+  return *r;
 }
 
 bool Registry::Remove(const std::string& name) {
diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h
new file mode 100644 (file)
index 0000000..63ad359
--- /dev/null
@@ -0,0 +1,598 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file minrpc_server.h
+ * \brief Minimum RPC server implementation,
+ *        redirects all the calls to C runtime API.
+ *
+ * \note This file do not depend on c++ std or c std,
+ *       and only depends on TVM's C runtime API.
+ */
+#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
+#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
+
+#include <dmlc/endian.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include "../rpc_protocol.h"
+#include "../../../support/arena.h"
+
+/*! \brief Whether or not to enable glog style DLOG */
+#ifndef TVM_MINRPC_ENABLE_LOGGING
+#define TVM_MINRPC_ENABLE_LOGGING 0
+#endif
+
+#ifndef MINRPC_CHECK
+#define MINRPC_CHECK(cond)                                      \
+  if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError);
+#endif
+
+#if TVM_MINRPC_ENABLE_LOGGING
+#include <dmlc/logging.h>
+#endif
+
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief A minimum RPC server that only depends on the tvm C runtime..
+ *
+ *  All the dependencies are provided by the io arguments.
+ *
+ * \tparam TIOHandler IO provider to provide io handling.
+ *         An IOHandler needs to provide the following functions:
+ *         - PosixWrite, PosixRead, Close: posix style, read, write, close API.
+ *         - Exit: exit with status code.
+ */
+template<typename TIOHandler>
+class MinRPCServer {
+ public:
+  /*!
+   * \brief Constructor.
+   * \param io The IO handler.
+   */
+  explicit MinRPCServer(TIOHandler io)
+      : io_(io), arena_(PageAllocator(io)) {}
+
+  /*! \brief Run the server loop until shutdown signal is received. */
+  void ServerLoop() {
+    RPCCode code;
+    uint64_t packet_len;
+
+    while (true) {
+      arena_.RecycleAll();
+      allow_clean_shutdown_ = true;
+
+      this->Read(&packet_len);
+      if (packet_len == 0) continue;
+      this->Read(&code);
+
+      allow_clean_shutdown_ = false;
+
+      if (code >= RPCCode::kSyscallCodeStart) {
+        this->HandleSyscallFunc(code);
+      } else {
+        switch (code) {
+          case RPCCode::kCallFunc: {
+            HandleNormalCallFunc();
+            break;
+          }
+          case RPCCode::kInitServer: {
+            HandleInitServer();
+            break;
+          }
+          case RPCCode::kCopyFromRemote: {
+            HandleCopyFromRemote();
+            break;
+          }
+          case RPCCode::kCopyToRemote: {
+            HandleCopyToRemote();
+            break;
+          }
+          case RPCCode::kShutdown: {
+            this->Shutdown();
+            return;
+          }
+          default: {
+            this->ThrowError(RPCServerStatus::kUnknownRPCCode);
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  void Shutdown() {
+    arena_.FreeAll();
+    io_.Close();
+  }
+
+  void HandleNormalCallFunc() {
+    uint64_t call_handle;
+    TVMValue* values;
+    int* tcodes;
+    int num_args;
+    TVMValue ret_value[3];
+    int ret_tcode[3];
+
+    this->Read(&call_handle);
+    RecvPackedSeq(&values, &tcodes, &num_args);
+
+    int call_ecode = TVMFuncCall(
+        reinterpret_cast<void*>(call_handle),
+        values, tcodes, num_args,
+        &(ret_value[1]), &(ret_tcode[1]));
+
+    if (call_ecode == 0) {
+      // Return value encoding as in LocalSession
+      int rv_tcode = ret_tcode[1];
+      ret_tcode[0] = kDLInt;
+      ret_value[0].v_int64 = rv_tcode;
+      if (rv_tcode == kTVMNDArrayHandle) {
+        ret_tcode[1] = kTVMDLTensorHandle;
+        ret_value[2].v_handle = ret_value[1].v_handle;
+        ret_tcode[2] = kTVMOpaqueHandle;
+        this->ReturnPackedSeq(ret_value, ret_tcode, 3);
+      } else if (rv_tcode == kTVMPackedFuncHandle ||
+                 rv_tcode == kTVMModuleHandle) {
+        ret_tcode[1] = kTVMOpaqueHandle;
+        this->ReturnPackedSeq(ret_value, ret_tcode, 2);
+      } else {
+        this->ReturnPackedSeq(ret_value, ret_tcode, 2);
+      }
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void HandleCopyFromRemote() {
+    uint64_t handle, offset, num_bytes;
+    TVMContext ctx;
+    DLDataType type_hint;
+
+    this->Read(&handle);
+    this->Read(&offset);
+    this->Read(&num_bytes);
+    this->Read(&ctx);
+    this->Read(&type_hint);
+
+    uint8_t* data_ptr;
+    int call_ecode = 0;
+    if (ctx.device_type == kDLCPU) {
+      data_ptr = reinterpret_cast<uint8_t*>(handle) + offset;
+    } else {
+      data_ptr = this->ArenaAlloc<uint8_t>(num_bytes);
+      call_ecode = TVMDeviceCopyDataFromTo(
+              reinterpret_cast<void*>(handle), offset,
+              data_ptr, 0, num_bytes,
+              ctx, DLContext{kDLCPU, 0},
+              type_hint, nullptr);
+    }
+
+    if (call_ecode == 0) {
+      RPCCode code = RPCCode::kCopyAck;
+      uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+      this->Write(packet_nbytes);
+      this->Write(code);
+      this->WriteArray(data_ptr, num_bytes);
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void HandleCopyToRemote() {
+    uint64_t handle, offset, num_bytes;
+    TVMContext ctx;
+    DLDataType type_hint;
+
+    this->Read(&handle);
+    this->Read(&offset);
+    this->Read(&num_bytes);
+    this->Read(&ctx);
+    this->Read(&type_hint);
+    int call_ecode = 0;
+
+    if (ctx.device_type == kDLCPU) {
+      uint8_t* dptr = reinterpret_cast<uint8_t*>(handle) + offset;
+      this->ReadArray(dptr, num_bytes);
+    } else {
+      uint8_t* temp_data = this->ArenaAlloc<uint8_t>(num_bytes);
+      this->ReadArray(temp_data, num_bytes);
+
+      call_ecode = TVMDeviceCopyDataFromTo(
+              temp_data, 0,
+              reinterpret_cast<void*>(handle), offset,
+              num_bytes,
+              DLContext{kDLCPU, 0}, ctx,
+              type_hint, nullptr);
+    }
+
+    if (call_ecode == 0) {
+      this->ReturnVoid();
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void HandleSyscallFunc(RPCCode code) {
+    TVMValue* values;
+    int* tcodes;
+    int num_args;
+    RecvPackedSeq(&values, &tcodes, &num_args);
+    switch (code) {
+      case RPCCode::kFreeHandle: {
+        this->SyscallFreeHandle(values, tcodes, num_args);
+        break;
+      }
+      case RPCCode::kGetGlobalFunc: {
+        this->SyscallGetGlobalFunc(values, tcodes, num_args);
+        break;
+      }
+      case RPCCode::kDevSetDevice: {
+        this->ReturnException("SetDevice not supported");
+        break;
+      }
+      case RPCCode::kDevGetAttr: {
+        this->ReturnException("GetAttr not supported");
+        break;
+      }
+      case RPCCode::kDevAllocData: {
+        this->SyscallDevAllocData(values, tcodes, num_args);
+        break;
+      }
+      case RPCCode::kDevFreeData: {
+        this->SyscallDevFreeData(values, tcodes, num_args);
+        break;
+      }
+      case RPCCode::kDevStreamSync: {
+        this->SyscallDevStreamSync(values, tcodes, num_args);
+        break;
+      }
+      case RPCCode::kCopyAmongRemote: {
+        this->SyscallCopyAmongRemote(values, tcodes, num_args);
+        break;
+      }
+      default: {
+        this->ReturnException("Syscall not recognized");
+        break;
+      }
+    }
+  }
+
+  void HandleInitServer() {
+    uint64_t len;
+    this->Read(&len);
+    char* proto_ver = this->ArenaAlloc<char>(len + 1);
+    this->ReadArray(proto_ver, len);
+
+    TVMValue* values;
+    int* tcodes;
+    int num_args;
+    RecvPackedSeq(&values, &tcodes, &num_args);
+    MINRPC_CHECK(num_args == 0);
+    this->ReturnVoid();
+  }
+
+  void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 2);
+    MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle);
+    MINRPC_CHECK(tcodes[1] == kDLInt);
+
+    void* handle = values[0].v_handle;
+    int64_t type_code = values[1].v_int64;
+    int call_ecode;
+
+    if (type_code == kTVMNDArrayHandle) {
+      call_ecode = TVMArrayFree(static_cast<TVMArrayHandle>(handle));
+    } else if (type_code == kTVMPackedFuncHandle) {
+      call_ecode = TVMFuncFree(handle);
+    } else {
+      MINRPC_CHECK(type_code == kTVMModuleHandle);
+      call_ecode = TVMModFree(handle);
+    }
+
+    if (call_ecode == 0) {
+      this->ReturnVoid();
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 1);
+    MINRPC_CHECK(tcodes[0] == kTVMStr);
+
+    void* handle;
+    int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle);
+
+    if (call_ecode == 0) {
+      this->ReturnHandle(handle);
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 9);
+    // from, from_offset
+    MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle);
+    MINRPC_CHECK(tcodes[1] == kDLInt);
+    // to, to_offset
+    MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle);
+    MINRPC_CHECK(tcodes[3] == kDLInt);
+    // size
+    MINRPC_CHECK(tcodes[4] == kDLInt);
+    // ctx_from, ctx_to
+    MINRPC_CHECK(tcodes[5] == kTVMContext);
+    MINRPC_CHECK(tcodes[6] == kTVMContext);
+    // type_hint, stream
+    MINRPC_CHECK(tcodes[7] == kTVMDataType);
+    MINRPC_CHECK(tcodes[8] == kTVMOpaqueHandle);
+
+    void* from = values[0].v_handle;
+    int64_t from_offset = values[1].v_int64;
+    void* to = values[2].v_handle;
+    int64_t to_offset = values[3].v_int64;
+    int64_t size = values[4].v_int64;
+    TVMContext ctx_from = values[5].v_ctx;
+    TVMContext ctx_to = values[6].v_ctx;
+    DLDataType type_hint = values[7].v_type;
+    TVMStreamHandle stream = values[8].v_handle;
+
+    int call_ecode = TVMDeviceCopyDataFromTo(
+        from, from_offset,
+        to, to_offset, size,
+        ctx_from, ctx_to, type_hint, stream);
+
+    if (call_ecode == 0) {
+      this->ReturnVoid();
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 4);
+    MINRPC_CHECK(tcodes[0] == kTVMContext);
+    MINRPC_CHECK(tcodes[1] == kDLInt);
+    MINRPC_CHECK(tcodes[2] == kDLInt);
+    MINRPC_CHECK(tcodes[3] == kTVMDataType);
+
+    TVMContext ctx = values[0].v_ctx;
+    int64_t nbytes = values[1].v_int64;
+    int64_t alignment = values[2].v_int64;
+    DLDataType type_hint = values[3].v_type;
+
+    void* handle;
+    int call_ecode = TVMDeviceAllocDataSpace(
+        ctx, nbytes, alignment, type_hint, &handle);
+
+    if (call_ecode == 0) {
+      this->ReturnHandle(handle);
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 2);
+    MINRPC_CHECK(tcodes[0] == kTVMContext);
+    MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
+
+    TVMContext ctx = values[0].v_ctx;
+    void* handle = values[1].v_handle;
+
+    int call_ecode = TVMDeviceFreeDataSpace(ctx, handle);
+
+    if (call_ecode == 0) {
+      this->ReturnVoid();
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) {
+    MINRPC_CHECK(num_args == 2);
+    MINRPC_CHECK(tcodes[0] == kTVMContext);
+    MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
+
+    TVMContext ctx = values[0].v_ctx;
+    void* handle = values[1].v_handle;
+
+    int call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, handle);
+
+    if (call_ecode == 0) {
+      this->ReturnVoid();
+    } else {
+      this->ReturnLastTVMError();
+    }
+  }
+
+  void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
+    io_.Exit(static_cast<int>(code));
+  }
+
+  template<typename T>
+  T* ArenaAlloc(int count) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return arena_.template allocate_<T>(count);
+  }
+
+  template<typename T>
+  void Read(T* data) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    this->ReadRawBytes(data, sizeof(T));
+  }
+
+  template<typename T>
+  void ReadArray(T* data, size_t count) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return this->ReadRawBytes(data, sizeof(T) * count);
+  }
+
+  template<typename T>
+  void Write(const T& data) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return this->WriteRawBytes(&data, sizeof(T));
+  }
+
+  template<typename T>
+  void WriteArray(T* data, size_t count) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return this->WriteRawBytes(data, sizeof(T) * count);
+  }
+
+ private:
+  // Internal allocator that redirects alloc to TVM's C API.
+  class PageAllocator {
+   public:
+    using ArenaPageHeader = tvm::support::ArenaPageHeader;
+
+    explicit PageAllocator(TIOHandler io)
+        : io_(io) {}
+
+    ArenaPageHeader* allocate(size_t min_size) {
+      size_t npages = ((min_size + kPageSize - 1) / kPageSize);
+      void* data;
+
+      if (TVMDeviceAllocDataSpace(
+              DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign,
+              DLDataType{kDLInt, 1, 1}, &data) != 0) {
+        io_.Exit(static_cast<int>(RPCServerStatus::kAllocError));
+      }
+
+      ArenaPageHeader* header = static_cast<ArenaPageHeader*>(data);
+      header->size = npages * kPageSize;
+      header->offset = sizeof(ArenaPageHeader);
+      return header;
+    }
+
+    void deallocate(ArenaPageHeader* page) {
+      if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) {
+        io_.Exit(static_cast<int>(RPCServerStatus::kAllocError));
+      }
+    }
+
+    static const constexpr int kPageSize = 2 << 10;
+    static const constexpr int kPageAlign = 8;
+
+   private:
+    TIOHandler io_;
+  };
+
+  void RecvPackedSeq(TVMValue** out_values,
+                     int** out_tcodes,
+                     int* out_num_args) {
+    RPCReference::RecvPackedSeq(
+        out_values, out_tcodes, out_num_args, this);
+  }
+
+  void ReturnVoid() {
+    int32_t num_args = 1;
+    int32_t tcode = kTVMNullptr;
+    RPCCode code = RPCCode::kReturn;
+
+    uint64_t packet_nbytes =
+        sizeof(code) + sizeof(num_args) + sizeof(tcode);
+
+    this->Write(packet_nbytes);
+    this->Write(code);
+    this->Write(num_args);
+    this->Write(tcode);
+  }
+
+  void ReturnHandle(void* handle) {
+    int32_t num_args = 1;
+    int32_t tcode = kTVMOpaqueHandle;
+    RPCCode code = RPCCode::kReturn;
+    uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
+
+    uint64_t packet_nbytes =
+        sizeof(code) + sizeof(num_args) +
+        sizeof(tcode) + sizeof(encode_handle);
+
+    this->Write(packet_nbytes);
+    this->Write(code);
+    this->Write(num_args);
+    this->Write(tcode);
+    this->Write(encode_handle);
+  }
+
+  void ReturnException(const char* msg) {
+    RPCReference::ReturnException(msg, this);
+  }
+
+  void ReturnPackedSeq(const TVMValue* arg_values,
+                       const int* type_codes,
+                       int num_args) {
+    RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
+  }
+
+  void ReturnLastTVMError() {
+    this->ReturnException(TVMGetLastError());
+  }
+
+  void ReadRawBytes(void* data, size_t size) {
+    uint8_t* buf = reinterpret_cast<uint8_t*>(data);
+    size_t ndone = 0;
+    while (ndone <  size) {
+      ssize_t ret = io_.PosixRead(buf, size - ndone);
+      if (ret == 0) {
+        if (allow_clean_shutdown_) {
+          this->Shutdown();
+          io_.Exit(0);
+        } else {
+          this->ThrowError(RPCServerStatus::kReadError);
+        }
+      }
+      if (ret == -1) {
+        this->ThrowError(RPCServerStatus::kReadError);
+      }
+      ndone += ret;
+      buf += ret;
+    }
+  }
+
+  void WriteRawBytes(const void* data, size_t size) {
+    const uint8_t *buf = reinterpret_cast<const uint8_t*>(data);
+    size_t ndone = 0;
+    while (ndone <  size) {
+      ssize_t ret = io_.PosixWrite(buf, size - ndone);
+      if (ret == 0 || ret == -1) {
+        this->ThrowError(RPCServerStatus::kWriteError);
+      }
+      buf += ret;
+      ndone += ret;
+    }
+  }
+
+  /*! \brief IO handler. */
+  TIOHandler io_;
+  /*! \brief internal arena. */
+  support::GenericArena<PageAllocator> arena_;
+  /*! \brief Whether we are in a state that allows clean shutdown. */
+  bool allow_clean_shutdown_{true};
+  static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian.");
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc
new file mode 100644 (file)
index 0000000..fdc5711
--- /dev/null
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+// Disable constructor to bring minimum dep on c++ABI.
+#define TVM_ARENA_HAS_DESTRUCTOR 0
+
+#include <unistd.h>
+#include <cstdlib>
+#include "minrpc_server.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief IOHandler based on posix API.
+ */
+class PosixIOHandler {
+ public:
+  explicit PosixIOHandler(int read_fd = 0, int write_fd = 1)
+      : read_fd_(read_fd), write_fd_(write_fd) {
+  }
+
+  ssize_t PosixRead(void* data, size_t size) {
+    return read(read_fd_, data, size);
+  }
+
+  ssize_t PosixWrite(const void* data, size_t size) {
+    return write(write_fd_, data, size);
+  }
+
+  void Exit(int code) {
+    exit(code);
+  }
+
+  void Close() {
+    if (read_fd_ != 0) close(read_fd_);
+    if (write_fd_ != 0) close(write_fd_);
+  }
+
+ private:
+  int read_fd_{0};
+  int write_fd_{1};
+};
+
+/*! \brief Type for the posix version of min rpc server. */
+using PosixMinRPCServer = MinRPCServer<PosixIOHandler>;
+
+}  // namespace runtime
+}  // namespace tvm
+
+int main(int argc, char* argv[]) {
+  if (argc != 3) return -1;
+  // pass the descriptor via arguments.
+  tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2]));
+  tvm::runtime::PosixMinRPCServer server(handler);
+  server.ServerLoop();
+  return 0;
+}
diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc
new file mode 100644 (file)
index 0000000..f8dc6e6
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_channel.cc
+ */
+#include <string>
+#include "rpc_channel.h"
+
+namespace tvm {
+namespace runtime {
+
+size_t CallbackChannel::Send(const void* data, size_t size) {
+  TVMByteArray bytes;
+  bytes.data = static_cast<const char*>(data);
+  bytes.size = size;
+  int64_t n = fsend_(bytes);
+  if (n == -1) {
+    LOG(FATAL) << "CallbackChannel::Send";
+  }
+  return static_cast<size_t>(n);
+}
+
+size_t CallbackChannel::Recv(void* data, size_t size) {
+  TVMRetValue ret = frecv_(size);
+
+  if (ret.type_code() != kTVMBytes) {
+    LOG(FATAL) << "CallbackChannel::Recv";
+  }
+  std::string* bytes = ret.ptr<std::string>();
+  memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
+  return bytes->length();
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h
new file mode 100644 (file)
index 0000000..be34a8b
--- /dev/null
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_channel.h
+ * \brief Communication endpoints to connect local and remote RPC sessions.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_
+#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_
+
+#include <tvm/runtime/packed_func.h>
+#include <utility>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Abstract channel interface used to create RPCEndpoint.
+ */
+class RPCChannel {
+ public:
+  /*! \brief virtual destructor */
+  virtual ~RPCChannel() {}
+  /*!
+   * \brief Send data over to the channel.
+   * \param data The data pointer.
+   * \param size The size fo the data.
+   * \return The actual bytes sent.
+   */
+  virtual size_t Send(const void* data, size_t size) = 0;
+  /*!
+   * \brief Recv data from channel.
+   *
+   * \param data The data pointer.
+   * \param size The size fo the data.
+   * \return The actual bytes received.
+   */
+  virtual size_t Recv(void* data, size_t size) = 0;
+};
+
+/*!
+ * \brief RPC channel which callback
+ * frontend (Python/Java/etc.)'s send & recv function
+ */
+class CallbackChannel final : public RPCChannel {
+ public:
+  /*!
+   * \brief Constructor.
+   *
+   * \param fsend The send function, takes in a TVMByteArray and returns the
+   *              number of bytes sent in that array. Returns -1 if error happens.
+   * \param frecv The recv function, takes an expected maximum size, and return
+   *              a byte array with the actual amount of data received.
+   */
+  explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
+      : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}
+
+  ~CallbackChannel() {}
+  /*!
+   * \brief Send data over to the channel.
+   * \param data The data pointer.
+   * \param size The size fo the data.
+   * \return The actual bytes sent.
+   */
+  size_t Send(const void* data, size_t size) final;
+  /*!
+   * \brief Recv data from channel.
+   *
+   * \param data The data pointer.
+   * \param size The size fo the data.
+   * \return The actual bytes received.
+   */
+  size_t Recv(void* data, size_t size) final;
+
+ private:
+  PackedFunc fsend_;
+  PackedFunc frecv_;
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_RPC_CHANNEL_H_
index 9fd45ac..ade4d16 100644 (file)
@@ -23,6 +23,7 @@
 #include <dmlc/logging.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/device_api.h>
+#include <utility>
 #include "rpc_session.h"
 
 namespace tvm {
@@ -31,20 +32,24 @@ namespace runtime {
 class RPCDeviceAPI final : public DeviceAPI {
  public:
   void SetDevice(TVMContext ctx) final {
-    GetSess(ctx)->CallRemote(
-        RPCCode::kDevSetDevice, ctx);
+    auto remote_ctx = RemoveSessMask(ctx);
+    GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx);
   }
+
   void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
-    *rv = GetSess(ctx)->CallRemote(
-        RPCCode::kDevGetAttr, ctx, static_cast<int>(kind));
+    auto remote_ctx = RemoveSessMask(ctx);
+    GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv);
   }
+
   void* AllocDataSpace(TVMContext ctx,
                        size_t nbytes,
                        size_t alignment,
                        DLDataType type_hint) final {
     auto sess = GetSess(ctx);
-    void *data = sess->CallRemote(
-            RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint);
+    auto remote_ctx = RemoveSessMask(ctx);
+    void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(
+        remote_ctx, nbytes, alignment, type_hint);
+
     RemoteSpace* space = new RemoteSpace();
     space->data = data;
     space->sess = std::move(sess);
@@ -52,9 +57,10 @@ class RPCDeviceAPI final : public DeviceAPI {
   }
   void FreeDataSpace(TVMContext ctx, void* ptr) final {
     RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
+    auto remote_ctx = RemoveSessMask(ctx);
     try {
-      GetSess(ctx)->CallRemote(
-          RPCCode::kDevFreeData, ctx, space->data);
+      GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(
+          remote_ctx, space->data);
     } catch (const dmlc::Error& e) {
       // fault tolerance to remote close.
     }
@@ -75,29 +81,35 @@ class RPCDeviceAPI final : public DeviceAPI {
         to_dev_type > kRPCSessMask) {
       CHECK(ctx_from.device_type == ctx_to.device_type)
           << "Cannot copy across two different remote session";
-      GetSess(ctx_from)->CallRemote(
-          RPCCode::kCopyAmongRemote,
-          static_cast<const RemoteSpace*>(from)->data, from_offset,
-          static_cast<const RemoteSpace*>(to)->data, to_offset,
-          size,  ctx_from, ctx_to, type_hint, stream);
+      auto remote_ctx_from = RemoveSessMask(ctx_from);
+      auto remote_ctx_to = RemoveSessMask(ctx_to);
+      auto remote_ctx = remote_ctx_from;
+      if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to;
+      GetSess(ctx_from)->GetDeviceAPI(remote_ctx)
+          ->CopyDataFromTo(static_cast<const RemoteSpace*>(from)->data, from_offset,
+                           static_cast<const RemoteSpace*>(to)->data, to_offset,
+                           size, remote_ctx_from, remote_ctx_to, type_hint, stream);
     } else if (from_dev_type > kRPCSessMask &&
                to_dev_type == kDLCPU) {
+      auto remote_ctx_from = RemoveSessMask(ctx_from);
       GetSess(ctx_from)->CopyFromRemote(
           static_cast<const RemoteSpace*>(from)->data, from_offset,
-          to, to_offset, size, ctx_from, type_hint);
+          to, to_offset, size, remote_ctx_from, type_hint);
     } else if (from_dev_type == kDLCPU &&
                to_dev_type > kRPCSessMask) {
+      auto remote_ctx_to = RemoveSessMask(ctx_to);
       GetSess(ctx_to)->CopyToRemote(
-          (void*)from, from_offset,  // NOLINT(*)
+          const_cast<void*>(from), from_offset,
           static_cast<const RemoteSpace*>(to)->data, to_offset,
-          size, ctx_to, type_hint);
+          size, remote_ctx_to, type_hint);
     } else {
       LOG(FATAL) << "expect copy from/to remote or between remote";
     }
   }
+
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
-    GetSess(ctx)->CallRemote(
-        RPCCode::kDevStreamSync, ctx, stream);
+    auto remote_ctx = RemoveSessMask(ctx);
+    GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(ctx, stream);
   }
 
  private:
@@ -107,6 +119,11 @@ class RPCDeviceAPI final : public DeviceAPI {
     int tbl_index = dev_type / kRPCSessMask -  1;
     return RPCSession::Get(tbl_index);
   }
+
+  static TVMContext RemoveSessMask(TVMContext ctx) {
+    ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
+    return ctx;
+  }
 };
 
 TVM_REGISTER_GLOBAL("device_api.rpc")
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
new file mode 100644 (file)
index 0000000..916ecae
--- /dev/null
@@ -0,0 +1,1059 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_session.cc
+ * \brief RPC session for remote function call.
+ */
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/serializer.h>
+#include <memory>
+#include <array>
+#include <string>
+#include <chrono>
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <algorithm>
+
+#include "rpc_endpoint.h"
+#include "rpc_local_session.h"
+#include "../object_internal.h"
+#include "../../support/ring_buffer.h"
+#include "../../support/arena.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * Event-driven state-machine based handlers for RPCEndpoint.
+ *
+ * Key functions:
+ *
+ * - SendPackedSeq: send the arguments over to the peer
+ * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol).
+ */
+class RPCEndpoint::EventHandler : public dmlc::Stream {
+ public:
+  EventHandler(support::RingBuffer* reader,
+               support::RingBuffer* writer,
+               std::string name,
+               std::string* remote_key)
+      : reader_(reader),
+        writer_(writer),
+        name_(name),
+        remote_key_(remote_key) {
+    this->Clear();
+
+    if (*remote_key == "%toinit") {
+      state_ = kInitHeader;
+      remote_key_->resize(0);
+      pending_request_bytes_ = sizeof(int32_t);
+    }
+  }
+
+  /*!
+   * \brief Bytes needed to fulfill current request
+   */
+  size_t BytesNeeded() const {
+    if (reader_->bytes_available() < pending_request_bytes_) {
+      return pending_request_bytes_ - reader_->bytes_available();
+    } else {
+      return 0;
+    }
+  }
+
+  /*!
+   * \brief Request number of bytes from the reader.
+   * \param nbytes The number of bytes
+   */
+  void RequestBytes(size_t nbytes) {
+    pending_request_bytes_ += nbytes;
+    reader_->Reserve(pending_request_bytes_);
+  }
+
+  /*! \return Whether we are ready to handle next request. */
+  bool Ready() const {
+    return reader_->bytes_available() >= pending_request_bytes_;
+  }
+
+  /*! \return Whether we can perform a clean shutdown */
+  bool CanCleanShutdown() const {
+    return state_ == kRecvPacketNumBytes;
+  }
+
+  /*! \brief Finish the copy ack stage. */
+  void FinishCopyAck() {
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+  /*!
+   * \brief Enter the io loop until the next event.
+   * \param client_mode Whether we are in the client.
+   * \param setreturn The function to set the return value encoding.
+   * \return The function to set return values when there is a return event.
+   */
+  RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) {
+    std::swap(client_mode_, client_mode);
+
+    while (this->Ready()) {
+      switch (state_) {
+        case kInitHeader: HandleInitHeader(); break;
+        case kRecvPacketNumBytes: {
+          uint64_t packet_nbytes;
+          CHECK(this->Read(&packet_nbytes));
+          if (packet_nbytes != 0) {
+            this->SwitchToState(kProcessPacket);
+            this->RequestBytes(packet_nbytes);
+          } else {
+            this->SwitchToState(kRecvPacketNumBytes);
+          }
+          break;
+        }
+        case kProcessPacket: {
+          this->HandleProcessPacket(setreturn);
+          break;
+        }
+        case kReturnReceived: {
+          this->SwitchToState(kRecvPacketNumBytes);
+          std::swap(client_mode_, client_mode);
+          return RPCCode::kReturn;
+        }
+        case kCopyAckReceived: {
+          std::swap(client_mode_, client_mode);
+          return RPCCode::kCopyAck;
+        }
+        case kShutdownReceived: {
+          std::swap(client_mode_, client_mode);
+          return RPCCode::kShutdown;
+        }
+      }
+    }
+    std::swap(client_mode_, client_mode);
+    return RPCCode::kNone;
+  }
+
+  /*! \brief Clear all the states in the Handler.*/
+  void Clear() {
+    state_ = kRecvPacketNumBytes;
+    pending_request_bytes_ = sizeof(uint64_t);
+  }
+
+  /*!
+   * \brief Validate that the arguments can be sent through RPC.
+   * \param arg_values The argument values.
+   * \param type_codes The type codes.
+   */
+  void ValidateArguments(const TVMValue* arg_values,
+                         const int* type_codes,
+                         int num_args) {
+    TVMArgs args(arg_values, type_codes, num_args);
+    for (int i = 0; i < num_args; ++i) {
+      int tcode = type_codes[i];
+      if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
+        LOG(FATAL) << "ValueError: Cannot pass argument " << i
+                   << ", type " << args[i].AsObjectRef<ObjectRef>()->GetTypeKey()
+                   << " is not supported by RPC";
+      } else if (tcode == kTVMContext) {
+        DLContext ctx = args[i];
+        CHECK_LT(static_cast<int>(ctx.device_type), kRPCSessMask)
+            << "InternalError: cannot pass RPC context in the channel";
+      }
+    }
+  }
+
+  void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
+    LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code);
+  }
+
+  uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values,
+                                const int* type_codes,
+                                int num_args,
+                                bool client_mode) {
+    return RPCReference::PackedSeqGetNumBytes(
+        arg_values, type_codes, num_args, client_mode, this);
+  }
+
+  void SendPackedSeq(const TVMValue* arg_values,
+                     const int* type_codes,
+                     int num_args,
+                     bool client_mode) {
+    RPCReference::SendPackedSeq(
+        arg_values, type_codes, num_args, client_mode, this);
+  }
+
+  // Endian aware IO handling
+  using Stream::Read;
+  using Stream::Write;
+  using Stream::ReadArray;
+  using Stream::WriteArray;
+
+  bool Read(RPCCode* code) {
+    int32_t cdata;
+    if (!this->Read(&cdata)) return false;
+    *code = static_cast<RPCCode>(cdata);
+    return true;
+  }
+  void Write(RPCCode code) {
+    int32_t cdata = static_cast<int>(code);
+    this->Write(cdata);
+  }
+
+  template<typename T>
+  T* ArenaAlloc(int count) {
+    static_assert(std::is_pod<T>::value, "need to be trival");
+    return arena_.template allocate_<T>(count);
+  }
+
+ protected:
+  enum State {
+    kInitHeader,
+    kRecvPacketNumBytes,
+    kProcessPacket,
+    kReturnReceived,
+    kCopyAckReceived,
+    kShutdownReceived
+  };
+  // Current state;
+  State state_;
+  // Initialize remote header
+  bool init_header_step_{0};
+  // Whether current handler is client or server mode.
+  bool client_mode_{false};
+  // Internal arena
+  support::Arena arena_;
+
+  // State switcher
+  void SwitchToState(State state) {
+    // invariant
+    if (state != kCopyAckReceived) {
+      CHECK_EQ(pending_request_bytes_, 0U)
+          << "state=" << state;
+    }
+    state_ = state;
+    CHECK(state != kInitHeader)
+        << "cannot switch to init header";
+    if (state == kRecvPacketNumBytes) {
+      this->RequestBytes(sizeof(uint64_t));
+      // recycle arena for the next session.
+      arena_.RecycleAll();
+    }
+  }
+
+  // handler for initial header read
+  void HandleInitHeader() {
+    if (init_header_step_ == 0) {
+      int32_t len;
+      this->Read(&len);
+      remote_key_->resize(len);
+      init_header_step_ = 1;
+      this->RequestBytes(len);
+      return;
+    } else {
+      CHECK_EQ(init_header_step_, 1);
+      this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
+      this->SwitchToState(kRecvPacketNumBytes);
+    }
+  }
+
+  // Handler for read code.
+  void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) {
+    RPCCode code = RPCCode::kNone;
+    this->Read(&code);
+
+    if (code >= RPCCode::kSyscallCodeStart) {
+      this->HandleSyscall(code);
+    } else {
+        switch (code) {
+          case RPCCode::kInitServer: {
+            this->HandleInitServer();
+            break;
+          }
+          case RPCCode::kCallFunc: {
+            this->HandleNormalCallFunc();
+            break;
+          }
+          case RPCCode::kCopyFromRemote: {
+            this->HandleCopyFromRemote();
+            break;
+          }
+          case RPCCode::kCopyToRemote: {
+            this->HandleCopyToRemote();
+            break;
+          }
+          case RPCCode::kException:
+          case RPCCode::kReturn: {
+            this->HandleReturn(code, setreturn);
+            break;
+          }
+          case RPCCode::kCopyAck: {
+            this->SwitchToState(kCopyAckReceived);
+            break;
+          }
+          case RPCCode::kShutdown: {
+            this->SwitchToState(kShutdownReceived);
+            break;
+          }
+          default: LOG(FATAL) << "Unknown event "  << static_cast<int>(code);
+        }
+    }
+  }
+
+  /*!
+   * \brief Recive incoming packed seq from the stream.
+   * \return The received argments.
+   * \note The TVMArgs is available until we switchstate.
+   */
+  TVMArgs RecvPackedSeq() {
+    TVMValue* values;
+    int* tcodes;
+    int num_args;
+    RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this);
+    return TVMArgs(values, tcodes, num_args);
+  }
+
+  /*!
+   * \brief Return exception to the remote.
+   * \param err_msg The error message.
+   */
+  void ReturnException(const char* err_msg) {
+    RPCReference::ReturnException(err_msg, this);
+  }
+
+  /*!
+   * \brief Return nullptr to the remote.
+   * \param err_msg The error message.
+   */
+  void ReturnVoid() {
+    RPCReference::ReturnVoid(this);
+  }
+
+  /*!
+   * \brief Return a packed sequence to the remote.
+   * \param args The arguments.
+   */
+  void ReturnPackedSeq(TVMArgs args) {
+    RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this);
+  }
+
+  /*!
+   * \brief Handle the case when return/exception value is received.
+   * \param code The RPC code.
+   * \param setreturn The function to encode return.
+   */
+  void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
+    TVMArgs args = RecvPackedSeq();
+
+    if (code == RPCCode::kException) {
+      // switch to the state before sending exception.
+      this->SwitchToState(kRecvPacketNumBytes);
+      std::string msg = args[0];
+      LOG(FATAL) << "RPCError: Error caught from RPC call:\n" <<  msg;
+    }
+
+    CHECK(setreturn != nullptr) << "fsetreturn not available";
+    setreturn(args);
+
+    this->SwitchToState(kReturnReceived);
+  }
+
+  void HandleSyscall(RPCCode code);
+
+  void HandleCopyFromRemote() {
+    uint64_t handle, offset, num_bytes;
+    TVMContext ctx;
+    DLDataType type_hint;
+    this->Read(&handle);
+    this->Read(&offset);
+    this->Read(&num_bytes);
+    this->Read(&ctx);
+    this->Read(&type_hint);
+    size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
+
+    char* data_ptr;
+
+    if (ctx.device_type == kDLCPU) {
+      data_ptr = reinterpret_cast<char*>(handle) + offset;
+      // endian aware handling
+      if (!DMLC_IO_NO_ENDIAN_SWAP) {
+        char* temp = this->ArenaAlloc<char>(num_bytes);
+        std::memcpy(temp, data_ptr, num_bytes);
+        dmlc::ByteSwap(temp, elem_bytes, num_bytes / elem_bytes);
+        data_ptr = temp;
+      }
+    } else {
+      try {
+        data_ptr = this->ArenaAlloc<char>(num_bytes);
+        GetServingSession()->CopyFromRemote(
+            reinterpret_cast<void*>(handle), offset,
+            data_ptr, 0,
+            num_bytes, ctx, type_hint);
+        // endian aware handling
+        if (!DMLC_IO_NO_ENDIAN_SWAP) {
+          dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes);
+        }
+      } catch (const std::runtime_error &e) {
+        this->ReturnException(e.what());
+        this->SwitchToState(kRecvPacketNumBytes);
+        return;
+      }
+    }
+    RPCCode code = RPCCode::kCopyAck;
+    uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+    // Return Copy Ack
+    this->Write(packet_nbytes);
+    this->Write(code);
+    this->WriteArray(data_ptr, num_bytes);
+
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+  void HandleCopyToRemote() {
+    uint64_t handle, offset, num_bytes;
+    TVMContext ctx;
+    DLDataType type_hint;
+
+    this->Read(&handle);
+    this->Read(&offset);
+    this->Read(&num_bytes);
+    this->Read(&ctx);
+    this->Read(&type_hint);
+
+    size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
+
+    if (ctx.device_type == kDLCPU) {
+       char* dptr = reinterpret_cast<char*>(handle) + offset;
+       this->ReadArray(dptr, num_bytes);
+
+        if (!DMLC_IO_NO_ENDIAN_SWAP) {
+          dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
+        }
+    } else {
+      char* temp_data = this->ArenaAlloc<char>(num_bytes);
+      this->ReadArray(temp_data, num_bytes);
+
+      if (!DMLC_IO_NO_ENDIAN_SWAP) {
+        dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes);
+      }
+
+      try {
+        GetServingSession()->CopyToRemote(
+            temp_data, 0,
+            reinterpret_cast<void*>(handle), offset,
+            num_bytes, ctx, type_hint);
+      } catch (const std::runtime_error &e) {
+        this->ReturnException(e.what());
+        this->SwitchToState(kRecvPacketNumBytes);
+        return;
+      }
+    }
+
+    this->ReturnVoid();
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+  // Handle for packed call.
+  void HandleNormalCallFunc() {
+    uint64_t call_handle;
+
+    this->Read(&call_handle);
+    TVMArgs args = RecvPackedSeq();
+
+    try {
+      GetServingSession()->CallFunc(
+          reinterpret_cast<void*>(call_handle),
+          args.values, args.type_codes, args.size(),
+          [this](TVMArgs ret) { this->ReturnPackedSeq(ret); });
+    } catch (const std::runtime_error& e) {
+      this->ReturnException(e.what());
+    }
+
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+  void HandleInitServer() {
+    std::string client_protocol_ver;
+
+    uint64_t len;
+    this->Read(&len);
+    client_protocol_ver.resize(len);
+    this->Read(dmlc::BeginPtr(client_protocol_ver), len);
+
+    TVMArgs args = RecvPackedSeq();
+
+    try {
+      CHECK(serving_session_ == nullptr)
+          << "Server has already been initialized";
+
+      std::string server_protocol_ver = kRPCProtocolVer;
+      CHECK_EQ(client_protocol_ver, server_protocol_ver)
+          << "Server[" << name_ << "]: Client protocol version mismatch with the server "
+          << " server protocol=" << server_protocol_ver
+          << ", client protocol=" << client_protocol_ver;
+
+      if (args.size() == 0) {
+        serving_session_ = std::make_shared<LocalSession>();
+      } else {
+        std::string constructor_name = args[0];
+        auto* fconstructor = Registry::Get(constructor_name);
+        CHECK(fconstructor != nullptr)
+            << " Cannot find session constructor " << constructor_name;
+        TVMRetValue con_ret;
+
+        try {
+          fconstructor->CallPacked(
+              TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), &con_ret);
+        } catch (const dmlc::Error& e) {
+          LOG(FATAL) << "Server[" << name_ << "]:"
+                     << " Error caught from session constructor " << constructor_name
+                     << ":\n" << e.what();
+        }
+
+        CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
+            << "Server[" << name_ << "]:"
+            << " Constructor " << constructor_name
+            << " need to return an RPCModule";
+        Module mod = con_ret;
+        std::string tkey = mod->type_key();
+        CHECK_EQ(tkey, "rpc")
+            << "Constructor " << constructor_name << " to return an RPCModule";
+        serving_session_ = RPCModuleGetSession(mod);
+      }
+
+      this->ReturnVoid();
+    } catch (const std::runtime_error &e) {
+      this->ReturnException(e.what());
+    }
+
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+  // Handler for special syscalls that have a specific RPCCode.
+  template<typename F>
+  void SysCallHandler(F f) {
+    TVMArgs args = RecvPackedSeq();
+    try {
+      TVMRetValue rv;
+      f(GetServingSession(), args, &rv);
+      TVMValue ret_value;
+      int ret_tcode;
+      TVMArgsSetter setter(&ret_value, &ret_tcode);
+      setter(0, rv);
+
+      this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1));
+    } catch (const std::runtime_error& e) {
+      this->ReturnException(e.what());
+    }
+    this->SwitchToState(kRecvPacketNumBytes);
+  }
+
+ private:
+  RPCSession* GetServingSession() const {
+    CHECK(serving_session_ != nullptr)
+        << "Need to call InitRemoteSession first before any further actions";
+    return serving_session_.get();
+  }
+  // Utility functions
+  // Internal read function, update pending_request_bytes_
+  size_t Read(void* data, size_t size) final {
+    CHECK_LE(size, pending_request_bytes_);
+    reader_->Read(data, size);
+    pending_request_bytes_ -= size;
+    return size;
+  }
+  // wriite the data to the channel.
+  void Write(const void* data, size_t size) final {
+    writer_->Write(data, size);
+  }
+  // Number of pending bytes requests
+  size_t pending_request_bytes_{0};
+  // The ring buffer to read data from.
+  support::RingBuffer* reader_;
+  // The ringr buffer to write reply to.
+  support::RingBuffer* writer_;
+  // The session used to serve the RPC requests.
+  std::shared_ptr<RPCSession> serving_session_;
+  // Name of endpoint.
+  std::string name_;
+  // remote key
+  std::string* remote_key_;
+};
+
+RPCCode RPCEndpoint::HandleUntilReturnEvent(
+    bool client_mode, RPCSession::FEncodeReturn setreturn) {
+  RPCCode code = RPCCode::kCallFunc;
+  while (code != RPCCode::kReturn &&
+         code != RPCCode::kShutdown &&
+         code != RPCCode::kCopyAck) {
+    while (writer_.bytes_available() != 0) {
+      writer_.ReadWithCallback([this](const void *data, size_t size) {
+          return channel_->Send(data, size);
+        }, writer_.bytes_available());
+    }
+    size_t bytes_needed = handler_->BytesNeeded();
+    if (bytes_needed != 0) {
+      size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
+          return channel_->Recv(data, size);
+        }, bytes_needed);
+      if (n == 0) {
+        if (handler_->CanCleanShutdown()) {
+          return RPCCode::kShutdown;
+        } else {
+          LOG(FATAL) << "Channel closes before we get neded bytes";
+        }
+      }
+    }
+    code = handler_->HandleNextEvent(client_mode, setreturn);
+  }
+  return code;
+}
+
+void RPCEndpoint::Init() {
+  // Event handler
+  handler_ = std::make_shared<EventHandler>(
+      &reader_, &writer_, name_, &remote_key_);
+  // Quick function to for syscall remote.
+  syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    RPCCode code = static_cast<RPCCode>(all_args[0].operator int());
+    TVMArgs args(all_args.values + 1, all_args.type_codes +1, all_args.num_args -1);
+
+    uint64_t packet_nbytes =
+        sizeof(code) +
+        handler_->PackedSeqGetNumBytes(
+            args.values, args.type_codes, args.num_args, true);
+
+    // All packet begins with packet nbytes
+    handler_->Write(packet_nbytes);
+    handler_->Write(code);
+    handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
+
+    code = HandleUntilReturnEvent(true, [rv](TVMArgs args) {
+      CHECK_EQ(args.size(), 1);
+      *rv = args[0];
+    });
+    CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
+  });
+}
+
+std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(
+    std::unique_ptr<RPCChannel> channel,
+    std::string name,
+    std::string remote_key) {
+  std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
+  endpt->channel_ = std::move(channel);
+  endpt->name_ = std::move(name);
+  endpt->remote_key_ = std::move(remote_key);
+  endpt->Init();
+  return endpt;
+}
+
+RPCEndpoint::~RPCEndpoint() {
+  this->Shutdown();
+}
+
+void RPCEndpoint::Shutdown() {
+  if (channel_ != nullptr) {
+    RPCCode code = RPCCode::kShutdown;
+    uint64_t packet_nbytes = sizeof(code);
+
+    handler_->Write(packet_nbytes);
+    handler_->Write(code);
+
+    // flush all writing buffer to output channel.
+    try {
+      while (writer_.bytes_available() != 0) {
+        size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
+            return channel_->Send(data, size);
+          }, writer_.bytes_available());
+        if (n == 0) break;
+      }
+    } catch (const dmlc::Error& e) {
+    }
+    channel_.reset(nullptr);
+  }
+}
+
+void RPCEndpoint::ServerLoop() {
+  if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
+    (*f)();
+  }
+  TVMRetValue rv;
+  CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown);
+  if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
+    (*f)();
+  }
+  channel_.reset(nullptr);
+}
+
+int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) {
+  RPCCode code = RPCCode::kNone;
+  if (in_bytes.length() != 0) {
+    reader_.Write(in_bytes.c_str(), in_bytes.length());
+    code = handler_->HandleNextEvent(false, [](TVMArgs) {});
+  }
+  if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
+    writer_.ReadWithCallback([this](const void *data, size_t size) {
+        return channel_->Send(data, size);
+      }, writer_.bytes_available());
+  }
+  CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
+  if (code == RPCCode::kShutdown) return 0;
+  if (writer_.bytes_available() != 0) return 2;
+  return 1;
+}
+
+void RPCEndpoint::InitRemoteSession(TVMArgs args) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  RPCCode code = RPCCode::kInitServer;
+  std::string protocol_ver = kRPCProtocolVer;
+  uint64_t length = protocol_ver.length();
+
+  uint64_t packet_nbytes =
+      sizeof(code) +
+      sizeof(length) +
+      length +
+      handler_->PackedSeqGetNumBytes(
+          args.values, args.type_codes, args.num_args, true);
+
+  // All packet begins with packet nbytes
+  handler_->Write(packet_nbytes);
+  handler_->Write(code);
+  handler_->Write(length);
+  handler_->WriteArray(protocol_ver.data(), length);
+  handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
+
+  code = HandleUntilReturnEvent(true, [](TVMArgs args) {});
+  CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
+}
+
+// Get remote function with name
+void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h,
+                           const TVMValue* arg_values,
+                           const int* arg_type_codes,
+                           int num_args,
+                           RPCSession::FEncodeReturn encode_return) {
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  handler_->ValidateArguments(arg_values, arg_type_codes, num_args);
+  RPCCode code = RPCCode::kCallFunc;
+  uint64_t handle = reinterpret_cast<uint64_t>(h);
+
+  uint64_t packet_nbytes =
+      sizeof(code) +
+      sizeof(handle) +
+      handler_->PackedSeqGetNumBytes(
+          arg_values, arg_type_codes, num_args, true);
+
+  handler_->Write(packet_nbytes);
+  handler_->Write(code);
+  handler_->Write(handle);
+  handler_->SendPackedSeq(
+      arg_values, arg_type_codes, num_args, true);
+
+  code = HandleUntilReturnEvent(true, encode_return);
+  CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
+}
+
+void RPCEndpoint::CopyToRemote(void* from,
+                               size_t from_offset,
+                               void* to,
+                               size_t to_offset,
+                               size_t data_size,
+                               TVMContext ctx_to,
+                               DLDataType type_hint) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  RPCCode code = RPCCode::kCopyToRemote;
+  uint64_t handle = reinterpret_cast<uint64_t>(to);
+  uint64_t offset = static_cast<uint64_t>(to_offset);
+  uint64_t size = static_cast<uint64_t>(data_size);
+
+  uint64_t packet_nbytes =
+      sizeof(code) +
+      sizeof(handle) +
+      sizeof(offset) +
+      sizeof(size) +
+      sizeof(ctx_to) +
+      sizeof(type_hint) +
+      data_size;
+
+  handler_->Write(packet_nbytes);
+  handler_->Write(code);
+  handler_->Write(handle);
+  handler_->Write(offset);
+  handler_->Write(size);
+  handler_->Write(ctx_to);
+  handler_->Write(type_hint);
+  handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
+
+  CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn);
+}
+
+void RPCEndpoint::CopyFromRemote(void* from,
+                                size_t from_offset,
+                                void* to,
+                                size_t to_offset,
+                                size_t data_size,
+                                TVMContext ctx_from,
+                                DLDataType type_hint) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  RPCCode code = RPCCode::kCopyFromRemote;
+  uint64_t handle = reinterpret_cast<uint64_t>(from);
+  uint64_t offset = static_cast<uint64_t>(from_offset);
+  uint64_t size = static_cast<uint64_t>(data_size);
+
+  uint64_t packet_nbytes =
+      sizeof(code) +
+      sizeof(handle) +
+      sizeof(offset) +
+      sizeof(size) +
+      sizeof(ctx_from) +
+      sizeof(type_hint);
+
+  handler_->Write(packet_nbytes);
+  handler_->Write(code);
+  handler_->Write(handle);
+  handler_->Write(offset);
+  handler_->Write(size);
+  handler_->Write(ctx_from);
+  handler_->Write(type_hint);
+
+  TVMRetValue rv;
+  CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck);
+  handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
+  handler_->FinishCopyAck();
+}
+
+// SysCallEventHandler functions
+void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
+  std::string name = args[0];
+  *rv = handler->GetFunction(name);
+}
+
+void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  void* handle = args[0];
+  int type_code = args[1];
+  handler->FreeHandle(handle, type_code);
+}
+
+void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  TVMContext ctx = args[0];
+  handler->GetDeviceAPI(ctx)->SetDevice(ctx);
+}
+
+void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  TVMContext ctx = args[0];
+  DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
+  if (kind == kExist) {
+    DeviceAPI* api = handler->GetDeviceAPI(ctx, true);
+    if (api != nullptr) {
+      api->GetAttr(ctx, kind, rv);
+    } else {
+      *rv = 0;
+    }
+  } else {
+    handler->GetDeviceAPI(ctx)->GetAttr(
+        ctx, static_cast<DeviceAttrKind>(kind), rv);
+  }
+}
+
+void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  TVMContext ctx = args[0];
+  uint64_t nbytes = args[1];
+  uint64_t alignment = args[2];
+  DLDataType type_hint = args[3];
+  void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(
+      ctx, nbytes, alignment, type_hint);
+  *rv = data;
+}
+
+void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  TVMContext ctx = args[0];
+  void* ptr = args[1];
+  handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr);
+}
+
+void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  TVMContext ctx = args[0];
+  TVMStreamHandle handle = args[1];
+  handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle);
+}
+
+void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
+  void* from = args[0];
+  uint64_t from_offset = args[1];
+  void* to = args[2];
+  uint64_t to_offset = args[3];
+  uint64_t size = args[4];
+  TVMContext ctx_from = args[5];
+  TVMContext ctx_to = args[6];
+  DLDataType type_hint = args[7];
+  TVMStreamHandle stream = args[8];
+  TVMContext ctx = ctx_from;
+
+  if (ctx.device_type == kDLCPU) {
+    ctx = ctx_to;
+  } else {
+    CHECK(ctx_to.device_type == kDLCPU ||
+          ctx_to.device_type == ctx_from.device_type)
+        << "Can not copy across different ctx types directly";
+  }
+  handler->GetDeviceAPI(ctx)->CopyDataFromTo(
+      from, from_offset,
+      to, to_offset,
+      size, ctx_from, ctx_to, type_hint, stream);
+}
+
+void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
+  // Event handler sit at clean state at this point.
+  switch (code) {
+    // system functions
+    case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break;
+    case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break;
+    case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break;
+    case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break;
+    case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break;
+    case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break;
+    case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break;
+    case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break;
+    default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
+  }
+
+  CHECK_EQ(state_, kRecvPacketNumBytes);
+}
+
+/*!
+ * \brief RPC client session that proxies all calls to an endpoint.
+ */
+class RPCClientSession : public RPCSession,
+                         public DeviceAPI {
+ public:
+  /*!
+   * \brief param endpoint The client endpoint of the session.
+   */
+  explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint)
+      : endpoint_(endpoint) {}
+
+  // function overrides
+  PackedFuncHandle GetFunction(const std::string& name) final {
+    return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name);
+  }
+
+  void CallFunc(PackedFuncHandle func,
+                const TVMValue* arg_values,
+                const int* arg_type_codes,
+                int num_args,
+                const FEncodeReturn& fencode_return) final {
+    endpoint_->CallFunc(
+        func, arg_values, arg_type_codes, num_args, fencode_return);
+  }
+
+  void CopyToRemote(void* from,
+                    size_t from_offset,
+                    void* to,
+                    size_t to_offset,
+                    size_t nbytes,
+                    TVMContext ctx_to,
+                    DLDataType type_hint) final {
+    endpoint_->CopyToRemote(
+        from, from_offset, to, to_offset, nbytes, ctx_to, type_hint);
+  }
+
+  void CopyFromRemote(void* from,
+                      size_t from_offset,
+                      void* to,
+                      size_t to_offset,
+                      size_t nbytes,
+                      TVMContext ctx_from,
+                      DLDataType type_hint) final {
+    endpoint_->CopyFromRemote(
+        from, from_offset, to, to_offset, nbytes, ctx_from, type_hint);
+  }
+
+  void FreeHandle(void* handle, int type_code) final {
+    endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code);
+  }
+
+
+  void SetDevice(TVMContext ctx) final {
+    endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx);
+  }
+
+  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
+    if (ctx.device_type == kDLCPU && kind == kExist) {
+      // cpu always exists.
+      *rv = 1;
+    } else {
+      *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast<int>(kind));
+    }
+  }
+
+  void* AllocDataSpace(TVMContext ctx,
+                       size_t nbytes,
+                       size_t alignment,
+                       DLDataType type_hint) final {
+    return endpoint_->SysCallRemote(
+        RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint);
+  }
+
+  void FreeDataSpace(TVMContext ctx, void* ptr) final {
+    endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr);
+  }
+
+  void CopyDataFromTo(const void* from,
+                      size_t from_offset,
+                      void* to,
+                      size_t to_offset,
+                      size_t size,
+                      TVMContext ctx_from,
+                      TVMContext ctx_to,
+                      DLDataType type_hint,
+                      TVMStreamHandle stream) final {
+    endpoint_->SysCallRemote(
+        RPCCode::kCopyAmongRemote,
+        const_cast<void*>(from), from_offset,
+        to, to_offset,
+        size,
+        ctx_from, ctx_to,
+        type_hint, stream);
+  }
+
+  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
+    endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream);
+  }
+
+  DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final {
+    return this;
+  }
+
+ private:
+  std::shared_ptr<RPCEndpoint> endpoint_;
+};
+
+std::shared_ptr<RPCSession>
+CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
+  return std::make_shared<RPCClientSession>(endpoint);
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h
new file mode 100644 (file)
index 0000000..9a6afcd
--- /dev/null
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_endpoint.h
+ * \brief Communication endpoints to connect local and remote RPC sessions.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
+#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
+
+#include <tvm/runtime/packed_func.h>
+#include <mutex>
+#include <string>
+#include <memory>
+#include <utility>
+#include "rpc_session.h"
+#include "rpc_channel.h"
+#include "rpc_protocol.h"
+#include "../../support/ring_buffer.h"
+
+namespace tvm {
+namespace runtime {
+
+// Magic header for RPC data plane
+const int kRPCMagic = 0xff271;
+// magic header for RPC tracker(control plane)
+const int kRPCTrackerMagic = 0x2f271;
+// sucess response
+const int kRPCSuccess = kRPCMagic + 0;
+// cannot found matched key in server
+const int kRPCMismatch = kRPCMagic + 2;
+
+/*! \brief Enumeration code for the RPC tracker */
+enum class TrackerCode : int {
+  kFail = -1,
+  kSuccess = 0,
+  kPing = 1,
+  kStop = 2,
+  kPut = 3,
+  kRequest = 4,
+  kUpdateInfo = 5,
+  kSummary = 6,
+  kGetPendingMatchKeys = 7
+};
+
+
+/*!
+ * \brief Communication endpoints to connect local and remote RPC sessions.
+ *        An endpoint can either be a client or a server.
+ */
+class RPCEndpoint {
+ public:
+  /*! \brief virtual destructor */
+  ~RPCEndpoint();
+  /*!
+   *  \brief The server loop that server runs to handle RPC calls.
+   */
+  void ServerLoop();
+  /*!
+   * \brief Message handling function for an async IO event driven server.
+   *
+   *  Called when the server receives a message or an IO event update.
+   *  Event driven handler will never call recv on the channel
+   *  and always relies on the ServerIOEventHandler to receive the data.
+   *
+   * \param in_bytes The incoming bytes.
+   * \param event_flag  1: read_available, 2: write_avaiable.
+   * \return State flag.
+   *     1: continue running, no need to write,
+   *     2: need to write
+   *     0: shutdown
+   */
+  int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag);
+
+  /*!
+   * \brief Initalize the session on the remote that will be used to back all the RPC requests.
+   *
+   *  If no session constructor arguments is passed, LocalSession will be used in the remote.
+   *  Otherwise the remote serving session will be constructed using the arguments
+   *  specified in the session_constructor_args.
+   *
+   *  The construction rule can be summarized as follows:
+   *
+   * \code
+   *
+   *  auto args = session_constructor_args;
+   *  int n = args.size();
+   *  if (n != 0) {
+   *    std::string constructor = args[0];
+   *    server.serving_session_ = GetGlobalFunc(constructor)(
+   *        args[1], args[2] ... args[n - 1])
+   *  } else {
+   *    server.serving_session_ = LocalSession();
+   *  }
+   * \endcode
+   *
+   * \param session_constructor_args Optional sequence of the remote sesssion constructor.
+   */
+  void InitRemoteSession(TVMArgs session_constructor_args);
+
+  /*!
+   * \brief Call into remote function
+   * \param handle The function handle
+   * \param arg_values The argument values.
+   * \param arg_type_codes the type codes of the argument.
+   * \param num_args Number of arguments.
+   * \param fencode_return The function to receive return value encodings.
+   */
+  void CallFunc(RPCSession::PackedFuncHandle handle,
+                const TVMValue* arg_values,
+                const int* arg_type_codes,
+                int num_args,
+                RPCSession::FEncodeReturn encode_return);
+  /*!
+   * \brief Copy bytes into remote array content.
+   * \param from The source host data.
+   * \param from_offset The byte offeset in the from.
+   * \param to The target array.
+   * \param to_offset The byte offset in the to.
+   * \param nbytes The size of the memory in bytes.
+   * \param ctx_to The target context.
+   * \param type_hint Hint of content data type.
+   */
+  void CopyToRemote(void* from,
+                    size_t from_offset,
+                    void* to,
+                    size_t to_offset,
+                    size_t nbytes,
+                    TVMContext ctx_to,
+                    DLDataType type_hint);
+  /*!
+   * \brief Copy bytes from remote array content.
+   * \param from The source host data.
+   * \param from_offset The byte offeset in the from.
+   * \param to The target array.
+   * \param to_offset The byte offset in the to.
+   * \param nbytes The size of the memory in bytes.
+   * \param ctx_from The source context.
+   * \param type_hint Hint of content data type.
+   */
+  void CopyFromRemote(void* from,
+                      size_t from_offset,
+                      void* to,
+                      size_t to_offset,
+                      size_t nbytes,
+                      TVMContext ctx_from,
+                      DLDataType type_hint);
+
+  /*!
+   * \brief Call a remote defined system function with arguments.
+   * \param fcode The function code.
+   * \param args The arguments
+   * \return The returned remote value.
+   */
+  template<typename... Args>
+  inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args);
+  /*!
+   * \brief Create a RPC session with given channel.
+   * \param channel The communication channel.
+   * \param name The local name of the session, used for debug
+   * \param remote_key The remote key of the session
+   *   if remote_key equals "%toinit", we need to re-intialize
+   *   it by event handler.
+   */
+  static std::shared_ptr<RPCEndpoint> Create(
+      std::unique_ptr<RPCChannel> channel,
+      std::string name,
+      std::string remote_key);
+
+ private:
+  class EventHandler;
+  // Handle events until receives a return
+  // Also flushes channels so that the function advances.
+  RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn);
+  // Initalization
+  void Init();
+  // Shutdown
+  void Shutdown();
+  // Internal channel.
+  std::unique_ptr<RPCChannel> channel_;
+  // Internal mutex
+  std::mutex mutex_;
+  // Internal ring buffer.
+  support::RingBuffer reader_, writer_;
+  // Event handler.
+  std::shared_ptr<EventHandler> handler_;
+  // syscall remote with specified function code.
+  PackedFunc syscall_remote_;
+  // The name of the session.
+  std::string name_;
+  // The remote key
+  std::string remote_key_;
+};
+
+/*!
+ * \brief Create an RPC client session from an RPC client endpoint.
+ * \param endpoint The endpoint.
+ * \return The created session.
+ */
+std::shared_ptr<RPCSession>
+CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint);
+
+// implementation of inline functions
+template<typename... Args>
+inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) {
+  return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
+}
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_
index 29adb0f..284dca5 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 
 /*!
  * \file rpc_event_impl.cc
- * \brief Event based RPC server implementation.
+ * \brief Event driven RPC server implementation.
  */
 #include <tvm/runtime/registry.h>
 #include <memory>
-#include "rpc_session.h"
+#include "rpc_endpoint.h"
+#include "rpc_local_session.h"
 
 namespace tvm {
 namespace runtime {
@@ -35,16 +36,17 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
     LOG(FATAL) << "Do not allow explicit receive";
     return 0;
   });
+
   std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend, frecv));
-  std::shared_ptr<RPCSession> sess =
-      RPCSession::Create(std::move(ch), name, remote_key);
+  std::shared_ptr<RPCEndpoint> sess =
+      RPCEndpoint::Create(std::move(ch), name, remote_key);
   return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
-      int ret = sess->ServerEventHandler(args[0], args[1]);
+      int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]);
       *rv = ret;
     });
 }
 
-TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
+TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer")
 .set_body_typed(CreateEventDrivenServer);
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc
new file mode 100644 (file)
index 0000000..0a2809b
--- /dev/null
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file local_session.cc
+ * \brief Local session that directs requests to local API.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/device_api.h>
+#include <memory>
+#include "rpc_local_session.h"
+
+namespace tvm {
+namespace runtime {
+
+RPCSession::PackedFuncHandle
+LocalSession::GetFunction(const std::string& name) {
+  PackedFunc pf = this->GetFunctionInternal(name);
+  // return raw handl because the remote need to explicitly manage it.
+  if (pf != nullptr) return new PackedFunc(pf);
+  return nullptr;
+}
+
+void LocalSession::CallFunc(RPCSession::PackedFuncHandle func,
+                            const TVMValue* arg_values,
+                            const int* arg_type_codes,
+                            int num_args,
+                            const FEncodeReturn& encode_return) {
+  auto* pf = static_cast<PackedFunc*>(func);
+  TVMRetValue rv;
+
+  pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
+  int rv_tcode = rv.type_code();
+
+  // return value encoding.
+  TVMValue ret_value_pack[3];
+  int ret_tcode_pack[3];
+  TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack);
+  // first location always encode type code.
+  set_arg(0, rv_tcode);
+
+  if (rv_tcode == kTVMNDArrayHandle) {
+    // We follow a special protocol to return NDArray to client side
+    // The first pack value is the NDArray handle as DLTensor
+    // The second pack value is a customized deleter that deletes the NDArray.
+    rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
+    ret_tcode_pack[1] = kTVMDLTensorHandle;
+    ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
+    ret_tcode_pack[2] = kTVMOpaqueHandle;
+    encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
+  } else if (rv_tcode == kTVMPackedFuncHandle ||
+             rv_tcode == kTVMModuleHandle) {
+    // MoveToCHost means rv no longer manages the object.
+    // return handle instead.
+    rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
+    ret_tcode_pack[1] = kTVMOpaqueHandle;
+    encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
+  } else if (rv_tcode == kTVMBytes) {
+    TVMByteArray byte_arr;
+    auto* sptr = rv.ptr<std::string>();
+    byte_arr.data = sptr->data();
+    byte_arr.size = sptr->length();
+    set_arg(1, byte_arr);
+    encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
+  } else {
+    set_arg(1, rv);
+    encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2));
+  }
+}
+
+void LocalSession::CopyToRemote(void* from,
+                                size_t from_offset,
+                                void* to,
+                                size_t to_offset,
+                                size_t nbytes,
+                                TVMContext ctx_to,
+                                DLDataType type_hint) {
+  TVMContext cpu_ctx;
+  cpu_ctx.device_type = kDLCPU;
+  cpu_ctx.device_id = 0;
+  this->GetDeviceAPI(ctx_to)->CopyDataFromTo(
+      from, from_offset,
+      to, to_offset,
+      nbytes, cpu_ctx, ctx_to, type_hint, nullptr);
+}
+
+void LocalSession::CopyFromRemote(void* from,
+                                  size_t from_offset,
+                                  void* to,
+                                  size_t to_offset,
+                                  size_t nbytes,
+                                  TVMContext ctx_from,
+                                  DLDataType type_hint) {
+  TVMContext cpu_ctx;
+  cpu_ctx.device_type = kDLCPU;
+  cpu_ctx.device_id = 0;
+
+  this->GetDeviceAPI(ctx_from)->CopyDataFromTo(
+      from, from_offset,
+      to, to_offset,
+      nbytes, ctx_from, cpu_ctx, type_hint, nullptr);
+}
+
+void LocalSession::FreeHandle(void* handle, int type_code) {
+  TVMValue value;
+  value.v_handle = handle;
+  // will trigger deleter once the rv goes out of the scope.
+  TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code);
+}
+
+DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) {
+  return DeviceAPI::Get(ctx, allow_missing);
+}
+
+PackedFunc LocalSession::GetFunctionInternal(const std::string& name) {
+  auto* fp = tvm::runtime::Registry::Get(name);
+  if (fp != nullptr) {
+    return *fp;
+  } else {
+    return nullptr;
+  }
+}
+
+TVM_REGISTER_GLOBAL("rpc.LocalSession")
+.set_body_typed([]() {
+  return CreateRPCSessionModule(std::make_shared<LocalSession>());
+});
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h
new file mode 100644 (file)
index 0000000..ebb3ea1
--- /dev/null
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_local_session.h
+ * \brief Local session that directs all request to the local runtime API.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_
+#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_
+
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/device_api.h>
+#include <functional>
+#include <string>
+#include "rpc_session.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief A local session that directly use the handle repr of the
+ *        local tvm runtime objects on the same process.
+ */
+class LocalSession : public RPCSession {
+ public:
+  // function overrides
+  PackedFuncHandle GetFunction(const std::string& name) final;
+
+  void CallFunc(PackedFuncHandle func,
+                const TVMValue* arg_values,
+                const int* arg_type_codes,
+                int num_args,
+                const FEncodeReturn& fencode_return) final;
+
+  void CopyToRemote(void* from,
+                    size_t from_offset,
+                    void* to,
+                    size_t to_offset,
+                    size_t nbytes,
+                    TVMContext ctx_to,
+                    DLDataType type_hint) final;
+
+  void CopyFromRemote(void* from,
+                      size_t from_offset,
+                      void* to,
+                      size_t to_offset,
+                      size_t nbytes,
+                      TVMContext ctx_from,
+                      DLDataType type_hint) final;
+
+  void FreeHandle(void* handle, int type_code) final;
+
+  DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final;
+
+ protected:
+  /*!
+   * \brief Internal implementation of GetFunction.
+   * \param name The name of the function.
+   * \return The corresponding PackedFunc.
+   */
+  virtual PackedFunc GetFunctionInternal(const std::string& name);
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_
index 0e48e6f..1062304 100644 (file)
  */
 
 /*!
- * \file rpc_device_api.cc
- * \brief RPC module.
+ * \file rpc_module.cc
+ * \brief RPC runtime module.
  */
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <memory>
 #include <cstring>
+#include "rpc_endpoint.h"
 #include "rpc_session.h"
 
 namespace tvm {
 namespace runtime {
 
-// Wrapped remote function to packed func.
-class RPCWrappedFunc {
+/*!
+ * \brief A wrapped remote function as a PackedFunc.
+ */
+class RPCWrappedFunc : public Object {
  public:
   RPCWrappedFunc(void* handle,
                  std::shared_ptr<RPCSession> sess)
       : handle_(handle), sess_(sess) {
-    fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
-        WrapRemote(sess, args, rv);
-      });
   }
 
-  void operator()(TVMArgs args, TVMRetValue *rv) const {
-    sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_);
+  void operator()(TVMArgs args, TVMRetValue* rv) const {
+    std::vector<TVMValue> values(args.values, args.values + args.size());
+    std::vector<int> type_codes(args.type_codes, args.type_codes + args.size());
+    std::vector<std::unique_ptr<DLTensor>> temp_dltensors;
+
+    // scan and check whether we need rewrite these arguments
+    // to their remote variant.
+    for (int i = 0; i < args.size(); ++i) {
+      int tcode = type_codes[i];
+
+      switch (tcode) {
+        case kTVMDLTensorHandle:
+        case kTVMNDArrayHandle: {
+          // Pass NDArray as DLTensor, NDArray and DLTensor
+          // are compatible to each other, just need to change the index.
+          type_codes[i] = kTVMDLTensorHandle;
+          // translate to a remote view of DLTensor
+          auto dptr = std::make_unique<DLTensor>(
+              *static_cast<DLTensor*>(values[i].v_handle));
+          dptr->ctx = RemoveSessMask(dptr->ctx);
+          dptr->data = static_cast<RemoteSpace*>(dptr->data)->data;
+          values[i].v_handle = dptr.get();
+          temp_dltensors.emplace_back(std::move(dptr));
+          break;
+        }
+        case kTVMContext: {
+          values[i].v_ctx = RemoveSessMask(values[i].v_ctx);
+          break;
+        }
+        case kTVMPackedFuncHandle:
+        case kTVMModuleHandle: {
+          values[i].v_handle = UnwrapRemoteValueToHandle(
+              TVMArgValue(values[i], tcode));
+          break;
+        }
+      }
+    }
+    auto set_return = [this, rv](TVMArgs args) {
+      this->WrapRemoteReturnToValue(args, rv);
+    };
+    sess_->CallFunc(handle_, values.data(), type_codes.data(),
+                    args.size(), set_return);
   }
+
   ~RPCWrappedFunc() {
     try {
-      sess_->CallRemote(RPCCode::kFreeFunc, handle_);
+      sess_->FreeHandle(handle_, kTVMPackedFuncHandle);
     } catch (const dmlc::Error& e) {
       // fault tolerance to remote close
     }
   }
 
-  static void WrapRemote(std::shared_ptr<RPCSession> sess,
-                         TVMArgs args,
-                         TVMRetValue* rv);
+ private:
+  // remote function handle
+  void* handle_{nullptr};
+  // pointer to the session.
+  std::shared_ptr<RPCSession> sess_;
 
-  static void* UnwrapRemote(int rpc_sess_table_index,
-                            const TVMArgValue& arg);
+  // unwrap a remote value to the underlying handle.
+  void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const;
+  // wrap a remote return via Set
+  void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const;
+
+  // remove a remote session mask
+  TVMContext RemoveSessMask(TVMContext ctx) const {
+    int dev_type = ctx.device_type;
+    CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1)
+        << "Can not pass in local context or context with a different remote session";
+    ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
+    return ctx;
+  }
 
   // deleter of RPC remote array
   static void RemoteNDArrayDeleter(Object* obj) {
     auto* ptr = static_cast<NDArray::Container*>(obj);
     RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
-    space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
+    space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle);
     delete space;
     delete ptr;
   }
+
   // wrap return value as remote NDArray.
-  static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess,
-                                   DLTensor* tensor,
-                                   void* nd_handle) {
+  NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const {
     NDArray::Container* data = new NDArray::Container();
     data->manager_ctx = nd_handle;
     data->SetDeleter(RemoteNDArrayDeleter);
     RemoteSpace* space = new RemoteSpace();
-    space->sess = sess;
+    space->sess = sess_;
     space->data = tensor->data;
     data->dl_tensor.data = space;
     NDArray ret(GetObjectPtr<Object>(data));
@@ -89,18 +143,13 @@ class RPCWrappedFunc {
     data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
     data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
         static_cast<int>(tensor->ctx.device_type) +
-        kRPCSessMask * (sess->table_index() + 1));
+        kRPCSessMask * (sess_->table_index() + 1));
     // check strides.
     CHECK(tensor->strides == nullptr);
     // setup byteoffset
     data->dl_tensor.byte_offset = tensor->byte_offset;
     return ret;
   }
-
- private:
-  PackedFunc fwrap_;
-  void* handle_{nullptr};
-  std::shared_ptr<RPCSession> sess_;
 };
 
 // RPC that represents a remote module session.
@@ -109,10 +158,11 @@ class RPCModuleNode final : public ModuleNode {
   RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
       : module_handle_(module_handle), sess_(sess) {
   }
+
   ~RPCModuleNode() {
     if (module_handle_ != nullptr) {
       try {
-        sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
+        sess_->FreeHandle(module_handle_, kTVMModuleHandle);
       } catch (const dmlc::Error& e) {
         // fault tolerance to remote close
       }
@@ -127,31 +177,56 @@ class RPCModuleNode final : public ModuleNode {
   PackedFunc GetFunction(
       const std::string& name,
       const ObjectPtr<Object>& sptr_to_self) final {
-    RPCFuncHandle handle = GetFuncHandle(name);
-    return WrapRemote(handle);
+    if (module_handle_ == nullptr) {
+      return WrapRemoteFunc(sess_->GetFunction(name));
+    } else {
+      InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction");
+      return remote_mod_get_function_(GetRef<Module>(this), name, false);
+    }
   }
 
   std::string GetSource(const std::string& format) final {
-    if (module_handle_ != nullptr) {
-      std::string ret =  sess_->CallRemote(
-          RPCCode::kModuleGetSource, module_handle_, format);
-    }
+    LOG(FATAL) << "GetSource for rpc Module is not supported";
     return "";
   }
 
-  std::shared_ptr<RPCSession>& sess() {
-    return sess_;
-  }
-
   PackedFunc GetTimeEvaluator(const std::string& name,
                               TVMContext ctx,
                               int number,
                               int repeat,
                               int min_repeat_ms) {
-    RPCFuncHandle handle = GetFuncHandle(name);
-    if (handle == nullptr) return PackedFunc();
-    handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms);
-    return WrapRemote(handle);
+    InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator");
+    // Remove session mask because we pass ctx by parts.
+    int dev_type = ctx.device_type;
+    CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1)
+        << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator";
+    ctx.device_type = static_cast<DLDeviceType>(ctx.device_type % kRPCSessMask);
+
+    if (module_handle_ != nullptr) {
+      return remote_get_time_evaluator_(
+          GetRef<Module>(this), name,
+          static_cast<int>(ctx.device_type), ctx.device_id,
+          number, repeat, min_repeat_ms);
+    } else {
+      return remote_get_time_evaluator_(
+          Optional<Module>(nullptr), name,
+          static_cast<int>(ctx.device_type), ctx.device_id,
+          number, repeat, min_repeat_ms);
+    }
+  }
+
+  Module LoadModule(std::string name) {
+    InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module");
+    return remote_load_module_(name);
+  }
+
+  void ImportModule(Module other) {
+    InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule");
+    remote_import_module_(GetRef<Module>(this), other);
+  }
+
+  const std::shared_ptr<RPCSession>& sess() {
+    return sess_;
   }
 
   void* module_handle() const {
@@ -159,7 +234,15 @@ class RPCModuleNode final : public ModuleNode {
   }
 
  private:
-  PackedFunc WrapRemote(RPCFuncHandle handle) {
+  template<typename FType>
+  void InitRemoteFunc(FType* func, const std::string& name) {
+    if (*func != nullptr) return;
+    RPCSession::PackedFuncHandle handle = sess_->GetFunction(name);
+    CHECK(handle != nullptr) << "Cannot found remote function " << name;
+    *func = WrapRemoteFunc(handle);
+  }
+
+  PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) {
     if (handle == nullptr) return PackedFunc();
     auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
     return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
@@ -167,33 +250,30 @@ class RPCModuleNode final : public ModuleNode {
       });
   }
 
-  RPCFuncHandle GetFuncHandle(const std::string& name) {
-    RPCFuncHandle handle = nullptr;
-    if (module_handle_ == nullptr) {
-      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
-    } else {
-      handle = sess_->CallRemote(
-          RPCCode::kModuleGetFunc, module_handle_, name);
-    }
-    return handle;
-  }
   // The module handle
   void* module_handle_{nullptr};
   // The local channel
   std::shared_ptr<RPCSession> sess_;
-  // Wrap function to wrap remote module/function.
-  PackedFunc fwrap_;
+  // remote function to get time evaluator
+  TypedPackedFunc<PackedFunc(Optional<Module>, std::string, int, int, int, int, int)>
+  remote_get_time_evaluator_;
+  // remote function getter for modules.
+  TypedPackedFunc<PackedFunc(Module, std::string, bool)> remote_mod_get_function_;
+  // remote function getter for load module
+  TypedPackedFunc<Module(std::string)> remote_load_module_;
+  // remote function getter for load module
+  TypedPackedFunc<void(Module, Module)> remote_import_module_;
 };
 
-void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
-                                   const TVMArgValue& arg) {
+
+void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const {
   if (arg.type_code() == kTVMModuleHandle) {
     Module mod = arg;
     std::string tkey = mod->type_key();
     CHECK_EQ(tkey, "rpc")
         << "ValueError: Cannot pass a non-RPC module to remote";
     auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
-    CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index)
+    CHECK(rmod->sess() == sess_)
         << "ValueError: Cannot pass in module into a different remote session";
     return rmod->module_handle();
   } else {
@@ -204,93 +284,173 @@ void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
   }
 }
 
-void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
-                                TVMArgs args,
-                                TVMRetValue *rv) {
-  void* handle = args.values[0].v_handle;
-  int tcode = args.type_codes[0];
+void RPCWrappedFunc::WrapRemoteReturnToValue(
+    TVMArgs args,
+    TVMRetValue *rv) const {
+  int tcode = args[0];
 
-  if (handle == nullptr) return;
+  if (tcode == kTVMNullptr) return;
   if (tcode == kTVMPackedFuncHandle) {
-    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
+    CHECK_EQ(args.size(), 2);
+    void* handle = args[1];
+    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
     *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
-        return wf->operator()(args, rv);
-      });
+      return wf->operator()(args, rv);
+    });
   } else if (tcode == kTVMModuleHandle) {
-    auto n = make_object<RPCModuleNode>(handle, sess);
+    CHECK_EQ(args.size(), 2);
+    void* handle = args[1];
+    auto n = make_object<RPCModuleNode>(handle, sess_);
     *rv = Module(n);
   } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
-    CHECK_EQ(args.size(), 2);
-    DLTensor* tensor = args[0];
-    void* nd_handle = args[1];
-    *rv = WrapRemoteNDArray(sess, tensor, nd_handle);
+    CHECK_EQ(args.size(), 3);
+    DLTensor* tensor = args[1];
+    void* nd_handle = args[2];
+    *rv = WrapRemoteNDArray(tensor, nd_handle);
   } else {
-    LOG(FATAL) << "Cannot wrap tcode=" << tcode;
+    CHECK_EQ(args.size(), 2);
+    *rv = args[1];
   }
 }
 
-Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
+Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess) {
   auto n = make_object<RPCModuleNode>(nullptr, sess);
+  RPCSession::InsertToSessionTable(sess);
   return Module(n);
 }
 
+std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod) {
+  std::string tkey = mod->type_key();
+  CHECK_EQ(tkey, "rpc")
+      << "ValueError: Cannot pass a non-RPC module to remote";
+  auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
+  return rmod->sess();
+}
+
+PackedFunc WrapTimeEvaluator(PackedFunc pf,
+                             TVMContext ctx,
+                             int number,
+                             int repeat,
+                             int min_repeat_ms) {
+  CHECK(pf != nullptr);
+
+  if (static_cast<int>(ctx.device_type) == static_cast<int>(kDLMicroDev)) {
+    auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator");
+    CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled";
+    return (*get_micro_time_evaluator)(pf, ctx, number, repeat);
+  }
+
+  auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv)
+                mutable {
+    TVMRetValue temp;
+    std::ostringstream os;
+    // skip first time call, to activate lazy compilation components.
+    pf.CallPacked(args, &temp);
+
+    DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
+
+    for (int i = 0; i < repeat; ++i) {
+      std::chrono::time_point<
+        std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
+      double duration_ms = 0.0;
+
+      do {
+        if (duration_ms > 0.0) {
+          number = static_cast<int>(
+              std::max((min_repeat_ms / (duration_ms / number) + 1),
+                       number * 1.618));   // 1.618 is chosen by random
+        }
+
+        tbegin = std::chrono::high_resolution_clock::now();
+        // start timing
+        for (int i = 0; i < number; ++i) {
+          pf.CallPacked(args, &temp);
+        }
+        DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
+        tend = std::chrono::high_resolution_clock::now();
+
+        duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
+            (tend - tbegin).count() * 1000;
+      } while (duration_ms < min_repeat_ms);
+
+      double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
+          tend - tbegin).count() / number;
+      os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
+    }
+
+    std::string blob = os.str();
+    TVMByteArray arr;
+    arr.size = blob.length();
+    arr.data = blob.data();
+    // return the time.
+    *rv = arr;
+  };
+  return PackedFunc(ftimer);
+}
+
+
 TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    Module m = args[0];
+.set_body_typed([](Optional<Module> opt_mod,
+                   std::string name,
+                   int device_type,
+                   int device_id,
+                   int number,
+                   int repeat,
+                   int min_repeat_ms) {
+  TVMContext ctx;
+  ctx.device_type = static_cast<DLDeviceType>(device_type);
+  ctx.device_id = device_id;
+  if (opt_mod.defined()) {
+    Module m = opt_mod.value();
     std::string tkey = m->type_key();
-    TVMContext ctx;
-    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
-    ctx.device_id = args[3];
     if (tkey == "rpc") {
-      *rv = static_cast<RPCModuleNode*>(m.operator->())
-          ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]);
+      return static_cast<RPCModuleNode*>(m.operator->())
+          ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms);
     } else {
-      *rv = WrapTimeEvaluator(
-          m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]);
+      return WrapTimeEvaluator(
+          m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
     }
-  });
+  } else {
+    auto* pf = runtime::Registry::Get(name);
+    CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
+    return WrapTimeEvaluator(
+        *pf, ctx, number, repeat, min_repeat_ms);
+  }
+});
 
-TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    Module m = args[0];
-    std::string tkey = m->type_key();
-    CHECK_EQ(tkey, "rpc");
-    auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
-    void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
-    auto n = make_object<RPCModuleNode>(mhandle, sess);
-    *rv = Module(n);
-  });
+// server function registration.
+TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule")
+.set_body_typed([](Module parent, Module child) {
+  parent->Import(child);
+});
 
-TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    Module parent = args[0];
-    Module child = args[1];
-    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
-          !std::strcmp(child->type_key(), "rpc"));
-    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
-    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
-    CHECK(pmod->sess().get() == cmod->sess().get())
-        << "Import of remote module need to belong to same session.";
-    pmod->sess()->CallRemote(RPCCode::kModuleImport,
-                             pmod->module_handle(),
-                             cmod->module_handle());
-  });
-
-TVM_REGISTER_GLOBAL("rpc._ModuleHandle")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    Module m = args[0];
-    std::string tkey = m->type_key();
-    CHECK_EQ(tkey, "rpc");
-    *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
-  });
+TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction")
+.set_body_typed([](Module parent, std::string name, bool query_imports) {
+  return parent->GetFunction(name, query_imports);
+});
+
+// functions to access an RPC module.
+TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule")
+.set_body_typed([](Module sess, std::string name) {
+  std::string tkey = sess->type_key();
+  CHECK_EQ(tkey, "rpc");
+  return static_cast<RPCModuleNode*>(sess.operator->())->LoadModule(name);
+});
 
-TVM_REGISTER_GLOBAL("rpc._SessTableIndex")
+TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule")
+.set_body_typed([](Module parent, Module child) {
+  std::string tkey = parent->type_key();
+  CHECK_EQ(tkey, "rpc");
+  static_cast<RPCModuleNode*>(parent.operator->())->ImportModule(child);
+});
+
+TVM_REGISTER_GLOBAL("rpc.SessTableIndex")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    Module m = args[0];
-    std::string tkey = m->type_key();
-    CHECK_EQ(tkey, "rpc");
-    *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
-  });
+  Module m = args[0];
+  std::string tkey = m->type_key();
+  CHECK_EQ(tkey, "rpc");
+  *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
+});
 
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc
new file mode 100644 (file)
index 0000000..376b8b5
--- /dev/null
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_pipe_impl.cc
+ * \brief Pipe-based RPC channel.
+ */
+// Linux only for now, as linux is the most common usecase.
+#if defined(__linux__) || defined(__ANDROID__)
+
+#include <sys/types.h>
+#include <unistd.h>
+#include <errno.h>
+#include <signal.h>
+
+#include <tvm/runtime/registry.h>
+#include <memory>
+#include <cstdlib>
+
+#include "rpc_endpoint.h"
+#include "rpc_local_session.h"
+#include "../../support/pipe.h"
+
+namespace tvm {
+namespace runtime {
+
+class PipeChannel final : public RPCChannel {
+ public:
+  explicit PipeChannel(int readfd, int writefd, pid_t child_pid)
+      : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {
+  }
+
+  ~PipeChannel() {
+    Close();
+  }
+
+  size_t Send(const void* data, size_t size) final {
+    ssize_t n = write(writefd_, data, size);
+    if (n == -1) {
+      LOG(FATAL) << "Pipe write error";
+    }
+    return static_cast<size_t>(n);
+  }
+
+  size_t Recv(void* data, size_t size) final {
+    ssize_t n = read(readfd_, data, size);
+    if (n == -1) {
+      LOG(FATAL) << "Pipe read error";
+    }
+    return static_cast<size_t>(n);
+  }
+
+  void Close() {
+    close(readfd_);
+    close(writefd_);
+    kill(child_pid_, SIGKILL);
+  }
+
+ private:
+  int readfd_;
+  int writefd_;
+  pid_t child_pid_;
+};
+
+
+Module CreatePipeClient(std::vector<std::string> cmd) {
+  int parent2child[2];
+  int child2parent[2];
+  CHECK_EQ(pipe(parent2child), 0);
+  CHECK_EQ(pipe(child2parent), 0);
+
+  int parent_read = child2parent[0];
+  int parent_write = parent2child[1];
+  int child_read = parent2child[0];
+  int child_write = child2parent[1];
+
+  pid_t pid = fork();
+  if (pid == 0) {
+    // child process
+    close(parent_read);
+    close(parent_write);
+    std::string sread_pipe = std::to_string(child_read);
+    std::string swrite_pipe = std::to_string(child_write);
+    std::vector<char*> argv;
+    for (auto& str : cmd) {
+      argv.push_back(dmlc::BeginPtr(str));
+    }
+    argv.push_back(dmlc::BeginPtr(sread_pipe));
+    argv.push_back(dmlc::BeginPtr(swrite_pipe));
+    argv.push_back(nullptr);
+    execvp(argv[0], &argv[0]);
+  }
+  // parent process
+  close(child_read);
+  close(child_write);
+
+  auto endpt = RPCEndpoint::Create(
+      std::unique_ptr<PipeChannel>(
+          new PipeChannel(parent_read, parent_write, pid)),
+      "pipe", "pipe");
+  endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0));
+  return CreateRPCSessionModule(CreateClientSession(endpt));
+}
+
+TVM_REGISTER_GLOBAL("rpc.CreatePipeClient")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+  std::vector<std::string> cmd;
+  for (int i = 0; i < args.size(); ++i) {
+    cmd.push_back(args[i].operator std::string());
+  }
+  *rv = CreatePipeClient(cmd);
+});
+
+
+}  // namespace runtime
+}  // namespace tvm
+#endif
diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h
new file mode 100644 (file)
index 0000000..6221bfb
--- /dev/null
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_procotol.h
+ * \brief Common header defining the communication code used in the RPC protocol.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
+#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
+
+namespace tvm {
+namespace runtime {
+
+/*! \brief The current RPC procotol version. */
+constexpr const char* kRPCProtocolVer = "0.7.0";
+
+/*! \brief The RPC code */
+enum class RPCCode : int {
+  kNone,
+  kShutdown,
+  kInitServer,
+  kCallFunc,
+  kReturn,
+  kException,
+  kCopyFromRemote,
+  kCopyToRemote,
+  kCopyAck,
+  // The following are syscall code that can send over CallRemote
+  kSyscallCodeStart,
+  kGetGlobalFunc = kSyscallCodeStart,
+  kFreeHandle,
+  kDevSetDevice,
+  kDevGetAttr,
+  kDevAllocData,
+  kDevFreeData,
+  kDevStreamSync,
+  kCopyAmongRemote,
+};
+
+/*!
+ * \brief List of potential error status during rpc communication.
+ */
+enum class RPCServerStatus : int {
+  kSuccess = 0,
+  kInvalidTypeCodeObject,
+  kInvalidTypeCodeNDArray,
+  kInvalidDLTensorFieldStride,
+  kInvalidDLTensorFieldByteOffset,
+  kUnknownTypeCode,
+  kUnknownRPCCode,
+  kRPCCodeNotSupported,
+  kUnknownRPCSyscall,
+  kCheckError,
+  kReadError,
+  kWriteError,
+  kAllocError
+};
+
+/*!
+ * \brief Convert RPC server status to string.
+ * \param status The status.
+ * \return The corresponding string.
+ */
+inline const char* RPCServerStatusToString(RPCServerStatus status) {
+  switch (status) {
+    case RPCServerStatus::kSuccess: return "kSuccess";
+    case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject";
+    case RPCServerStatus::kInvalidTypeCodeNDArray: return "kInvalidTypeCodeNDArray";
+    case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride";
+    case RPCServerStatus::kInvalidDLTensorFieldByteOffset: {
+      return "kInvalidDLTensorFieldByteOffset";
+    }
+    case RPCServerStatus::kUnknownTypeCode: return "kUnknownTypeCode";
+    case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode";
+    case RPCServerStatus::kRPCCodeNotSupported: return "RPCCodeNotSupported";
+    case RPCServerStatus::kUnknownRPCSyscall: return "kUnknownRPCSyscall";
+    case RPCServerStatus::kCheckError: return "kCheckError";
+    case RPCServerStatus::kReadError: return "kReadError";
+    case RPCServerStatus::kWriteError: return "kWriteError";
+    case RPCServerStatus::kAllocError: return "kAllocError";
+    default: return "";
+  }
+}
+
+/*!
+ * \brief Reference implementation of the communication protocol.
+ *
+ * \note The implementation is intentionally written via template
+ *       so it can be used in a dependency free setting.
+ *
+ * \sa src/runtime/rpc/device/min_rpc_server.h
+ */
+struct RPCReference {
+  /*!
+   * \brief Auxiliary class to get the packed sequence.
+   * \tparam TChannel The channel to throw errror.
+   */
+  template<typename TChannel>
+  struct PackedSeqNumBytesGetter {
+   public:
+    explicit PackedSeqNumBytesGetter(TChannel* channel)
+        : channel_(channel) {}
+
+    template <typename T>
+    void Write(const T& value) {
+      num_bytes_ += sizeof(T);
+    }
+
+    template <typename T>
+    void WriteArray(const T* value, size_t num) {
+      num_bytes_ += sizeof(T) * num;
+    }
+
+    void ThrowError(RPCServerStatus status) {
+      channel_->ThrowError(status);
+    }
+
+    uint64_t num_bytes() const {
+      return num_bytes_;
+    }
+
+   private:
+    TChannel* channel_;
+    uint64_t num_bytes_{0};
+  };
+
+  /*!
+   * \return the length of the str.
+   * \param str the string.
+   * \return The length.
+   */
+  static uint64_t StrLength(const char* str) {
+    uint64_t len = 0;
+    while (str[len] != '\0') ++len;
+    return len;
+  }
+
+  /*!
+   * \brief Get the total nbytes to be sent in the packed sequence.
+   *
+   * \param arg_values The values to be sent over.
+   * \param type_codes The type codes to be sent over.
+   * \param num_args Number of argument.
+   * \param client_mode Whether it is a client to server call.
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   * \return The total number of bytes.
+   */
+  template<typename TChannel>
+  static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values,
+                                       const int* type_codes,
+                                       int num_args,
+                                       bool client_mode,
+                                       TChannel* channel) {
+    PackedSeqNumBytesGetter<TChannel> getter(channel);
+    SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter);
+    return getter.num_bytes();
+  }
+
+  /*!
+   * \brief Send packed argument sequnce to the other peer.
+   *
+   * This function serves as the foundational communication primitive between peers.
+   *
+   * TVMValue sequence encoding protocol(according to the type):
+   *
+   * - int/float/uint/bytes/str: Serialize all content.
+   * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t)
+   * - OpaqueHandle: send as uint64_t
+   * - ModuleHandle, PackedFuncHandle: send as uint64_t,
+   *   The support to Module/PackedFuncHandle are reserved for arguments
+   *   in the CallFunc from a client to server only.
+   *   Note that we cannot simply take these argument out(as the handle)
+   *   refers to a value on the remote(instead of local).
+   *
+   * \param arg_values The values to be sent over.
+   * \param type_codes The type codes to be sent over.
+   * \param num_args Number of argument.
+   * \param client_mode Whether it is a client to server call.
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   */
+  template<typename TChannel>
+  static void SendPackedSeq(const TVMValue* arg_values,
+                            const int* type_codes,
+                            int num_args,
+                            bool client_mode,
+                            TChannel* channel) {
+    channel->Write(num_args);
+    channel->WriteArray(type_codes, num_args);
+
+    // Argument packing.
+    for (int i = 0; i < num_args; ++i) {
+      int tcode = type_codes[i];
+      TVMValue value = arg_values[i];
+      switch (tcode) {
+        case kDLInt:
+        case kDLUInt:
+        case kDLFloat: {
+          channel->template Write<int64_t>(value.v_int64);
+          break;
+        }
+        case kTVMDataType: {
+          channel->Write(value.v_type);
+          // padding
+          int32_t padding = 0;
+          channel->template Write<int32_t>(padding);
+          break;
+        }
+        case kTVMContext: {
+          channel->Write(value.v_ctx);
+          break;
+        }
+
+        case kTVMPackedFuncHandle:
+        case kTVMModuleHandle: {
+          if (!client_mode) {
+            channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject);
+          }
+          // always send handle in 64 bit.
+          uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
+          channel->Write(handle);
+          break;
+        }
+        case kTVMOpaqueHandle: {
+          // always send handle in 64 bit.
+          uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
+          channel->Write(handle);
+          break;
+        }
+        case kTVMNDArrayHandle: {
+          channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray);
+          break;
+        }
+        case kTVMDLTensorHandle: {
+          DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
+          TVMContext ctx;
+          uint64_t data;
+          // When we return NDArray, we directly return
+          // the space and the context
+          // The client will be further wrapping
+          ctx = arr->ctx;
+          data = reinterpret_cast<uint64_t>(arr->data);
+          channel->Write(data);
+          channel->Write(ctx);
+          channel->Write(arr->ndim);
+          channel->Write(arr->dtype);
+          channel->WriteArray(arr->shape, arr->ndim);
+          if (arr->strides != nullptr) {
+            channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride);
+          }
+          if (arr->byte_offset != 0) {
+            channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldByteOffset);
+          }
+          break;
+        }
+        case kTVMNullptr: break;
+        case kTVMStr: {
+          const char* s = value.v_str;
+          uint64_t len = StrLength(s);
+          channel->Write(len);
+          channel->WriteArray(s, len);
+          break;
+        }
+        case kTVMBytes: {
+          TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
+          uint64_t len = bytes->size;
+          channel->Write(len);
+          channel->WriteArray(bytes->data, len);
+          break;
+        }
+        default: {
+          channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
+          break;
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Receive packed seq from the channel.
+   *
+   * \param out_arg_values The values to be received.
+   * \param out_tcodes The type codes to be received.
+   * \param out_num_args Number of argument.
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   * \note The temporary space are populated via an arena inside channel.
+   */
+  template<typename TChannel>
+  static void RecvPackedSeq(TVMValue** out_values,
+                            int** out_tcodes,
+                            int* out_num_args,
+                            TChannel* channel) {
+    // receive number of args
+    int num_args;
+    channel->Read(&num_args);
+    *out_num_args = num_args;
+
+    if (num_args == 0) {
+      *out_values = nullptr;
+      *out_tcodes = nullptr;
+      return;
+    }
+
+    TVMValue* values = channel->template ArenaAlloc<TVMValue>(num_args);
+    int* tcodes = channel->template ArenaAlloc<int>(num_args);
+    *out_values = values;
+    *out_tcodes = tcodes;
+
+    // receive type code.
+    channel->ReadArray(tcodes, num_args);
+
+    // receive arguments
+    for (int i = 0; i < num_args; ++i) {
+      auto& value = values[i];
+      switch (tcodes[i]) {
+        case kDLInt:
+        case kDLUInt:
+        case kDLFloat: {
+          channel->template Read<int64_t>(&(value.v_int64));
+          break;
+        }
+        case kTVMDataType: {
+          channel->Read(&(value.v_type));
+          int32_t padding = 0;
+          channel->template Read<int32_t>(&padding);
+          break;
+        }
+        case kTVMContext: {
+          channel->Read(&(value.v_ctx));
+          break;
+        }
+        case kTVMPackedFuncHandle:
+        case kTVMModuleHandle:
+        case kTVMOpaqueHandle: {
+          // always send handle in 64 bit.
+          uint64_t handle;
+          channel->Read(&handle);
+          value.v_handle = reinterpret_cast<void*>(handle);
+          break;
+        }
+        case kTVMNullptr: {
+          value.v_handle = nullptr;
+          break;
+        }
+        case kTVMStr: {
+          uint64_t len;
+          channel->Read(&len);
+          char* str = channel->template ArenaAlloc<char>(len + 1);
+          str[len] = '\0';
+          channel->ReadArray(str, len);
+          value.v_str = str;
+          break;
+        }
+        case kTVMBytes: {
+          uint64_t len;
+          channel->Read(&len);
+          TVMByteArray* arr = channel->template ArenaAlloc<TVMByteArray>(1);
+          char* data = channel->template ArenaAlloc<char>(len);
+          arr->size = len;
+          arr->data = data;
+          channel->ReadArray(data, len);
+          value.v_handle = arr;
+          break;
+        }
+        case kTVMDLTensorHandle: {
+          uint64_t handle;
+          channel->Read(&handle);
+          DLTensor* arr = channel->template ArenaAlloc<DLTensor>(1);
+          DLTensor& tensor = *arr;
+          tensor.data = reinterpret_cast<void*>(handle);
+          channel->Read(&(tensor.ctx));
+          channel->Read(&(tensor.ndim));
+          channel->Read(&(tensor.dtype));
+          tensor.shape = channel->template ArenaAlloc<int64_t>(tensor.ndim);
+          channel->ReadArray(tensor.shape, tensor.ndim);
+          tensor.strides = nullptr;
+          tensor.byte_offset = 0;
+          value.v_handle = arr;
+          break;
+        }
+        default: {
+          channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
+          break;
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Return an exception packet.
+   *
+   * \param msg The error message.
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   */
+  template<typename TChannel>
+  static void ReturnException(const char* msg, TChannel* channel) {
+    RPCCode code = RPCCode::kException;
+    int32_t num_args = 1;
+    int32_t tcode = kTVMStr;
+    uint64_t len = StrLength(msg);
+
+    uint64_t packet_nbytes =
+        sizeof(code) +
+        sizeof(num_args) +
+        sizeof(tcode) +
+        sizeof(len) +
+        len;
+
+    channel->Write(packet_nbytes);
+    channel->Write(code);
+    channel->Write(num_args);
+    channel->Write(tcode);
+    channel->Write(len);
+    channel->WriteArray(msg, len);
+  }
+
+  /*!
+   * \brief Return a normal packed sequence packet.
+   *
+   * \param msg The error message.
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   */
+  template<typename TChannel>
+  static void ReturnPackedSeq(const TVMValue* arg_values,
+                              const int* type_codes,
+                              int num_args,
+                              TChannel* channel) {
+    RPCCode code = RPCCode::kReturn;
+
+    uint64_t packet_nbytes =
+        sizeof(code) +
+        PackedSeqGetNumBytes(
+            arg_values, type_codes, num_args, false, channel);
+
+    channel->Write(packet_nbytes);
+    channel->Write(code);
+    SendPackedSeq(
+        arg_values, type_codes, num_args, false, channel);
+  }
+
+  /*!
+   * \brief Return a null(void) packet.
+   *
+   * \param channel The communication channel handler.
+   * \tparam TChannel The type of the communication channel.
+   */
+  template<typename TChannel>
+  static void ReturnVoid(TChannel* channel) {
+    int32_t num_args = 1;
+    int32_t tcode = kTVMNullptr;
+    RPCCode code = RPCCode::kReturn;
+
+    uint64_t packet_nbytes =
+        sizeof(code) +
+        sizeof(num_args) +
+        sizeof(tcode);
+
+    channel->Write(packet_nbytes);
+    channel->Write(code);
+    channel->Write(num_args);
+    channel->Write(tcode);
+  }
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
index f6a7fb6..612ca41 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,7 +28,8 @@ namespace tvm {
 namespace runtime {
 
 std::string RPCGetPath(const std::string& name) {
-  static const PackedFunc* f =
+  // do live lookup everytime as workpath can change.
+  const PackedFunc* f =
       runtime::Registry::Get("tvm.rpc.server.workpath");
   CHECK(f != nullptr) << "require tvm.rpc.server.workpath";
   return (*f)(name);
index ae293ab..dd0afa0 100644 (file)
  * \file rpc_session.cc
  * \brief RPC session for remote function call.
  */
-#include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/device_api.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/serializer.h>
-#include <memory>
+#include <mutex>
 #include <array>
-#include <string>
-#include <chrono>
-#include <vector>
-#include <utility>
-#include <cmath>
-#include <algorithm>
 #include "rpc_session.h"
-#include "../object_internal.h"
-#include "../../support/ring_buffer.h"
-#include "../../support/socket.h"
-#include "../micro/micro_session.h"
 
 namespace tvm {
 namespace runtime {
 
-// Temp buffer for data array
-struct RPCByteArrayBuffer {
-  TVMByteArray arr;
-  std::string data;
-};
-// Temp buffer for data array
-struct RPCDataArrayBuffer {
-  DLTensor tensor;
-  std::vector<int64_t> shape;
-};
-/*!
- * \brief Temporal argument buffer.
- */
-struct RPCArgBuffer {
-  // The argument values
-  std::vector<TVMValue> value;
-  // The type codes.
-  std::vector<int> tcode;
-  // Temporal resources.
-  std::vector<std::unique_ptr<RPCByteArrayBuffer> > temp_bytes;
-  // Temporal array
-  std::vector<std::unique_ptr<RPCDataArrayBuffer> > temp_array;
-  // convert buffer as TVMArgs
-  TVMArgs AsTVMArgs() const {
-    return TVMArgs(value.data(), tcode.data(), static_cast<int>(value.size()));
-  }
-};
-
-// Event handler for RPC events.
-class RPCSession::EventHandler : public dmlc::Stream {
- public:
-  EventHandler(support::RingBuffer* reader,
-               support::RingBuffer* writer,
-               int rpc_sess_table_index,
-               std::string name,
-               std::string* remote_key)
-      : reader_(reader),
-        writer_(writer),
-        rpc_sess_table_index_(rpc_sess_table_index),
-        name_(name),
-        remote_key_(remote_key) {
-    this->Clear();
-    if (*remote_key == "%toinit") {
-      state_ = kInitHeader;
-      remote_key_->resize(0);
-      pending_request_bytes_ = sizeof(int32_t);
-    }
-  }
-  // Bytes needed to fulfill current request
-  size_t BytesNeeded() {
-    if (reader_->bytes_available() < pending_request_bytes_) {
-      return pending_request_bytes_ - reader_->bytes_available();
-    } else {
-      return 0;
-    }
-  }
-  // Request number of bytes from reader.
-  void RequestBytes(size_t nbytes) {
-    pending_request_bytes_ += nbytes;
-    reader_->Reserve(pending_request_bytes_);
-  }
-  // Whether we are ready to handle next request.
-  bool Ready() {
-    return reader_->bytes_available() >= pending_request_bytes_;
-  }
-  bool CanCleanShutdown() const {
-    return state_ == kRecvCode;
-  }
-  void FinishCopyAck() {
-    this->SwitchToState(kRecvCode);
-  }
-  RPCCode HandleNextEvent(TVMRetValue* rv,
-                          bool client_mode,
-                          const PackedFunc* fwrap) {
-    std::swap(client_mode_, client_mode);
-    while (this->Ready()) {
-      switch (state_) {
-        case kInitHeader: HandleInitHeader(); break;
-        case kRecvCode: HandleRecvCode(); break;
-        case kRecvCallHandle: {
-          CHECK(this->Read(&call_handle_));
-          this->SwitchToState(kRecvPackedSeqNumArgs);
-          break;
-        }
-        case kRecvPackedSeqNumArgs: {
-          CHECK(this->Read(&num_packed_args_));
-          arg_buf_.reset(new RPCArgBuffer());
-          arg_buf_->value.resize(num_packed_args_);
-          arg_buf_->tcode.resize(num_packed_args_);
-          this->SwitchToState(kRecvPackedSeqTypeCode);
-          break;
-        }
-        case kRecvPackedSeqTypeCode: {
-          if (num_packed_args_ != 0) {
-            this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
-          }
-          arg_index_ = 0;
-          arg_recv_stage_ = 0;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kRecvPackedSeqArg: {
-          this->HandleRecvPackedSeqArg();
-          break;
-        }
-        case kDoCopyFromRemote: {
-          this->HandleCopyFromRemote();
-          break;
-        }
-        case kDoCopyToRemote: {
-          this->HandleCopyToRemote();
-          break;
-        }
-        case kReturnReceived: {
-          CHECK_GE(arg_buf_->value.size(), 1U);
-
-          TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
-          if (argv.type_code() == kTVMPackedFuncHandle ||
-              argv.type_code() == kTVMModuleHandle ||
-              argv.type_code() == kTVMDLTensorHandle) {
-            CHECK(fwrap != nullptr) << "function/module wrapper not available";
-            fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
-          } else {
-            CHECK_EQ(arg_buf_->value.size(), 1U);
-            *rv = argv;
-          }
-          arg_buf_.reset();
-          this->SwitchToState(kRecvCode);
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kReturn;
-        }
-        case kCopyAckReceived: {
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kCopyAck;
-        }
-        case kShutdownReceived: {
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kShutdown;
-        }
-      }
-    }
-    std::swap(client_mode_, client_mode);
-    return RPCCode::kNone;
-  }
-  // Reset and clear all states.
-  void Clear() {
-    state_ = kRecvCode;
-    pending_request_bytes_ = sizeof(RPCCode);
-    arg_recv_stage_ = 0;
-    arg_buf_.reset();
-  }
-  // strip session on mask
-  TVMContext StripSessMask(TVMContext ctx) {
-    int dev_type = ctx.device_type;
-    CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
-        << "Can not pass in local context or context with a different remote session";
-    ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
-    return ctx;
-  }
-  // Send Packed sequence to writer.
-  //
-  // client_mode: whether we are in client mode.
-  //
-  // funwrap: auxiliary function to unwrap remote Object
-  //          when it is provided, we need to unwrap objects.
-  //
-  // return_ndarray is a special flag to handle returning of ndarray
-  //    In this case, we return the shape, context and data of the array,
-  //    as well as a customized PackedFunc that handles deletion of
-  //    the array in the remote.
-  void SendPackedSeq(const TVMValue* arg_values,
-                     const int* type_codes,
-                     int num_args,
-                     bool client_mode,
-                     FUnwrapRemoteObject funwrap = nullptr,
-                     bool return_ndarray = false) {
-    std::swap(client_mode_, client_mode);
-
-    this->Write(num_args);
-    for (int i = 0; i < num_args; ++i) {
-      int tcode = type_codes[i];
-      if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle;
-      this->Write(tcode);
-    }
-
-    // Argument packing.
-    for (int i = 0; i < num_args; ++i) {
-      int tcode = type_codes[i];
-      TVMValue value = arg_values[i];
-      switch (tcode) {
-        case kDLInt:
-        case kDLUInt:
-        case kDLFloat: {
-          this->Write<int64_t>(value.v_int64);
-          break;
-        }
-        case kTVMDataType: {
-          this->Write(value.v_type);
-          // padding
-          int32_t padding = 0;
-          this->Write<int32_t>(padding);
-          break;
-        }
-        case kTVMContext: {
-          value.v_ctx = StripSessMask(value.v_ctx);
-          this->Write(value.v_ctx);
-          break;
-        }
-        case kTVMPackedFuncHandle:
-        case kTVMModuleHandle: {
-          // always send handle in 64 bit.
-          uint64_t handle;
-          // allow pass module as argument to remote.
-          if (funwrap != nullptr) {
-            void* remote_handle = (*funwrap)(
-                rpc_sess_table_index_,
-                runtime::TVMArgValue(value, tcode));
-            handle = reinterpret_cast<uint64_t>(remote_handle);
-          } else {
-            CHECK(!client_mode_)
-                << "Cannot directly pass remote object as argument";
-            handle = reinterpret_cast<uint64_t>(value.v_handle);
-          }
-          this->Write(handle);
-          break;
-        }
-        case kTVMOpaqueHandle: {
-          // always send handle in 64 bit.
-          uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
-          this->Write(handle);
-          break;
-        }
-        case kTVMNDArrayHandle:
-        case kTVMDLTensorHandle: {
-          DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
-          TVMContext ctx;
-          uint64_t data;
-          if (!return_ndarray) {
-            // in the client mode
-            // ctx contains the remote table index
-            // the space is wrapped by an RemoteSpace
-            // that holds reference to the session.
-            ctx = StripSessMask(arr->ctx);
-            data = reinterpret_cast<uint64_t>(
-                static_cast<RemoteSpace*>(arr->data)->data);
-          } else {
-            // When we return NDArray, we directly return
-            // the space and the context
-            // The client will be further wrapping
-            ctx = arr->ctx;
-            data = reinterpret_cast<uint64_t>(arr->data);
-          }
-          this->Write(data);
-          this->Write(ctx);
-          this->Write(arr->ndim);
-          this->Write(arr->dtype);
-          this->WriteArray(arr->shape, arr->ndim);
-          CHECK(arr->strides == nullptr)
-              << "Do not support strided remote array";
-          CHECK_EQ(arr->byte_offset, 0)
-              << "Do not support send byte offset";
-          break;
-        }
-        case kTVMNullptr: break;
-        case kTVMStr: {
-          const char* s = value.v_str;
-          uint64_t len = strlen(s);
-          this->Write(len);
-          this->WriteArray(s, len);
-          break;
-        }
-        case kTVMBytes: {
-          TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
-          uint64_t len = bytes->size;
-          this->Write(len);
-          this->WriteArray(bytes->data, len);
-          break;
-        }
-        default: {
-          LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
-          break;
-        }
-      }
-    }
-    std::swap(client_mode_, client_mode);
-  }
-
-  // Endian aware IO handling
-  using Stream::Read;
-  using Stream::Write;
-  using Stream::ReadArray;
-  using Stream::WriteArray;
-
-  inline bool Read(RPCCode* code) {
-    int cdata;
-    if (!this->Read(&cdata)) return false;
-    *code = static_cast<RPCCode>(cdata);
-    return true;
-  }
-  inline void Write(RPCCode code) {
-    int cdata = static_cast<int>(code);
-    this->Write(cdata);
-  }
-
- protected:
-  enum State {
-    kInitHeader,
-    kRecvCode,
-    kRecvCallHandle,
-    kRecvPackedSeqNumArgs,
-    kRecvPackedSeqTypeCode,
-    kRecvPackedSeqArg,
-    kDoCopyFromRemote,
-    kDoCopyToRemote,
-    kReturnReceived,
-    kCopyAckReceived,
-    kShutdownReceived
-  };
-  // Current state;
-  State state_;
-  // The RPCCode to be read.
-  RPCCode code_;
-  // Handle for the remote function call.
-  uint64_t call_handle_;
-  // Initialize remote header
-  bool init_header_step_{0};
-  // Number of packed arguments.
-  int num_packed_args_;
-  // Current argument index.
-  int arg_index_;
-  // The stage of each argument receiver.
-  int arg_recv_stage_;
-  // Whether current handler is client or server mode.
-  bool client_mode_{false};
-  // Argument buffer
-  std::unique_ptr<RPCArgBuffer> arg_buf_;
-  // Temp byte buffer.
-  std::unique_ptr<RPCByteArrayBuffer> temp_bytes_;
-  // Temp array buffer.
-  std::unique_ptr<RPCDataArrayBuffer> temp_array_;
-  // Internal temporal data space.
-  std::string temp_data_;
-  // Temp variables for copy request state.
-  TVMContext copy_ctx_;
-  DLDataType copy_dtype_;
-  uint64_t copy_handle_, copy_offset_, copy_size_;
-  // State switcher
-  void SwitchToState(State state) {
-    // invariant
-    CHECK_EQ(pending_request_bytes_, 0U)
-        << "state=" << state;
-    state_ = state;
-    switch (state) {
-      case kInitHeader: {
-        LOG(FATAL) << "cannot switch to init header";
-        break;
-      }
-      case kRecvCode: {
-        this->RequestBytes(sizeof(RPCCode));
-        break;
-      }
-      case kRecvCallHandle: {
-        this->RequestBytes(sizeof(call_handle_));
-        break;
-      }
-      case kRecvPackedSeqNumArgs: {
-        this->RequestBytes(sizeof(num_packed_args_));
-        break;
-      }
-      case kRecvPackedSeqTypeCode: {
-        this->RequestBytes(sizeof(int) * num_packed_args_);
-        break;
-      }
-      case kRecvPackedSeqArg: {
-        CHECK_LE(arg_index_, num_packed_args_);
-        if (arg_index_ == num_packed_args_) {
-          // The function can change state_ again.
-          HandlePackedCall();
-        } else {
-          RequestRecvPackedSeqArg();
-        }
-        break;
-      }
-      case kDoCopyFromRemote: {
-        this->RequestBytes(sizeof(uint64_t) * 3);
-        this->RequestBytes(sizeof(TVMContext));
-        this->RequestBytes(sizeof(DLDataType));
-        break;
-      }
-      case kDoCopyToRemote: {
-        this->RequestBytes(sizeof(uint64_t) * 3);
-        this->RequestBytes(sizeof(TVMContext));
-        this->RequestBytes(sizeof(DLDataType));
-        break;
-      }
-      case kCopyAckReceived:
-      case kReturnReceived:
-      case kShutdownReceived: {
-        break;
-      }
-    }
-  }
-  // Requets bytes needed for next computation.
-  void RequestRecvPackedSeqArg() {
-    CHECK_EQ(arg_recv_stage_, 0);
-    int tcode = arg_buf_->tcode[arg_index_];
-    static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
-    switch (tcode) {
-      case kDLInt:
-      case kDLUInt:
-      case kDLFloat:
-      case kTVMDataType:
-      case kTVMOpaqueHandle:
-      case kTVMStr:
-      case kTVMBytes:
-      case kTVMModuleHandle:
-      case kTVMContext: {
-        this->RequestBytes(sizeof(TVMValue)); break;
-      }
-      case kTVMPackedFuncHandle: {
-        CHECK(client_mode_)
-            << "Only client can receive remote functions";
-        this->RequestBytes(sizeof(TVMValue)); break;
-      }
-      case kTVMNullptr: break;
-      case kTVMDLTensorHandle: {
-        this->RequestBytes(sizeof(uint64_t));
-        this->RequestBytes(sizeof(TVMContext));
-        this->RequestBytes(sizeof(int));
-        this->RequestBytes(sizeof(DLDataType));
-        break;
-      }
-      default: {
-        LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
-        break;
-      }
-    }
-  }
-  // Handler for packed sequence argument receive.
-  void HandleRecvPackedSeqArg() {
-    CHECK_LT(arg_index_, num_packed_args_);
-    int tcode = arg_buf_->tcode[arg_index_];
-    TVMValue& value = arg_buf_->value[arg_index_];
-    if (arg_recv_stage_ == 0) {
-      switch (tcode) {
-        case kDLInt:
-        case kDLUInt:
-        case kDLFloat: {
-          this->Read<int64_t>(&(value.v_int64));
-          ++arg_index_;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kTVMDataType: {
-          this->Read(&(value.v_type));
-          int32_t padding = 0;
-          this->Read<int32_t>(&padding);
-          ++arg_index_;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kTVMContext: {
-          this->Read(&(value.v_ctx));
-          ++arg_index_;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kTVMPackedFuncHandle:
-        case kTVMModuleHandle:
-        case kTVMOpaqueHandle: {
-          // always send handle in 64 bit.
-          uint64_t handle;
-          this->Read(&handle);
-          value.v_handle = reinterpret_cast<void*>(handle);
-          ++arg_index_;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kTVMNullptr: {
-          value.v_handle = nullptr;
-          ++arg_index_;
-          this->SwitchToState(kRecvPackedSeqArg);
-          break;
-        }
-        case kTVMStr:
-        case kTVMBytes: {
-          uint64_t len;
-          this->Read(&len);
-          temp_bytes_.reset( new RPCByteArrayBuffer());
-          temp_bytes_->data.resize(len);
-          arg_recv_stage_ = 1;
-          this->RequestBytes(len);
-          break;
-        }
-        case kTVMDLTensorHandle: {
-          temp_array_.reset(new RPCDataArrayBuffer());
-          uint64_t handle;
-          this->Read(&handle);
-          DLTensor& tensor = temp_array_->tensor;
-          tensor.data = reinterpret_cast<void*>(handle);
-          this->Read(&(tensor.ctx));
-          this->Read(&(tensor.ndim));
-          this->Read(&(tensor.dtype));
-          temp_array_->shape.resize(tensor.ndim);
-          tensor.shape = temp_array_->shape.data();
-          arg_recv_stage_ = 1;
-          tensor.strides = nullptr;
-          tensor.byte_offset = 0;
-          this->RequestBytes(sizeof(int64_t) * tensor.ndim);
-          break;
-        }
-        default: {
-          LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
-          break;
-        }
-      }
-    } else {
-      CHECK_EQ(arg_recv_stage_, 1);
-      if (tcode == kTVMStr || tcode == kTVMBytes) {
-        if (temp_bytes_->data.size() != 0) {
-          this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
-        }
-        if (tcode == kTVMStr) {
-          value.v_str = temp_bytes_->data.c_str();
-        } else {
-          temp_bytes_->arr.size = static_cast<size_t>(temp_bytes_->data.size());
-          temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data);
-          value.v_handle = &(temp_bytes_->arr);
-        }
-        arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_));
-      } else {
-        CHECK_EQ(tcode, kTVMDLTensorHandle);
-        DLTensor& tensor = temp_array_->tensor;
-        this->ReadArray(tensor.shape, tensor.ndim);
-        value.v_handle = &tensor;
-        arg_buf_->temp_array.emplace_back(std::move(temp_array_));
-      }
-      ++arg_index_;
-      arg_recv_stage_ = 0;
-      this->SwitchToState(kRecvPackedSeqArg);
-    }
-  }
-  // handler for initial header read
-  void HandleInitHeader() {
-    if (init_header_step_ == 0) {
-      int32_t len;
-      this->Read(&len);
-      remote_key_->resize(len);
-      init_header_step_ = 1;
-      this->RequestBytes(len);
-      return;
-    } else {
-      CHECK_EQ(init_header_step_, 1);
-      this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
-      this->SwitchToState(kRecvCode);
-    }
-  }
-  // Handler for read code.
-  void HandleRecvCode() {
-    this->Read(&code_);
-    if (code_ > RPCCode::kSystemFuncStart) {
-      SwitchToState(kRecvPackedSeqNumArgs);
-      return;
-    }
-    // invariant.
-    CHECK_EQ(arg_recv_stage_, 0);
-    switch (code_) {
-      case RPCCode::kCallFunc: {
-        SwitchToState(kRecvCallHandle);
-        break;
-      }
-      case RPCCode::kException:
-      case RPCCode::kReturn: {
-        SwitchToState(kRecvPackedSeqNumArgs);
-        break;
-      }
-      case RPCCode::kCopyFromRemote: {
-        SwitchToState(kDoCopyFromRemote);
-        break;
-      }
-      case RPCCode::kCopyToRemote: {
-        SwitchToState(kDoCopyToRemote);
-        break;
-      }
-      case RPCCode::kShutdown: {
-        SwitchToState(kShutdownReceived);
-        break;
-      }
-      case RPCCode::kCopyAck: {
-        SwitchToState(kCopyAckReceived);
-        break;
-      }
-      default: LOG(FATAL) << "Unknown event "  << static_cast<int>(code_);
-    }
-  }
-
-  void HandleCopyFromRemote() {
-    uint64_t handle, offset, num_bytes;
-    TVMContext ctx;
-    DLDataType type_hint;
-    this->Read(&handle);
-    this->Read(&offset);
-    this->Read(&num_bytes);
-    this->Read(&ctx);
-    this->Read(&type_hint);
-    size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
-
-    if (ctx.device_type == kDLCPU) {
-      RPCCode code = RPCCode::kCopyAck;
-      this->Write(code);
-      char* dptr = reinterpret_cast<char*>(handle) + offset;
-      if (!DMLC_IO_NO_ENDIAN_SWAP) {
-        temp_data_.resize(0);
-        temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes);
-        dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
-        this->WriteArray(temp_data_.data(), num_bytes);
-      } else {
-        this->WriteArray(dptr, num_bytes);
-      }
-    } else {
-      temp_data_.resize(num_bytes + 1);
-      try {
-        TVMContext cpu_ctx;
-        cpu_ctx.device_type = kDLCPU;
-        cpu_ctx.device_id = 0;
-        DeviceAPI::Get(ctx)->CopyDataFromTo(
-            reinterpret_cast<void*>(handle), offset,
-            dmlc::BeginPtr(temp_data_), 0,
-            num_bytes, ctx, cpu_ctx, type_hint, nullptr);
-        RPCCode code = RPCCode::kCopyAck;
-        this->Write(code);
-        if (!DMLC_IO_NO_ENDIAN_SWAP) {
-          dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
-        }
-        this->WriteArray(&temp_data_[0], num_bytes);
-      } catch (const std::runtime_error &e) {
-        RPCCode code = RPCCode::kException;
-        this->Write(code);
-        TVMValue ret_value;
-        ret_value.v_str = e.what();
-        int ret_tcode = kTVMStr;
-        SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      }
-    }
-    this->SwitchToState(kRecvCode);
-  }
-
-  void HandleCopyToRemote() {
-    // use static variable to persist state.
-    // This only works if next stage is immediately after this.
-    if (arg_recv_stage_ == 0) {
-      CHECK(this->Read(&copy_handle_));
-      CHECK(this->Read(&copy_offset_));
-      CHECK(this->Read(&copy_size_));
-      CHECK(this->Read(&copy_ctx_));
-      CHECK(this->Read(&copy_dtype_));
-      arg_recv_stage_ = 1;
-      CHECK_EQ(pending_request_bytes_, 0U);
-      this->RequestBytes(copy_size_);
-    } else {
-      CHECK_EQ(arg_recv_stage_, 1);
-      TVMValue ret_value;
-      ret_value.v_handle = nullptr;
-      int ret_tcode = kTVMNullptr;
-      RPCCode code = RPCCode::kReturn;
-      std::string errmsg;
-
-      size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8;
-      if (copy_ctx_.device_type == kDLCPU) {
-        char* dptr = reinterpret_cast<char*>(copy_handle_) + copy_offset_;
-        this->ReadArray(dptr, copy_size_);
-        if (!DMLC_IO_NO_ENDIAN_SWAP) {
-          dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes);
-        }
-      } else {
-        temp_data_.resize(copy_size_ + 1);
-        this->ReadArray(&temp_data_[0], copy_size_);
-        if (!DMLC_IO_NO_ENDIAN_SWAP) {
-          dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes);
-        }
-        try {
-          TVMContext cpu_ctx;
-          cpu_ctx.device_type = kDLCPU;
-          cpu_ctx.device_id = 0;
-          DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
-              temp_data_.data(), 0,
-              reinterpret_cast<void*>(copy_handle_), copy_offset_,
-              copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr);
-        } catch (const std::runtime_error &e) {
-          code = RPCCode::kException;
-          errmsg = e.what();
-          ret_value.v_str = errmsg.c_str();
-          ret_tcode = kTVMStr;
-        }
-      }
-      this->Write(code);
-      SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      arg_recv_stage_ = 0;
-      this->SwitchToState(kRecvCode);
-    }
-  }
-  // Handle for packed call.
-  void HandlePackedCall();
-
-  template<typename F>
-  void CallHandler(F f) {
-    TVMRetValue rv;
-    TVMValue ret_value;
-    int ret_tcode;
-    try {
-      // Need to move out, in case f itself need to call RecvPackedSeq
-      // Which will override argbuf again.
-      std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
-      f(args->AsTVMArgs(), &rv);
-      RPCCode code = RPCCode::kReturn;
-      this->Write(code);
-      if (rv.type_code() == kTVMStr) {
-        ret_value.v_str = rv.ptr<std::string>()->c_str();
-        ret_tcode = kTVMStr;
-        SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      } else if (rv.type_code() == kTVMBytes) {
-        std::string* bytes = rv.ptr<std::string>();
-        TVMByteArray arr;
-        arr.data = bytes->c_str();
-        arr.size = bytes->length();
-        ret_value.v_handle = &arr;
-        ret_tcode = kTVMBytes;
-        SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      } else if (rv.type_code() == kTVMPackedFuncHandle ||
-                 rv.type_code() == kTVMModuleHandle) {
-        // always send handle in 64 bit.
-        CHECK(!client_mode_)
-              << "Only server can send function and module handle back.";
-        rv.MoveToCHost(&ret_value, &ret_tcode);
-        SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      } else if (rv.type_code() == kTVMNDArrayHandle) {
-        // always send handle in 64 bit.
-        CHECK(!client_mode_)
-            << "Only server can send NDArray back";
-        // We follow a special protocol to return NDArray to client side
-        // The first pack value is the NDArray handle as DLTensor
-        // The second pack value is a customized deleter that deletes the NDArray.
-        TVMValue ret_value_pack[2];
-        int ret_tcode_pack[2];
-        rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
-        ret_value_pack[1].v_handle = ret_value_pack[0].v_handle;
-        ret_tcode_pack[1] = kTVMOpaqueHandle;
-        SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
-      } else {
-        ret_value = rv.value();
-        ret_tcode = rv.type_code();
-        SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-      }
-    } catch (const std::runtime_error& e) {
-      RPCCode code = RPCCode::kException;
-      this->Write(code);
-      ret_value.v_str = e.what();
-      ret_tcode = kTVMStr;
-      SendPackedSeq(&ret_value, &ret_tcode, 1, false);
-    }
-  }
-
- private:
-  // Utility functions
-  // Internal read function, update pending_request_bytes_
-  size_t Read(void* data, size_t size) final {
-    CHECK_LE(size, pending_request_bytes_);
-    reader_->Read(data, size);
-    pending_request_bytes_ -= size;
-    return size;
-  }
-  void Write(const void* data, size_t size) final {
-    writer_->Write(data, size);
-  }
-  // Number of pending bytes requests
-  size_t pending_request_bytes_;
-  // The ring buffer to read data from.
-  support::RingBuffer* reader_;
-  // The ringr buffer to write reply to.
-  support::RingBuffer* writer_;
-  // Session table index.
-  int rpc_sess_table_index_;
-  // Name of session.
-  std::string name_;
-  // remote key
-  std::string* remote_key_;
-};
-
-struct RPCSessTable {
+class RPCSessTable {
  public:
   static constexpr int kMaxRPCSession = 32;
   // Get global singleton
@@ -864,465 +63,13 @@ struct RPCSessTable {
   std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
 };
 
-RPCCode RPCSession::HandleUntilReturnEvent(
-    TVMRetValue* rv,  bool client_mode, const PackedFunc* fwrap) {
-  RPCCode code = RPCCode::kCallFunc;
-  while (code != RPCCode::kReturn &&
-         code != RPCCode::kShutdown &&
-         code != RPCCode::kCopyAck) {
-    while (writer_.bytes_available() != 0) {
-      writer_.ReadWithCallback([this](const void *data, size_t size) {
-          return channel_->Send(data, size);
-        }, writer_.bytes_available());
-    }
-    size_t bytes_needed = handler_->BytesNeeded();
-    if (bytes_needed != 0) {
-      size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
-          return channel_->Recv(data, size);
-        }, bytes_needed);
-      if (n == 0) {
-        if (handler_->CanCleanShutdown()) {
-          return RPCCode::kShutdown;
-        } else {
-          LOG(FATAL) << "Channel closes before we get neded bytes";
-        }
-      }
-    }
-    code = handler_->HandleNextEvent(rv, client_mode, fwrap);
-  }
-  return code;
-}
-
-void RPCSession::Init() {
-  // Event handler
-  handler_ = std::make_shared<EventHandler>(
-      &reader_, &writer_, table_index_, name_, &remote_key_);
-  // Quick function to call remote.
-  call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
-      handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
-      RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
-      CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
-    });
-}
-
-std::shared_ptr<RPCSession> RPCSession::Create(
-    std::unique_ptr<RPCChannel> channel,
-    std::string name,
-    std::string remote_key) {
-  std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
-  sess->channel_ = std::move(channel);
-  sess->name_ = std::move(name);
-  sess->remote_key_ = std::move(remote_key);
-  sess->table_index_ = RPCSessTable::Global()->Insert(sess);
-  sess->Init();
-  return sess;
-}
-
 std::shared_ptr<RPCSession> RPCSession::Get(int table_index) {
   return RPCSessTable::Global()->Get(table_index);
 }
 
-RPCSession::~RPCSession() {
-  this->Shutdown();
-}
-
-void RPCSession::Shutdown() {
-  if (channel_ != nullptr) {
-    RPCCode code = RPCCode::kShutdown;
-    handler_->Write(code);
-    // flush all writing buffer to output channel.
-    try {
-      while (writer_.bytes_available() != 0) {
-        size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
-            return channel_->Send(data, size);
-          }, writer_.bytes_available());
-        if (n == 0) break;
-      }
-    } catch (const dmlc::Error& e) {
-    }
-    channel_.reset(nullptr);
-  }
-}
-
-void RPCSession::ServerLoop() {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-  if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
-    (*f)();
-  }
-  TVMRetValue rv;
-  CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
-  if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
-    (*f)();
-  }
-  channel_.reset(nullptr);
-}
-
-int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-  RPCCode code = RPCCode::kNone;
-  if (bytes.length() != 0) {
-    reader_.Write(bytes.c_str(), bytes.length());
-    TVMRetValue rv;
-    code = handler_->HandleNextEvent(&rv, false, nullptr);
-  }
-  if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
-    writer_.ReadWithCallback([this](const void *data, size_t size) {
-        return channel_->Send(data, size);
-      }, writer_.bytes_available());
-  }
-  CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
-  if (code == RPCCode::kShutdown) return 0;
-  if (writer_.bytes_available() != 0) return 2;
-  return 1;
-}
-
-// Get remote function with name
-void RPCSession::CallFunc(void* h,
-                          TVMArgs args,
-                          TVMRetValue* rv,
-                          FUnwrapRemoteObject funwrap,
-                          const PackedFunc* fwrap) {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-
-  RPCCode code = RPCCode::kCallFunc;
-  handler_->Write(code);
-  uint64_t handle = reinterpret_cast<uint64_t>(h);
-  handler_->Write(handle);
-  handler_->SendPackedSeq(
-      args.values, args.type_codes, args.num_args, true, funwrap);
-  code = HandleUntilReturnEvent(rv, true, fwrap);
-  CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
-}
-
-void RPCSession::CopyToRemote(void* from,
-                              size_t from_offset,
-                              void* to,
-                              size_t to_offset,
-                              size_t data_size,
-                              TVMContext ctx_to,
-                              DLDataType type_hint) {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-  ctx_to = handler_->StripSessMask(ctx_to);
-  RPCCode code = RPCCode::kCopyToRemote;
-  handler_->Write(code);
-  uint64_t handle = reinterpret_cast<uint64_t>(to);
-  handler_->Write(handle);
-  uint64_t offset = static_cast<uint64_t>(to_offset);
-  handler_->Write(offset);
-  uint64_t size = static_cast<uint64_t>(data_size);
-  handler_->Write(size);
-  handler_->Write(ctx_to);
-  handler_->Write(type_hint);
-  handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
-  TVMRetValue rv;
-  CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
-}
-
-void RPCSession::CopyFromRemote(void* from,
-                                size_t from_offset,
-                                void* to,
-                                size_t to_offset,
-                                size_t data_size,
-                                TVMContext ctx_from,
-                                DLDataType type_hint) {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-  ctx_from = handler_->StripSessMask(ctx_from);
-  RPCCode code = RPCCode::kCopyFromRemote;
-  handler_->Write(code);
-  uint64_t handle = reinterpret_cast<uint64_t>(from);
-  handler_->Write(handle);
-  uint64_t offset = static_cast<uint64_t>(from_offset);
-  handler_->Write(offset);
-  uint64_t size = static_cast<uint64_t>(data_size);
-  handler_->Write(size);
-  handler_->Write(ctx_from);
-  handler_->Write(type_hint);
-  TVMRetValue rv;
-  CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
-  reader_.Reserve(data_size);
-  handler_->RequestBytes(data_size);
-  while (!handler_->Ready()) {
-    size_t bytes_needed = handler_->BytesNeeded();
-    reader_.WriteWithCallback([this](void* data, size_t size) {
-        size_t n = channel_->Recv(data, size);
-        CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
-        return n;
-      }, bytes_needed);
-  }
-  handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
-  handler_->FinishCopyAck();
-}
-
-RPCFuncHandle RPCSession::GetTimeEvaluator(
-    RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) {
-  return this->CallRemote(
-      RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms);
-}
-
-// Event handler functions
-void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
-  std::string name = args[0];
-  auto *fp = tvm::runtime::Registry::Get(name);
-  if (fp != nullptr) {
-    *rv = static_cast<void*>(new tvm::runtime::PackedFunc(*fp));
-  } else {
-    *rv = nullptr;
-  }
-}
-
-void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) {
-  void* handle = args[0];
-  delete static_cast<PackedFunc*>(handle);
-}
-
-void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  DeviceAPI::Get(ctx)->SetDevice(ctx);
-}
-
-void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
-  if (kind == kExist) {
-    DeviceAPI* api = DeviceAPI::Get(ctx, true);
-    if (api != nullptr) {
-      api->GetAttr(ctx, kind, rv);
-    } else {
-      *rv = 0;
-    }
-  } else {
-    DeviceAPI::Get(ctx)->GetAttr(
-        ctx, static_cast<DeviceAttrKind>(kind), rv);
-  }
-}
-
-void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  uint64_t nbytes = args[1];
-  uint64_t alignment = args[2];
-  DLDataType type_hint = args[3];
-  void* data = DeviceAPI::Get(ctx)->AllocDataSpace(
-      ctx, nbytes, alignment, type_hint);
-  *rv = data;
-}
-
-void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  void* ptr = args[1];
-  DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr);
-}
-
-void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  TVMStreamHandle handle = args[1];
-  DeviceAPI::Get(ctx)->StreamSync(ctx, handle);
-}
-
-void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
-  void* from = args[0];
-  uint64_t from_offset = args[1];
-  void* to = args[2];
-  uint64_t to_offset = args[3];
-  uint64_t size = args[4];
-  TVMContext ctx_from = args[5];
-  TVMContext ctx_to = args[6];
-  DLDataType type_hint = args[7];
-  TVMStreamHandle stream = args[8];
-  TVMContext ctx = ctx_from;
-  if (ctx.device_type == kDLCPU) {
-    ctx = ctx_to;
-  } else {
-    CHECK(ctx_to.device_type == kDLCPU ||
-          ctx_to.device_type == ctx_from.device_type)
-        << "Can not copy across different ctx types directly";
-  }
-  DeviceAPI::Get(ctx)->CopyDataFromTo(
-      from, from_offset,
-      to, to_offset,
-      size, ctx_from, ctx_to, type_hint, stream);
-}
-
-void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
-  static const PackedFunc* fsys_load_ = nullptr;
-  if (fsys_load_ == nullptr) {
-    fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
-    CHECK(fsys_load_ != nullptr);
-  }
-  std::string file_name = args[0];
-  TVMRetValue ret = (*fsys_load_)(file_name);
-  // pass via void*
-  TVMValue value;
-  int rcode;
-  ret.MoveToCHost(&value, &rcode);
-  CHECK_EQ(rcode, kTVMModuleHandle);
-  *rv = static_cast<void*>(value.v_handle);
-}
-
-void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
-  void* pmod = args[0];
-  void* cmod = args[1];
-  ObjectInternal::GetModuleNode(pmod)->Import(
-      GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
-}
-
-void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
-  void* mhandle = args[0];
-  ObjectInternal::ObjectFree(mhandle);
-}
-
-void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
-  void* mhandle = args[0];
-  PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
-      args[1], false);
-  if (pf != nullptr) {
-    *rv = static_cast<void*>(new PackedFunc(pf));
-  } else {
-    *rv = nullptr;
-  }
-}
-
-void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
-  void* mhandle = args[0];
-  std::string fmt = args[1];
-  *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
-}
-
-void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
-  void* handle = args[0];
-  static_cast<NDArray::Container*>(
-      reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
-}
-
-void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
-  PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
-  void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4]));
-  delete pf;
-  *rv = fhandle;
-}
-
-void RPCSession::EventHandler::HandlePackedCall() {
-  CHECK_EQ(pending_request_bytes_, 0U);
-  if (code_ == RPCCode::kReturn) {
-    state_ = kReturnReceived; return;
-  }
-  // reset state to clean init state
-  state_ = kRecvCode;
-  this->RequestBytes(sizeof(RPCCode));
-  // Event handler sit at clean state at this point.
-  switch (code_) {
-    case RPCCode::kCallFunc: {
-      PackedFunc* pf = reinterpret_cast<PackedFunc*>(call_handle_);
-      CallHandler([pf](TVMArgs args, TVMRetValue* rv) {
-          pf->CallPacked(args, rv);
-        });
-      break;
-    }
-    case RPCCode::kException: {
-      CHECK_EQ(arg_buf_->value.size(), 1U);
-      CHECK_EQ(arg_buf_->tcode[0], kTVMStr);
-      std::ostringstream os;
-      os << "Except caught from RPC call: " << arg_buf_->value[0].v_str;
-      arg_buf_.reset();
-      throw dmlc::Error(os.str());
-      break;
-    }
-    // system functions
-    case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
-    case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
-    case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
-    case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
-    case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break;
-    case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break;
-    case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break;
-    case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
-    case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
-    case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
-    case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
-    case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
-    case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
-    case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
-    case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
-    default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
-  }
-  CHECK_EQ(state_, kRecvCode);
-}
-
-PackedFunc WrapTimeEvaluator(PackedFunc pf,
-                             TVMContext ctx,
-                             int number,
-                             int repeat,
-                             int min_repeat_ms) {
-  if (static_cast<int>(ctx.device_type) == static_cast<int>(kDLMicroDev)) {
-    auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator");
-    CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled";
-    return (*get_micro_time_evaluator)(pf, ctx, number, repeat);
-  }
-
-  auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable {
-    TVMRetValue temp;
-    std::ostringstream os;
-    // skip first time call, to activate lazy compilation components.
-    pf.CallPacked(args, &temp);
-    DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
-
-    for (int i = 0; i < repeat; ++i) {
-      std::chrono::time_point<
-        std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
-      double duration_ms = 0.0;
-
-      do {
-        if (duration_ms > 0.0) {
-          number = static_cast<int>(
-              std::max((min_repeat_ms / (duration_ms / number) + 1),
-                       number * 1.618));   // 1.618 is chosen by random
-        }
-
-        tbegin = std::chrono::high_resolution_clock::now();
-        // start timing
-        for (int i = 0; i < number; ++i) {
-          pf.CallPacked(args, &temp);
-        }
-        DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
-        tend = std::chrono::high_resolution_clock::now();
-
-        duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
-            (tend - tbegin).count() * 1000;
-      } while (duration_ms < min_repeat_ms);
-
-      double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
-          tend - tbegin).count() / number;
-      os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
-    }
-    std::string blob = os.str();
-    TVMByteArray arr;
-    arr.size = blob.length();
-    arr.data = blob.data();
-    // return the time.
-    *rv = arr;
-  };
-  return PackedFunc(ftimer);
-}
-
-size_t CallbackChannel::Send(const void* data, size_t size) {
-  TVMByteArray bytes;
-  bytes.data = static_cast<const char*>(data);
-  bytes.size = size;
-  int64_t n = fsend_(bytes);
-  if (n == -1) {
-    support::Socket::Error("CallbackChannel::Send");
-  }
-  return static_cast<size_t>(n);
-}
-
-size_t CallbackChannel::Recv(void* data, size_t size) {
-  TVMRetValue ret = frecv_(size);
-
-  if (ret.type_code() != kTVMBytes) {
-    support::Socket::Error("CallbackChannel::Recv");
-  }
-  std::string* bytes = ret.ptr<std::string>();
-  memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
-  return bytes->length();
+void RPCSession::InsertToSessionTable(std::shared_ptr<RPCSession> sess) {
+  CHECK_EQ(sess->table_index_, 0);
+  sess->table_index_ = RPCSessTable::Global()->Insert(sess);
 }
 
 }  // namespace runtime
index db63be4..a715e7b 100644 (file)
 #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
 #define TVM_RUNTIME_RPC_RPC_SESSION_H_
 
+
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/device_api.h>
-#include <mutex>
-#include <string>
+#include <functional>
 #include <memory>
-#include <utility>
-#include "../../support/ring_buffer.h"
+#include <string>
 
 namespace tvm {
 namespace runtime {
 
-// Magic header for RPC data plane
-const int kRPCMagic = 0xff271;
-// magic header for RPC tracker(control plane)
-const int kRPCTrackerMagic = 0x2f271;
-// sucess response
-const int kRPCSuccess = kRPCMagic + 0;
-// cannot found matched key in server
-const int kRPCMismatch = kRPCMagic + 2;
-
-/*! \brief Enumeration code for the RPC tracker */
-enum class TrackerCode : int {
-    kFail = -1,
-    kSuccess = 0,
-    kPing = 1,
-    kStop = 2,
-    kPut = 3,
-    kRequest = 4,
-    kUpdateInfo = 5,
-    kSummary = 6,
-    kGetPendingMatchKeys = 7
-};
-/*! \brief The remote functio handle */
-using RPCFuncHandle = void*;
-
-struct RPCArgBuffer;
-
-/*! \brief The RPC code */
-enum class RPCCode : int {
-  kNone,
-  kCallFunc,
-  kReturn,
-  kException,
-  kShutdown,
-  kCopyFromRemote,
-  kCopyToRemote,
-  kCopyAck,
-  // The following are code that can send over CallRemote
-  kSystemFuncStart,
-  kGetGlobalFunc,
-  kGetTimeEvaluator,
-  kFreeFunc,
-  kDevSetDevice,
-  kDevGetAttr,
-  kDevAllocData,
-  kDevFreeData,
-  kDevStreamSync,
-  kCopyAmongRemote,
-  kModuleLoad,
-  kModuleImport,
-  kModuleFree,
-  kModuleGetFunc,
-  kModuleGetSource,
-  kNDArrayFree
-};
-
 /*!
- * \brief Function that unwraps a remote object to its handle.
- * \param rpc_sess_table_index RPC session table index for validation.
- * \param obj Handle to the object argument.
- * \return The corresponding handle.
- */
-typedef void* (*FUnwrapRemoteObject)(
-    int rpc_sess_table_index,
-    const TVMArgValue& obj);
-
-/*!
- * \brief Abstract channel interface used to create RPCSession.
+ * \brief The interface of all remote RPC sessions.
+ *
+ *  It contains all the necessary interface to implement
+ *  remote call and resource management.
+ *
+ *  The interface is designed to allow easy proxy-chaining
+ *  by forward requests to another RPCSession.
  */
-class RPCChannel {
+class RPCSession {
  public:
-  /*! \brief virtual destructor */
-  virtual ~RPCChannel() {}
-  /*!
-   * \brief Send data over to the channel.
-   * \param data The data pointer.
-   * \param size The size fo the data.
-   * \return The actual bytes sent.
-   */
-  virtual size_t Send(const void* data, size_t size) = 0;
+  /*! \brief PackedFunc Handle in the remote. */
+  using PackedFuncHandle = void*;
+
+  /*! \brief Module handle in the remote. */
+  using ModuleHandle = void*;
+
+  /*! \brief NDArray handle in the remote. */
+  using NDArrayHandle = void*;
+
   /*!
-   * \brief Recv data from channel.
+   * \brief Callback to send an encoded return values via encode_args.
+   *
+   * \param encode_args The arguments that we can encode the return values into.
+   * \param ret_tcode The actual remote type code of the return value.
    *
-   * \param data The data pointer.
-   * \param size The size fo the data.
-   * \return The actual bytes received.
+   * Encoding convention (as list of arguments):
+   * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention.
+   * - PackedFunc/Module: [tcode: int, handle: void*]
+   * - NDArray: [tcode: int,  meta: DLTensor*, nd_handle: void*]
+   *            DLTensor* contains the meta-data as well as handle into the remote data.
+   *            nd_handle can be used for deletion.
    */
-  virtual size_t Recv(void* data, size_t size) = 0;
-};
+  using FEncodeReturn = std::function<void(TVMArgs encoded_args)>;
+
+  /*! \brief Destructor.*/
+  virtual ~RPCSession() {}
 
-// Bidirectional Communication Session of PackedRPC
-class RPCSession {
- public:
-  /*! \brief virtual destructor */
-  ~RPCSession();
   /*!
-   *  \brief The server loop that server runs to handle RPC calls.
+   * \brief Get function in the session.
+   * \param name The name of the function.
+   * \return The function handle.
    */
-  void ServerLoop();
+  virtual PackedFuncHandle GetFunction(const std::string& name) = 0;
+
   /*!
-   * \brief Message handling function for event driven server.
-   *  Called when the server receives a message.
-   *  Event driven handler will never call recv on the channel
-   *  and always relies on the ServerEventHandler.
-   *  to receive the data.
+   * \brief Call into a remote Packed function.
    *
-   * \param in_bytes The incoming bytes.
-   * \param event_flag  1: read_available, 2: write_avaiable.
-   * \return State flag.
-   *     1: continue running, no need to write,
-   *     2: need to write
-   *     0: shutdown
-   */
-  int ServerEventHandler(const std::string& in_bytes,
-                         int event_flag);
-  /*!
-   * \brief Call into remote function
-   * \param handle The function handle
-   * \param args The arguments
-   * \param rv The return value.
-   * \param funpwrap Function that takes a remote object and returns the raw handle.
-   * \param fwrap Wrapper function to turn Function/Module handle into real return.
+   *  Calling convention:
+   *
+   *  - type_code is follows the PackedFunc convention.
+   *  - int/float/string/bytes follows the PackedFunc convention, all data are local.
+   *  - PackedFunc/Module and future remote objects: pass remote handle instead.
+   *  - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor
+   *                      points to a remote data handle returned by the Device API.
+   *                      The meta-data of the DLTensor sits on local.
+   *
+   *  The caller populates the arguments and manages these arguments.
+   *
+   *  The callee can change the content of arg_values and arg_type_codes
+   *  if they want to do inplace modify and forward.
+   *
+   *  The callee need to store the return value into ret_value.
+   *  - PackedFunc/Module are stored as void*
+   *  - NDArray is stored as local NDArray, whose data field is a remote handle.
+   *    Notably the NDArray's deleter won't delete remote handle.
+   *    It is up to the user of the RPCSession to such wrapping.
+   *  - In short, remote handles are "moved" as return values
+   *    and the callee needs to explicitly manage them by calling
+   *    the deleter functions when they are no longer needed.
+   *
+   * \param func The function handle.
+   * \param arg_values The argument values.
+   * \param arg_type_codes the type codes of the argument.
+   * \param num_args Number of arguments.
+   * \param fencode_return The function to set the return value,
+   *                       if not called, return value is null.
    */
-  void CallFunc(RPCFuncHandle handle,
-                TVMArgs args,
-                TVMRetValue* rv,
-                FUnwrapRemoteObject funwrap,
-                const PackedFunc* fwrap);
+  virtual void CallFunc(PackedFuncHandle func,
+                        const TVMValue* arg_values,
+                        const int* arg_type_codes,
+                        int num_args,
+                        const FEncodeReturn& fencode_return) = 0;
+
   /*!
    * \brief Copy bytes into remote array content.
-   * \param from The source host data.
-   * \param from_offset The byte offeset in the from.
-   * \param to The target array.
-   * \param to_offset The byte offset in the to.
+   * \param local_from The source host data.
+   * \param local_from_offset The byte offeset in the from.
+   * \param remote_to The target array.
+   * \param remote_to_offset The byte offset in the to.
    * \param nbytes The size of the memory in bytes.
-   * \param ctx_to The target context.
+   * \param remote_ctx_to The target context.
    * \param type_hint Hint of content data type.
    */
-  void CopyToRemote(void* from,
-                    size_t from_offset,
-                    void* to,
-                    size_t to_offset,
-                    size_t nbytes,
-                    TVMContext ctx_to,
-                    DLDataType type_hint);
+  virtual void CopyToRemote(void* local_from,
+                            size_t local_from_offset,
+                            void* remote_to,
+                            size_t remote_to_offset,
+                            size_t nbytes,
+                            TVMContext remote_ctx_to,
+                            DLDataType type_hint) = 0;
   /*!
    * \brief Copy bytes from remote array content.
-   * \param from The source host data.
-   * \param from_offset The byte offeset in the from.
+   * \param remote_from The source host data.
+   * \param remote_from_offset The byte offeset in the from.
    * \param to The target array.
    * \param to_offset The byte offset in the to.
    * \param nbytes The size of the memory in bytes.
-   * \param ctx_from The source context.
+   * \param remote_ctx_from The source context in the remote.
    * \param type_hint Hint of content data type.
    */
-  void CopyFromRemote(void* from,
-                      size_t from_offset,
-                      void* to,
-                      size_t to_offset,
-                      size_t nbytes,
-                      TVMContext ctx_from,
-                      DLDataType type_hint);
+  virtual void CopyFromRemote(void* remote_from,
+                              size_t remote_from_offset,
+                              void* local_to,
+                              size_t local_to_offset,
+                              size_t nbytes,
+                              TVMContext remote_ctx_from,
+                              DLDataType type_hint) = 0;
+
   /*!
-   * \brief Get a remote timer function on ctx.
-   *  This function consumes fhandle, caller should not call Free on fhandle.
-   *
-   * \param fhandle The function handle.
-   * \param ctx The ctx to run measurement on.
-   * \param number The number of times to run this function for taking average.
-          We call these runs as one `repeat` of measurement.
-   * \param repeat The number of times to repeat the measurement.
-          In total, the function will be invoked (1 + number x repeat) times,
-          where the first one is warm up and will be discarded.
-          The returned result contains `repeat` costs,
-          each of which is an average of `number` costs.
-   * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
-          By default, one `repeat` contains `number` runs. If this parameter is set,
-          the parameters `number` will be dynamically adjusted to meet the
-          minimum duration requirement of one `repeat`.
-          i.e., When the run time of one `repeat` falls below this time,
-          the `number` parameter will be automatically increased.
-   * \return A remote timer function
+   * \brief Free a remote function.
+   * \param handle The remote handle, can be NDArray/PackedFunc/Module
+   * \param type_code The type code of the underlying type.
    */
-  RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
-                                 TVMContext ctx,
-                                 int number,
-                                 int repeat,
-                                 int min_repeat_ms);
+  virtual void FreeHandle(void* handle, int type_code) = 0;
+
   /*!
-   * \brief Call a remote defined system function with arguments.
-   * \param fcode The function code.
-   * \param args The arguments
-   * \return The returned remote value.
+   * \brief Get device API that represents the remote
+   *  actions that can be taken on the remote.
+   *
+   *  The caller can then call into the Alloc/Free functions
+   *  to allocate free spaces and taking the pointer as the handle.
+   *
+   *  The device API is guaranteed to be alive during the
+   *  lifetime of the Session.
+   *
+   * \param ctx The remote context.
+   * \param allow_missing Whether can we return nullptr if it is not available.
+   *
+   * \return The device API.
    */
-  template<typename... Args>
-  inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
+  virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0;
+
   /*!
    * \return The session table index of the session.
    */
   int table_index() const {
     return table_index_;
   }
-  /*!
-   * \brief Create a RPC session with given channel.
-   * \param channel The communication channel.
-   * \param name The local name of the session, used for debug
-   * \param remote_key The remote key of the session
-   *   if remote_key equals "%toinit", we need to re-intialize
-   *   it by event handler.
-   */
-  static std::shared_ptr<RPCSession> Create(
-      std::unique_ptr<RPCChannel> channel,
-      std::string name,
-      std::string remote_key);
+
   /*!
    * \brief Try get session from the global session table by table index.
    * \param table_index The table index of the session.
@@ -256,62 +192,25 @@ class RPCSession {
   static std::shared_ptr<RPCSession> Get(int table_index);
 
  private:
-  class EventHandler;
-  // Handle events until receives a return
-  // Also flushes channels so that the function advances.
-  RPCCode HandleUntilReturnEvent(
-      TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
-  // Initalization
-  void Init();
-  // Shutdown
-  void Shutdown();
-  // Internal channel.
-  std::unique_ptr<RPCChannel> channel_;
-  // Internal mutex
-  std::recursive_mutex mutex_;
-  // Internal ring buffer.
-  support::RingBuffer reader_, writer_;
-  // Event handler.
-  std::shared_ptr<EventHandler> handler_;
-  // call remote with specified function code.
-  PackedFunc call_remote_;
-  // The index of this session in RPC session table.
+  /*! \brief index of this session in RPC session table */
   int table_index_{0};
-  // The name of the session.
-  std::string name_;
-  // The remote key
-  std::string remote_key_;
+  /*! \brief Insert the current session to the session table.*/
+  static void InsertToSessionTable(std::shared_ptr<RPCSession> sess);
+  // friend declaration
+  friend Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess);
 };
 
 /*!
- * \brief RPC channel which callback
- * frontend (Python/Java/etc.)'s send & recv function
+ * \brief Remote space handle cell used by the RPC runtime API.
+ *
+ *  When we allocate space using a rpc context, the data pointer
+ *  points to an allocated RemoteSpace.
  */
-class CallbackChannel final : public RPCChannel {
- public:
-  explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
-      : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}
-
-  ~CallbackChannel() {}
-  /*!
-   * \brief Send data over to the channel.
-   * \param data The data pointer.
-   * \param size The size fo the data.
-   * \return The actual bytes sent.
-   */
-  size_t Send(const void* data, size_t size) final;
-  /*!
-   * \brief Recv data from channel.
-   *
-   * \param data The data pointer.
-   * \param size The size fo the data.
-   * \return The actual bytes received.
-   */
-  size_t Recv(void* data, size_t size) final;
-
- private:
-  PackedFunc fsend_;
-  PackedFunc frecv_;
+struct RemoteSpace {
+  /*! \brief The remote data handle. */
+  void* data;
+  /*! \brief Reference to the underlying RPC session. */
+  std::shared_ptr<RPCSession> sess;
 };
 
 /*!
@@ -319,18 +218,18 @@ class CallbackChannel final : public RPCChannel {
  * \param f The function argument.
  * \param ctx The context.
  * \param number The number of times to run this function for taking average.
         We call these runs as one `repeat` of measurement.
*        We call these runs as one `repeat` of measurement.
  * \param repeat The number of times to repeat the measurement.
         In total, the function will be invoked (1 + number x repeat) times,
         where the first one is warm up and will be discarded.
         The returned result contains `repeat` costs,
         each of which is an average of `number` costs.
*        In total, the function will be invoked (1 + number x repeat) times,
*        where the first one is warm up and will be discarded.
*        The returned result contains `repeat` costs,
*        each of which is an average of `number` costs.
  * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
         By default, one `repeat` contains `number` runs. If this parameter is set,
         the parameters `number` will be dynamically adjusted to meet the
         minimum duration requirement of one `repeat`.
         i.e., When the run time of one `repeat` falls below this time,
         the `number` parameter will be automatically increased.
*        By default, one `repeat` contains `number` runs. If this parameter is set,
*        the parameters `number` will be dynamically adjusted to meet the
*        minimum duration requirement of one `repeat`.
*        i.e., When the run time of one `repeat` falls below this time,
*        the `number` parameter will be automatically increased.
  * \return f_timer A timer function.
  */
 PackedFunc WrapTimeEvaluator(PackedFunc f,
@@ -344,21 +243,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f,
  * \param sess The RPC session of the global module.
  * \return The created module.
  */
-Module CreateRPCModule(std::shared_ptr<RPCSession> sess);
+Module CreateRPCSessionModule(std::shared_ptr<RPCSession> sess);
 
-// Remote space pointer.
-struct RemoteSpace {
-  void* data;
-  std::shared_ptr<RPCSession> sess;
-};
+/*!
+ * \brief Get the session module from a RPC session Module.
+ * \param mod The input module(must be an RPCModule).
+ * \return The internal RPCSession.
+ */
+std::shared_ptr<RPCSession> RPCModuleGetSession(Module mod);
 
-// implementation of inline functions
-template<typename... Args>
-inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
-  std::lock_guard<std::recursive_mutex> lock(mutex_);
-  writer_.Write(&code, sizeof(code));
-  return call_remote_(std::forward<Args>(args)...);
-}
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_RPC_RPC_SESSION_H_
index 642fbb8..f3a30dd 100644 (file)
  * \brief Socket based RPC implementation.
  */
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <memory>
+#include "rpc_endpoint.h"
 #include "rpc_session.h"
+#include "rpc_local_session.h"
 #include "../../support/socket.h"
 
 namespace tvm {
@@ -61,8 +64,8 @@ class SockChannel final : public RPCChannel {
   support::TCPSocket sock_;
 };
 
-std::shared_ptr<RPCSession>
-RPCConnect(std::string url, int port, std::string key) {
+std::shared_ptr<RPCEndpoint>
+RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) {
   support::TCPSocket sock;
   support::SockAddr addr(url.c_str(), port);
   sock.Create(addr.ss_family());
@@ -96,42 +99,56 @@ RPCConnect(std::string url, int port, std::string key) {
     remote_key.resize(keylen);
     CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
   }
-  return RPCSession::Create(
+  auto endpt = RPCEndpoint::Create(
       std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
+  endpt->InitRemoteSession(init_seq);
+  return endpt;
 }
 
-Module RPCClientConnect(std::string url, int port, std::string key) {
-  return CreateRPCModule(RPCConnect(url, port, "client:" + key));
+Module RPCClientConnect(std::string url,
+                        int port,
+                        std::string key,
+                        TVMArgs init_seq) {
+  auto endpt = RPCConnect(url, port, "client:" + key, init_seq);
+  return CreateRPCSessionModule(CreateClientSession(endpt));
 }
 
 // TVM_DLL needed for MSVC
 TVM_DLL void RPCServerLoop(int sockfd) {
   support::TCPSocket sock(
       static_cast<support::TCPSocket::SockType>(sockfd));
-  RPCSession::Create(
+  RPCEndpoint::Create(
       std::unique_ptr<SockChannel>(new SockChannel(sock)),
       "SockServerLoop", "")->ServerLoop();
 }
 
-void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
-  RPCSession::Create(std::unique_ptr<CallbackChannel>(
-      new CallbackChannel(fsend, frecv)),
+void RPCServerLoop(PackedFunc fsend,
+                   PackedFunc frecv) {
+  RPCEndpoint::Create(
+      std::unique_ptr<CallbackChannel>(new CallbackChannel(fsend, frecv)),
       "SockServerLoop", "")->ServerLoop();
 }
 
-TVM_REGISTER_GLOBAL("rpc._Connect")
-.set_body_typed(RPCClientConnect);
+TVM_REGISTER_GLOBAL("rpc.Connect")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  std::string url = args[0];
+  int port = args[1];
+  std::string key = args[2];
+  *rv = RPCClientConnect(
+      url, port, key,
+      TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3));
+});
 
-TVM_REGISTER_GLOBAL("rpc._ServerLoop")
+TVM_REGISTER_GLOBAL("rpc.ServerLoop")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    if (args.size() == 1) {
-      RPCServerLoop(args[0]);
-    } else {
-      CHECK_EQ(args.size(), 2);
-      RPCServerLoop(
-          args[0].operator tvm::runtime::PackedFunc(),
-          args[1].operator tvm::runtime::PackedFunc());
-    }
-  });
+  if (args[0].type_code() == kDLInt) {
+    RPCServerLoop(args[0]);
+  } else {
+    RPCServerLoop(
+        args[0].operator tvm::runtime::PackedFunc(),
+        args[1].operator tvm::runtime::PackedFunc());
+  }
+});
+
 }  // namespace runtime
 }  // namespace tvm
index 744ff4f..b062276 100644 (file)
 #ifndef TVM_SUPPORT_ARENA_H_
 #define TVM_SUPPORT_ARENA_H_
 
+#ifndef TVM_ARENA_HAS_DESTRUCTOR
+#define TVM_ARENA_HAS_DESTRUCTOR 1
+#endif
+
+#include <cstddef>
 #include <utility>
 #include <type_traits>
 
+
 namespace tvm {
 namespace support {
 
-const constexpr int kArenaPageSize = 16 << 10;
+/*!
+ * \brief An arena page header.
+ */
+struct ArenaPageHeader {
+  /*! \brief points to the next page. */
+  ArenaPageHeader* next;
+  /*!
+   * \brief Total size of the page.
+   */
+  size_t size;
+  /*! \brief memory allocator offset inside page. */
+  size_t offset;
+};
+
+/*!
+ * \brief Simple page allocator that uses new and delete.
+ */
+class SimplePageAllocator {
+ public:
+  /*!
+   * \brief Allocate a new page.
+   * \param min_size Minimum size of the page.
+   * \return The allocated page.
+   * \note This function can return a bigger page to meet the min_size requirement.
+   */
+  ArenaPageHeader* allocate(size_t min_size) {
+    size_t npages = ((min_size + kPageSize - 1) / kPageSize);
+    ArenaPageHeader* header = reinterpret_cast<ArenaPageHeader*>(new Page[npages]);
+    header->size = npages * kPageSize;
+    header->offset = sizeof(ArenaPageHeader);
+    return header;
+  }
+  /*!
+   * \brief De-allocate an allocate page.
+   * \param page The page to be de-allocated.
+   */
+  void deallocate(ArenaPageHeader* page) {
+    delete [] reinterpret_cast<Page*>(page);
+  }
+
+  static const constexpr int kPageSize = 16 << 10;
+  static const constexpr int kPageAlign = 1024;
+
+ private:
+  // page size 16 KB
+  // The page data type;
+  using Page = std::aligned_storage<kPageSize, kPageAlign>::type;
+};
 
 /*!
  * \brief Arena allocator that allocates memory from continuous
  *  chunk and frees them all only during destruction.
  */
-class Arena {
+template<typename PageAllocator>
+class GenericArena {
  public:
-  Arena() {
+  explicit GenericArena(PageAllocator alloc = PageAllocator())
+      : alloc_(alloc) {
     // eagerly allocate the first page.
-    head_ = reinterpret_cast<PageHeader*>(new Page());
+    head_ = tail_ = alloc_.allocate(1);
     head_->next = nullptr;
-    head_->ptr = sizeof(PageHeader);
   }
-  ~Arena() {
-    // delete all the allocated pages.
-    while (head_ != nullptr) {
-      Page* page = reinterpret_cast<Page*>(head_);
-      head_ = head_->next;
-      delete page;
-    }
+
+#if TVM_ARENA_HAS_DESTRUCTOR
+  ~GenericArena() {
+    this->FreeAll();
+  }
+#endif
+
+  /*! \brief Free all pages. */
+  void FreeAll() {
+    FreePageList(&head_);
+    FreePageList(&free_list_);
+  }
+  /*! \brief Recycle all the pages in the arena */
+  void RecycleAll() {
+    // put all the current list to the free list.
+    tail_->next = free_list_;
+    // allocate the first in the free list to head
+    free_list_ = head_->next;
+    head_->next = nullptr;
+    // Reset the head.
+    head_->offset = sizeof(ArenaPageHeader);
+    tail_ = head_;
   }
   /*!
    * \brief Allocate a space from Arena for type T
    * \param T the data type to be allocated
+   * \param count Numberof elements
    * \note The space of T is not initialized.
    */
   template<typename T>
-  T* allocate_() {
-    return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
+  T* allocate_(int count = 1) {
+    static_assert(PageAllocator::kPageAlign % alignof(T) == 0,
+                  "To large alignment");
+    return static_cast<T*>(Alloc(sizeof(T) * count, alignof(T)));
   }
   /*!
    * \brief Create a new instance of type T.
@@ -82,25 +154,21 @@ class Arena {
   }
 
  private:
-  // page size 16 KB
-  // The page data type;
-  using Page = std::aligned_storage<kArenaPageSize, 1024>::type;
-  /*! \brief Page header */
-  struct PageHeader {
-    /*! \brief points to the next page */
-    PageHeader* next;
-    /*! \brief memory allocator ptr inside page */
-    size_t ptr;
-  };
-  /* \brief The page header */
-  PageHeader* head_{nullptr};
+  /*! \brief internal page allocator. */
+  PageAllocator alloc_;
+  /* \brief The the head of the allocated list. */
+  ArenaPageHeader* head_{nullptr};
+  /*! \brief The tail of the allocated list. */
+  ArenaPageHeader* tail_{nullptr};
+  /* \brief List of free pages. */
+  ArenaPageHeader* free_list_{nullptr};
   /*!
    * \brief Align ptr by upper bound.
-   * \param ptr The pointer value.
+   * \param offset The offset value.
    * \param align The alignment requirement.
    */
-  size_t UpperAlign(size_t ptr, size_t align) {
-    return ptr + (align - (ptr % align)) % align;
+  size_t UpperAlign(size_t offset, size_t align) {
+    return offset + (align - (offset % align)) % align;
   }
   /*!
    * \brief Internal aligned alloc function.
@@ -108,22 +176,41 @@ class Arena {
    * \param align The alignment requirement.
    */
   void* Alloc(size_t size, size_t align) {
-    size_t ptr = UpperAlign(head_->ptr, align);
-    if (ptr + size <= kArenaPageSize) {
-      head_->ptr = ptr + size;
-      return reinterpret_cast<char*>(head_) + ptr;
+    size_t offset = UpperAlign(head_->offset, align);
+    if (offset + size <= head_->size) {
+      head_->offset = offset + size;
+      return reinterpret_cast<char*>(head_) + offset;
     } else {
-      PageHeader* new_head = reinterpret_cast<PageHeader*>(new Page());
+      ArenaPageHeader* new_head;
+      offset = UpperAlign(sizeof(ArenaPageHeader), align);
+      if (free_list_ != nullptr && offset + size <= free_list_-> size) {
+        new_head = free_list_;
+        free_list_ = free_list_->next;
+      } else {
+        new_head = alloc_.allocate(offset + size);
+      }
       new_head->next = head_;
-      ptr = UpperAlign(sizeof(PageHeader), align);
-      CHECK_LE(ptr + size, kArenaPageSize);
-      new_head->ptr = ptr + size;
+      new_head->offset = offset + size;
       head_ = new_head;
-      return reinterpret_cast<char*>(head_) + ptr;
+      return reinterpret_cast<char*>(head_) + offset;
+    }
+  }
+  /*!
+   * \brief Free all the pages in the list.
+   * \param ptr The head ptr.
+   */
+  void FreePageList(ArenaPageHeader** ptr) {
+    // delete all the allocated pages.
+    while (ptr[0] != nullptr) {
+      ArenaPageHeader* temp = ptr[0];
+      ptr[0] = ptr[0]->next;
+      alloc_.deallocate(temp);
     }
   }
 };
 
+using Arena = GenericArena<SimplePageAllocator>;
+
 /*!
  * \brief Link list node
  * \tparam T the content data type
index b61e6bb..091e942 100644 (file)
@@ -18,10 +18,12 @@ import tvm
 from tvm import te
 import tvm.testing
 import os
+import stat
 import logging
 import time
 import multiprocessing
 
+import pytest
 import numpy as np
 from tvm import rpc
 from tvm.contrib import util
@@ -77,11 +79,9 @@ def test_rpc_simple():
     f1 = client.get_function("rpc.test.addone")
     assert f1(10) == 11
     f3 = client.get_function("rpc.test.except")
-    try:
+
+    with pytest.raises(tvm.error.RPCError):
         f3("abc")
-        assert False
-    except tvm.error.TVMError as e:
-        assert "abc" in str(e)
 
     f2 = client.get_function("rpc.test.strcat")
     assert f2("abc", 11) == "abc:11"
@@ -101,6 +101,40 @@ def test_rpc_array():
     fremote = remote.get_function("rpc.test.remote_array_func")
     fremote(r_cpu)
 
+
+def test_rpc_echo():
+    def check(remote):
+        fecho = remote.get_function("testing.echo")
+        assert(fecho(1, 2, 3) == 1)
+        assert(fecho(100, 2, 3) == 100)
+        assert(fecho("xyz") == "xyz")
+        assert(bytes(fecho(bytearray(b"123"))) == b"123")
+
+        with pytest.raises(RuntimeError):
+            raise_err = remote.get_function(
+                "testing.test_raise_error_callback")("RuntimeError")
+            raise_err()
+
+    temp = rpc.server._server_env([])
+    server = rpc.Server("localhost")
+    client = rpc.connect(server.host, server.port)
+    check(rpc.LocalSession())
+
+    check(client)
+    # Test minrpc server.
+    temp = util.tempdir()
+    minrpc_exec = temp.relpath("minrpc")
+    tvm.rpc.with_minrpc("g++")(minrpc_exec, [])
+    check(rpc.PopenSession(minrpc_exec))
+    # minrpc on the remote
+    server = rpc.Server("localhost")
+    client = rpc.connect(
+        server.host, server.port,
+        session_constructor_args=["rpc.PopenSession",
+                             open(minrpc_exec, "rb").read()])
+    check(client)
+
+
 def test_rpc_file_exchange():
     if not tvm.runtime.enabled("rpc"):
         return
@@ -114,14 +148,15 @@ def test_rpc_file_exchange():
 def test_rpc_remote_module():
     if not tvm.runtime.enabled("rpc"):
         return
-    server = rpc.Server("localhost")
-    client = rpc.connect(server.host, server.port)
     # graph
-    n = tvm.runtime.convert(1024)
+    n = tvm.runtime.convert(102)
     A = te.placeholder((n,), name='A')
     B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
     s = te.create_schedule(B.op)
 
+    server = rpc.Server("localhost")
+    client = rpc.connect(server.host, server.port)
+
     def check_remote(remote):
         if not tvm.runtime.enabled("llvm"):
             print("Skip because llvm is not enabled")
@@ -133,13 +168,44 @@ def test_rpc_remote_module():
         f.export_library(path_dso)
         remote.upload(path_dso)
         f1 = remote.load_module("dev_lib.so")
-        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
         time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
         cost = time_f(a, b).mean
         print('%g secs/op' % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
+    def check_minrpc():
+        if not tvm.runtime.enabled("llvm"):
+            print("Skip because llvm is not enabled")
+            return
+        if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None:
+            return
+        # export to minrpc
+        temp = util.tempdir()
+        f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd")
+        path_minrpc = temp.relpath("dev_lib.minrpc")
+        f.export_library(path_minrpc, rpc.with_minrpc("g++"))
+
+        with pytest.raises(RuntimeError):
+            rpc.PopenSession("filenotexist")
+
+        # statrt the minrpc session.
+        remote = tvm.rpc.PopenSession(path_minrpc)
+        ctx = remote.cpu(0)
+        f1 = remote.system_lib()
+        a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
+        time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1)
+        cost = time_f(a, b).mean
+        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
+
+        # change to not executable
+        os.chmod(path_minrpc, stat.S_IRUSR)
+        with pytest.raises(RuntimeError):
+            rpc.PopenSession(path_minrpc)
+
+
     def check_remote_link_cl(remote):
         """Test function to run remote code such as cl
 
@@ -174,8 +240,8 @@ def test_rpc_remote_module():
         fhost = remote.load_module("myadd.o")
         fdev = remote.load_module("myadd.cl")
         fhost.import_module(fdev)
-        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
         fhost(a, b)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
         # Option 2: export library as a tar ball then handled by remote compiler
@@ -183,13 +249,15 @@ def test_rpc_remote_module():
         f.export_library(path_tar)
         remote.upload(path_tar)
         fhost = remote.load_module("myadd.tar")
-        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
         fhost(a, b)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
-    check_remote(client)
     check_remote(rpc.LocalSession())
+    check_remote(client)
+    check_minrpc()
+
 
 
 def test_rpc_return_func():
@@ -204,6 +272,37 @@ def test_rpc_return_func():
     assert fadd(12) == 22
 
 
+def test_rpc_session_constructor_args():
+    # start server
+    server0 = rpc.Server("localhost", key="x0")
+    server1 = rpc.Server("localhost", key="x1")
+
+    def check_multi_hop():
+        # use server0 as proxy to connect to server1
+        client = rpc.connect(
+            server0.host, server0.port, key="x0",
+            session_constructor_args=[
+                "rpc.Connect", server1.host, server1.port, "x1"])
+
+        fecho = client.get_function("testing.echo")
+        assert(fecho(1, 2, 3) == 1)
+        assert(fecho(100, 2, 3) == 100)
+        assert(fecho("xyz") == "xyz")
+        assert(bytes(fecho(bytearray(b"123"))) == b"123")
+
+        nd = tvm.nd.array([1,2,3], ctx=client.cpu(0))
+        assert(nd.asnumpy()[1] == 2)
+
+    def check_error_handling():
+        with pytest.raises(tvm.error.RPCError):
+            client = rpc.connect(
+                server0.host, server0.port, key="x0",
+                session_constructor_args=["rpc.NonExistingConstructor"])
+
+    check_multi_hop()
+    check_error_handling()
+
+
 def test_rpc_return_ndarray():
     # Use closure to check the ref counter correctness
     nd = tvm.nd.array(np.zeros(10).astype("float32"))
@@ -221,6 +320,7 @@ def test_rpc_return_ndarray():
     # start server
     server = rpc.Server("localhost", key="x1")
     client = rpc.connect(server.host, server.port, key="x1")
+
     m = client.get_function("rpc.test.remote_return_nd")
     get_arr = m("get_arr")
     ref_count = m("ref_count")
@@ -315,6 +415,7 @@ def test_rpc_tracker_request():
     time.sleep(0.5)
 
     summary = client.summary()
+
     assert summary['queue_info'][device_key]['free'] == 0
     assert summary['queue_info'][device_key]['pending'] == 1
 
@@ -334,6 +435,8 @@ def test_rpc_tracker_request():
 
 if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
+    test_rpc_echo()
+    test_rpc_session_constructor_args()
     test_rpc_return_ndarray()
     test_rpc_return_func()
     test_bigendian_rpc()
index b62b298..86ef59c 100644 (file)
@@ -907,7 +907,7 @@ var tvm_runtime = tvm_runtime || {};
 
         if (typeof systemFunc.fcreateServer === "undefined") {
           systemFunc.fcreateServer =
-            getGlobalFunc("rpc._CreateEventDrivenServer");
+            getGlobalFunc("rpc.CreateEventDrivenServer");
         }
         if (systemFunc.fcreateServer == null) {
           throwError("RPCServer is not included in runtime");