[RUTNIME] Support C++ RPC (#4281)
authorZhao Wu <wuzhaozju@gmail.com>
Sun, 10 Nov 2019 22:56:44 +0000 (06:56 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 10 Nov 2019 22:56:44 +0000 (14:56 -0800)
12 files changed:
apps/cpp_rpc/Makefile [new file with mode: 0644]
apps/cpp_rpc/README.md [new file with mode: 0644]
apps/cpp_rpc/main.cc [new file with mode: 0644]
apps/cpp_rpc/rpc_env.cc [new file with mode: 0644]
apps/cpp_rpc/rpc_env.h [new file with mode: 0644]
apps/cpp_rpc/rpc_server.cc [new file with mode: 0644]
apps/cpp_rpc/rpc_server.h [new file with mode: 0644]
apps/cpp_rpc/rpc_tracker_client.h [new file with mode: 0644]
src/common/socket.h
src/common/util.h [new file with mode: 0644]
src/runtime/rpc/rpc_session.h
src/runtime/rpc/rpc_socket_impl.h [new file with mode: 0644]

diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile
new file mode 100644 (file)
index 0000000..9cd39b4
--- /dev/null
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Makefile to compile RPC Server.
+TVM_ROOT=$(shell cd ../..; pwd)
+DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
+TVM_RUNTIME_DIR?=
+OS?=
+
+# Android can not link pthrad, but Linux need.
+ifeq ($(OS), Linux)
+LINK_PTHREAD=-lpthread
+else
+LINK_PTHREAD=
+endif
+
+PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\
+       -I${TVM_ROOT}/include\
+       -I${DMLC_CORE}/include\
+       -I${TVM_ROOT}/3rdparty/dlpack/include
+
+PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) $(LINK_PTHREAD) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR)
+
+ifeq ($(USE_GLOG), 1)
+        PKG_CFLAGS += -DDMLC_USE_GLOG=1
+        PKG_LDFLAGS += -lglog
+endif
+
+.PHONY: clean all
+
+all: tvm_rpc
+
+# Build rule for all in one TVM package library
+tvm_rpc: *.cc
+       @mkdir -p $(@D)
+       $(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS)
+
+clean:
+       -rm -f tvm_rpc
\ No newline at end of file
diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md
new file mode 100644 (file)
index 0000000..4baecaf
--- /dev/null
@@ -0,0 +1,56 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# TVM RPC Server
+This folder contains a simple recipe to make RPC server in c++.
+
+## Usage
+- Build tvm runtime
+- Make the rpc executable [Makefile](Makefile).
+  `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
+  if you want to compile it for embedded Linux, you should add `OS=Linux`.
+  if the target os is Android, you doesn't need to pass OS argument.
+  You could cross compile the TVM runtime like this:
+```
+  cd tvm
+  mkdir arm_runtime
+  cp cmake/config.cmake arm_runtime
+  cd arm_runtime
+  cmake .. -DCMAKE_CXX_COMPILER="/path/to/cross compiler g++/"
+  make runtime
+```
+- Use `./tvm_rpc server` to start the RPC server
+
+## How it works
+- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
+
+```
+Command line usage
+ server       - Start the server
+--host        - The hostname of the server, Default=0.0.0.0
+--port        - The port of the RPC, Default=9090
+--port-end    - The end search port of the RPC, Default=9199
+--tracker     - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=""
+--key         - The key used to identify the device type in tracker. Default=""
+--custom-addr - Custom IP Address to Report to RPC Tracker. Default=""
+--silent      - Whether to run in silent mode. Default=False
+  Example
+  ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp
+```
+
+## Note
+Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently.
\ No newline at end of file
diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc
new file mode 100644 (file)
index 0000000..3cf2ed6
--- /dev/null
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.cc
+ * \brief RPC Server for TVM.
+ */
+#include <stdlib.h>
+#include <signal.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <dmlc/logging.h>
+#include <iostream>
+#include <cstring>
+#include <vector>
+#include <sstream>
+
+#include "../../src/common/util.h"
+#include "../../src/common/socket.h"
+#include "rpc_server.h"
+
+using namespace std;
+using namespace tvm::runtime;
+using namespace tvm::common;
+
+static const string kUSAGE = \
+"Command line usage\n" \
+" server       - Start the server\n" \
+"--host        - The hostname of the server, Default=0.0.0.0\n" \
+"--port        - The port of the RPC, Default=9090\n" \
+"--port-end    - The end search port of the RPC, Default=9199\n" \
+"--tracker     - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \
+"--key         - The key used to identify the device type in tracker. Default=\"\"\n" \
+"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \
+"--silent      - Whether to run in silent mode. Default=False\n" \
+"\n" \
+"  Example\n" \
+"  ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
+" --tracker=127.0.0.1:9190 --key=rasp" \
+"\n";
+
+/*!
+ * \brief RpcServerArgs.
+ * \arg host The hostname of the server, Default=0.0.0.0
+ * \arg port The port of the RPC, Default=9090
+ * \arg port_end The end search port of the RPC, Default=9199
+ * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \arg key The key used to identify the device type in tracker. Default=""
+ * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \arg silent Whether run in silent mode. Default=False
+ */
+struct RpcServerArgs {
+  string host = "0.0.0.0";
+  int port = 9090;
+  int port_end = 9099;
+  string tracker;
+  string key;
+  string custom_addr;
+  bool silent = false;
+};
+
+/*!
+ * \brief PrintArgs print the contents of RpcServerArgs
+ * \param args RpcServerArgs structure
+ */
+void PrintArgs(struct RpcServerArgs args) {
+  LOG(INFO) << "host        = " << args.host;
+  LOG(INFO) << "port        = " << args.port;
+  LOG(INFO) << "port_end    = " << args.port_end;
+  LOG(INFO) << "tracker     = " << args.tracker;
+  LOG(INFO) << "key         = " << args.key;
+  LOG(INFO) << "custom_addr = " << args.custom_addr;
+  LOG(INFO) << "silent      = " << ((args.silent) ? ("True"): ("False"));
+}
+
+/*!
+ * \brief CtrlCHandler, exits if Ctrl+C is pressed
+ * \param s signal
+ */
+void CtrlCHandler(int s) {
+  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
+  exit(1);
+}
+
+/*!
+ * \brief HandleCtrlC Register for handling Ctrl+C event.
+ */
+void HandleCtrlC() {
+  // Ctrl+C handler
+  struct sigaction sigIntHandler;
+  sigIntHandler.sa_handler = CtrlCHandler;
+  sigemptyset(&sigIntHandler.sa_mask);
+  sigIntHandler.sa_flags = 0;
+  sigaction(SIGINT, &sigIntHandler, nullptr);
+}
+
+/*!
+ * \brief GetCmdOption Parse and find the command option.
+ * \param argc arg counter
+ * \param argv arg values
+ * \param option command line option to search for.
+ * \param key whether the option itself is key
+ * \return value corresponding to option.
+ */
+string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
+  string cmd;
+  for (int i = 1; i < argc; ++i) {
+    string arg = argv[i];
+    if (arg.find(option) == 0) {
+      if (key) {
+        cmd = argv[i];
+        return cmd;
+      }
+      // We assume "=" is the end of option.
+      CHECK_EQ(*option.rbegin(), '=');
+      cmd = arg.substr(arg.find("=") + 1);
+      return cmd;
+    }
+  }
+  return cmd;
+}
+
+/*!
+ * \brief ValidateTracker Check the tracker address format is correct and changes the format.
+ * \param tracker The tracker input.
+ * \return result of operation.
+ */
+bool ValidateTracker(string &tracker) {
+  vector<string> list = Split(tracker, ':');
+  if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) {
+    return false;
+  }
+  ostringstream ss;
+  ss << "('" << list[0] << "', " << list[1] << ")";
+  tracker = ss.str();
+  return true;
+}
+
+/*!
+ * \brief ParseCmdArgs parses the command line arguments.
+ * \param argc arg counter
+ * \param argv arg values
+ * \param args, the output structure which holds the parsed values
+ */
+void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
+  string silent = GetCmdOption(argc, argv, "--silent", true);
+  if (!silent.empty()) {
+    args.silent = true;
+    // Only errors and fatal is logged
+    dmlc::InitLogging("--minloglevel=2");
+  }
+
+  string host = GetCmdOption(argc, argv, "--host=");
+  if (!host.empty()) {
+    if (!ValidateIP(host)) {
+      LOG(WARNING) << "Wrong host address format.";
+      LOG(INFO) << kUSAGE;
+      exit(1);
+    }
+    args.host = host;
+  }
+
+  string port = GetCmdOption(argc, argv, "--port=");
+  if (!port.empty()) {
+    if (!IsNumber(port) || stoi(port) > 65535) {
+      LOG(WARNING) << "Wrong port number.";
+      LOG(INFO) << kUSAGE;
+      exit(1);
+    }
+    args.port = stoi(port);
+  }
+
+  string port_end = GetCmdOption(argc, argv, "--port_end=");
+  if (!port_end.empty()) {
+    if (!IsNumber(port_end) || stoi(port_end) > 65535) {
+      LOG(WARNING) << "Wrong port_end number.";
+      LOG(INFO) << kUSAGE;
+      exit(1);
+    }
+    args.port_end = stoi(port_end);
+  }
+
+  string tracker = GetCmdOption(argc, argv, "--tracker=");
+  if (!tracker.empty()) {
+    if (!ValidateTracker(tracker)) {
+      LOG(WARNING) << "Wrong tracker address format.";
+      LOG(INFO) << kUSAGE;
+      exit(1);
+    }
+    args.tracker = tracker;
+  }
+
+  string key = GetCmdOption(argc, argv, "--key=");
+  if (!key.empty()) {
+    args.key = key;
+  }
+
+  string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
+  if (!custom_addr.empty()) {
+    if (!ValidateIP(custom_addr)) {
+      LOG(WARNING) << "Wrong custom address format.";
+      LOG(INFO) << kUSAGE;
+      exit(1);
+    }
+    args.custom_addr = custom_addr;
+  }
+}
+
+/*!
+ * \brief RpcServer Starts the RPC server.
+ * \param argc arg counter
+ * \param argv arg values
+ * \return result of operation.
+ */
+int RpcServer(int argc, char * argv[]) {
+  struct RpcServerArgs args;
+
+  /* parse the command line args */
+  ParseCmdArgs(argc, argv, args);
+  PrintArgs(args);
+
+  // Ctrl+C handler
+  LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
+  HandleCtrlC();
+  tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
+                                args.key, args.custom_addr, args.silent);
+  return 0;
+}
+
+/*!
+ * \brief main The main function.
+ * \param argc arg counter
+ * \param argv arg values
+ * \return result of operation.
+ */
+int main(int argc, char * argv[]) {
+  if (argc <= 1) {
+    LOG(INFO) << kUSAGE;
+    return 0;
+  }
+
+  if (0 == strcmp(argv[1], "server")) {
+    RpcServer(argc, argv);
+  } else {
+    LOG(INFO) << kUSAGE;
+  }
+
+  return 0;
+}
diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc
new file mode 100644 (file)
index 0000000..44f848d
--- /dev/null
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file rpc_env.cc
+ * \brief Server environment of the RPC.
+ */
+#include <tvm/runtime/registry.h>
+#include <errno.h>
+#ifndef _MSC_VER
+#include <sys/stat.h>
+#include <dirent.h>
+#include <unistd.h>
+#else
+#include <Windows.h>
+#endif
+#include <fstream>
+#include <vector>
+#include <iostream>
+#include <string>
+#include <cstring>
+
+#include "rpc_env.h"
+#include "../../src/common/util.h"
+#include "../../src/runtime/file_util.h"
+
+namespace tvm {
+namespace runtime {
+
+RPCEnv::RPCEnv() {
+  #if defined(__linux__) || defined(__ANDROID__)
+    base_ = "./rpc";
+    mkdir(&base_[0], 0777);
+
+    TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+        static RPCEnv env;
+        *rv = env.GetPath(args[0]);
+      });
+
+    TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
+    .set_body([](TVMArgs args, TVMRetValue *rv) {
+        static RPCEnv env;
+        std::string file_name = env.GetPath(args[0]);
+        *rv = Load(&file_name, "");
+        LOG(INFO) << "Load module from " << file_name << " ...";
+      });
+  #else
+    LOG(FATAL) << "Only support RPC in linux environment";
+  #endif
+}
+/*!
+ * \brief GetPath To get the workpath from packed function
+ * \param name The file name
+ * \return The full path of file.
+ */
+std::string RPCEnv::GetPath(std::string file_name) {
+  // we assume file_name has "/" means file_name is the exact path
+  // and does not create /.rpc/
+  if (file_name.find("/") != std::string::npos) {
+    return file_name;
+  } else {
+    return base_ + "/" + file_name;
+  }
+}
+/*!
+ * \brief Remove The RPC Environment cleanup function
+ */
+void RPCEnv::CleanUp() {
+  #if defined(__linux__) || defined(__ANDROID__)
+    CleanDir(&base_[0]);
+    int ret = rmdir(&base_[0]);
+    if (ret != 0) {
+      LOG(WARNING) << "Remove directory " << base_ << " failed";
+    }
+  #else
+    LOG(FATAL) << "Only support RPC in linux environment";
+  #endif
+}
+
+/*!
+ * \brief ListDir get the list of files in a directory
+ * \param dirname The root directory name
+ * \return vector Files in directory.
+ */
+std::vector<std::string> ListDir(const std::string &dirname) {
+  std::vector<std::string> vec;
+  #ifndef _MSC_VER
+    DIR *dp = opendir(dirname.c_str());
+    if (dp == nullptr) {
+      int errsv = errno;
+      LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv);
+    }
+    dirent *d;
+    while ((d = readdir(dp)) != nullptr) {
+      std::string filename = d->d_name;
+      if (filename != "." && filename != "..") {
+        std::string f = dirname;
+        if (f[f.length() - 1] != '/') {
+          f += '/';
+        }
+        f += d->d_name;
+        vec.push_back(f);
+      }
+    }
+    closedir(dp);
+  #else
+    WIN32_FIND_DATA fd;
+    std::string pattern = dirname + "/*";
+    HANDLE handle = FindFirstFile(pattern.c_str(), &fd);
+    if (handle == INVALID_HANDLE_VALUE) {
+      int errsv = GetLastError();
+      LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
+    }
+    do {
+      if (fd.cFileName != "." && fd.cFileName != "..") {
+        std::string  f = dirname;
+        char clast = f[f.length() - 1];
+        if (f == ".") {
+          f = fd.cFileName;
+        } else if (clast != '/' && clast != '\\') {
+          f += '/';
+          f += fd.cFileName;
+        }
+        vec.push_back(f);
+      }
+    }  while (FindNextFile(handle, &fd));
+    FindClose(handle);
+  #endif
+  return vec;
+}
+
+/*!
+ * \brief LinuxShared Creates a linux shared library
+ * \param output The output file name
+ * \param files The files for building
+ * \param options The compiler options
+ * \param cc The compiler
+ */
+void LinuxShared(const std::string output,
+                 const std::vector<std::string> &files,
+                 std::string options = "",
+                 std::string cc = "g++") {
+    std::string cmd = cc;
+    cmd += " -shared -fPIC ";
+    cmd += " -o " + output;
+    for (auto f = files.begin(); f != files.end(); ++f) {
+     cmd += " " + *f;
+    }
+    cmd += " " + options;
+    std::string err_msg;
+    auto executed_status = common::Execute(cmd, &err_msg);
+    if (executed_status) {
+      LOG(FATAL) << err_msg;
+    }
+}
+
+/*!
+ * \brief CreateShared Creates a shared library
+ * \param output The output file name
+ * \param files The files for building
+ */
+void CreateShared(const std::string output, const std::vector<std::string> &files) {
+  #if defined(__linux__) || defined(__ANDROID__)
+    LinuxShared(output, files);
+  #else
+    LOG(FATAL) << "Do not support creating shared library";
+  #endif
+}
+
+/*!
+ * \brief Load Load module from file
+          This function will automatically call
+          cc.create_shared if the path is in format .o or .tar
+          High level handling for .o and .tar file.
+          We support this to be consistent with RPC module load.
+ * \param fileIn The input file, file name will be updated
+ * \param fmt The format of file
+ * \return Module The loaded module
+ */
+Module Load(std::string *fileIn, const std::string fmt) {
+  std::string file = *fileIn;
+  if (common::EndsWith(file, ".so")) {
+      return Module::LoadFromFile(file, fmt);
+  }
+
+  #if defined(__linux__) || defined(__ANDROID__)
+    std::string file_name = file + ".so";
+    if (common::EndsWith(file, ".o")) {
+      std::vector<std::string> files;
+      files.push_back(file);
+      CreateShared(file_name, files);
+    } else if (common::EndsWith(file, ".tar")) {
+      std::string tmp_dir = "./rpc/tmp/";
+      mkdir(&tmp_dir[0], 0777);
+      std::string cmd = "tar -C " + tmp_dir + " -zxf " + file;
+      std::string err_msg;
+      int executed_status = common::Execute(cmd, &err_msg);
+      if (executed_status) {
+        LOG(FATAL) << err_msg;
+      }
+      CreateShared(file_name, ListDir(tmp_dir));
+      CleanDir(tmp_dir);
+      rmdir(&tmp_dir[0]);
+    } else {
+      file_name = file;
+    }
+    *fileIn = file_name;
+    return Module::LoadFromFile(file_name, fmt);
+  #else
+    LOG(FATAL) << "Do not support creating shared library";
+  #endif
+}
+
+/*!
+ * \brief CleanDir Removes the files from the directory
+ * \param dirname The name of the directory
+ */
+void CleanDir(const std::string &dirname) {
+  #if defined(__linux__) || defined(__ANDROID__)
+    DIR *dp = opendir(dirname.c_str());
+    dirent *d;
+    while ((d = readdir(dp)) != nullptr) {
+      std::string filename = d->d_name;
+      if (filename != "." && filename != "..") {
+        filename = dirname + "/" + d->d_name;
+        int ret = std::remove(&filename[0]);
+        if (ret != 0) {
+          LOG(WARNING) << "Remove file " << filename << " failed";
+        }
+      }
+    }
+  #else
+    LOG(FATAL) << "Only support RPC in linux environment";
+  #endif
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h
new file mode 100644 (file)
index 0000000..82409ba
--- /dev/null
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_env.h
+ * \brief Server environment of the RPC.
+ */
+#ifndef TVM_APPS_CPP_RPC_ENV_H_
+#define TVM_APPS_CPP_RPC_ENV_H_
+
+#include <tvm/runtime/registry.h>
+#include <string>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Load Load module from file
+          This function will automatically call
+          cc.create_shared if the path is in format .o or .tar
+          High level handling for .o and .tar file.
+          We support this to be consistent with RPC module load.
+ * \param file The input file
+ * \param file The format of file
+ * \return Module The loaded module
+ */
+Module Load(std::string *path, const std::string fmt = "");
+
+/*!
+ * \brief CleanDir Removes the files from the directory
+ * \param dirname THe name of the directory
+ */
+void CleanDir(const std::string &dirname);
+
+/*!
+ * \brief RPCEnv The RPC Environment parameters for c++ rpc server
+ */
+struct RPCEnv {
+ public:
+  /*!
+   * \brief Constructor Init The RPC Environment initialize function
+   */
+  RPCEnv();
+  /*!
+   * \brief GetPath To get the workpath from packed function
+   * \param name The file name
+   * \return The full path of file.
+   */
+  std::string GetPath(std::string file_name);
+  /*!
+   * \brief The RPC Environment cleanup function
+   */
+  void CleanUp();
+
+ private:
+  /*!
+   * \brief Holds the environment path.
+   */
+  std::string base_;
+};  // RPCEnv
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_APPS_CPP_RPC_ENV_H_
diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc
new file mode 100644 (file)
index 0000000..b35a63b
--- /dev/null
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.cc
+ * \brief RPC Server implementation.
+ */
+#include <tvm/runtime/registry.h>
+
+#if defined(__linux__) || defined(__ANDROID__)
+#include <sys/select.h>
+#include <sys/wait.h>
+#endif
+#include <set>
+#include <iostream>
+#include <future>
+#include <thread>
+#include <chrono>
+#include <string>
+
+#include "rpc_server.h"
+#include "rpc_env.h"
+#include "rpc_tracker_client.h"
+#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/runtime/rpc/rpc_socket_impl.h"
+#include "../../src/common/socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief wait the child process end.
+ * \param status status value
+ */
+#if defined(__linux__) || defined(__ANDROID__)
+static pid_t waitPidEintr(int *status) {
+  pid_t pid = 0;
+  while ((pid = waitpid(-1, status, 0)) == -1) {
+    if (errno == EINTR) {
+      continue;
+    } else {
+      perror("waitpid");
+      abort();
+    }
+  }
+  return pid;
+}
+#endif
+
+/*!
+ * \brief RPCServer RPC Server class.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ */
+class RPCServer {
+ public:
+  /*!
+   * \brief Constructor.
+  */
+  RPCServer(const std::string &host,
+            int port,
+            int port_end,
+            const std::string &tracker_addr,
+            const std::string &key,
+            const std::string &custom_addr) {
+    // Init the values
+    host_ = host;
+    port_ = port;
+    port_end_ = port_end;
+    tracker_addr_ = tracker_addr;
+    key_ = key;
+    custom_addr_ = custom_addr;
+  }
+
+  /*!
+   * \brief Destructor.
+  */
+  ~RPCServer() {
+    // Free the resources
+    tracker_sock_.Close();
+    listen_sock_.Close();
+  }
+
+  /*!
+   * \brief Start Creates the RPC listen process and execution.
+  */
+  void Start() {
+    listen_sock_.Create();
+    my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_);
+    LOG(INFO) << "bind to " << host_ << ":" << my_port_;
+    listen_sock_.Listen(1);
+    std::future<void> proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this));
+    proc.get();
+    // Close the listen socket
+    listen_sock_.Close();
+  }
+
+ private:
+  /*!
+   * \brief ListenLoopProc The listen process.
+   */
+  void ListenLoopProc() {
+    TrackerClient tracker(tracker_addr_, key_, custom_addr_);
+    while (true) {
+      common::TCPSocket conn;
+      common::SockAddr addr("0.0.0.0", 0);
+      std::string opts;
+      try {
+        // step 1: setup tracker and report to tracker
+        tracker.TryConnect();
+        // step 2: wait for in-coming connections
+        AcceptConnection(&tracker, &conn, &addr, &opts);
+      }
+      catch (const char* msg) {
+        LOG(WARNING) << "Socket exception: " << msg;
+        // close tracker resource
+        tracker.Close();
+        continue;
+      }
+      catch (std::exception& e) {
+        // Other errors
+        LOG(WARNING) << "Exception standard: " << e.what();
+        continue;
+      }
+
+      int timeout = GetTimeOutFromOpts(opts);
+      #if defined(__linux__) || defined(__ANDROID__)
+        // step 3: serving
+        if (timeout != 0) {
+          const pid_t timer_pid = fork();
+          if (timer_pid == 0) {
+            // Timer process
+            sleep(timeout);
+            exit(0);
+          }
+
+          const pid_t worker_pid = fork();
+          if (worker_pid == 0) {
+            // Worker process
+            ServerLoopProc(conn, addr);
+            exit(0);
+          }
+
+          int status = 0;
+          const pid_t finished_first = waitPidEintr(&status);
+          if (finished_first == timer_pid) {
+            kill(worker_pid, SIGKILL);
+          } else if (finished_first == worker_pid) {
+            kill(timer_pid, SIGKILL);
+          } else {
+            LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
+          }
+
+          int status_second = 0;
+          waitPidEintr(&status_second);
+
+          // Logging.
+          if (finished_first == timer_pid) {
+            LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
+                      << "), Process status = " << status_second;
+          } else if (finished_first == worker_pid) {
+            LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
+          }
+        } else {
+          auto pid = fork();
+          if (pid == 0) {
+            ServerLoopProc(conn, addr);
+            exit(0);
+          }
+          // Wait for the result
+          int status = 0;
+          wait(&status);
+          LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
+        }
+      #else
+        // step 3: serving
+        std::future<void> proc(std::async(std::launch::async,
+                                          &RPCServer::ServerLoopProc, this, conn, addr));
+        // wait until server process finish or timeout
+        if (timeout != 0) {
+          // Autoterminate after timeout
+          proc.wait_for(std::chrono::seconds(timeout));
+        } else {
+          // Wait for the result
+          proc.get();
+        }
+      #endif
+      // close from our side.
+      LOG(INFO) << "Socket Connection Closed";
+      conn.Close();
+    }
+  }
+
+
+  /*!
+   * \brief AcceptConnection Accepts the RPC Server connection.
+   * \param tracker Tracker details.
+   * \param conn New connection information.
+   * \param addr New connection address information.
+   * \param opts Parsed options for socket
+   * \param ping_period Timeout for select call waiting
+   */
+  void AcceptConnection(TrackerClient* tracker,
+                        common::TCPSocket* conn_sock,
+                        common::SockAddr* addr,
+                        std::string* opts,
+                        int ping_period = 2) {
+    std::set <std::string> old_keyset;
+    std::string matchkey;
+
+    // Report resource to tracker and get key
+    tracker->ReportResourceAndGetKey(my_port_, &matchkey);
+
+    while (true) {
+      tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey);
+      common::TCPSocket conn = listen_sock_.Accept(addr);
+
+      int code = kRPCMagic;
+      CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
+      if (code != kRPCMagic) {
+        conn.Close();
+        LOG(FATAL) << "Client connected is not TVM RPC server";
+        continue;
+      }
+
+      int keylen = 0;
+      CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
+
+      const char* CLIENT_HEADER = "client:";
+      const char* SERVER_HEADER = "server:";
+      std::string expect_header = CLIENT_HEADER + matchkey;
+      std::string server_key = SERVER_HEADER + key_;
+      if (size_t(keylen) < expect_header.length()) {
+        conn.Close();
+        LOG(INFO) << "Wrong client header length";
+        continue;
+      }
+
+      CHECK_NE(keylen, 0);
+      std::string remote_key;
+      remote_key.resize(keylen);
+      CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen);
+
+      std::stringstream ssin(remote_key);
+      std::string arg0;
+      ssin >> arg0;
+      if (arg0 != expect_header) {
+          code = kRPCMismatch;
+          CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
+          conn.Close();
+          LOG(WARNING) << "Mismatch key from" << addr->AsString();
+          continue;
+      } else {
+        code = kRPCSuccess;
+        CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
+        keylen = server_key.length();
+        CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
+        CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen);
+        LOG(INFO) << "Connection success " << addr->AsString();
+        ssin >> *opts;
+        *conn_sock = conn;
+        return;
+      }
+    }
+  }
+
+  /*!
+   * \brief ServerLoopProc The Server loop process.
+   * \param sock The socket information
+   * \param addr The socket address information
+   */
+  void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) {
+      // Server loop
+      auto env = RPCEnv();
+      RPCServerLoop(sock.sockfd);
+      LOG(INFO) << "Finish serving " << addr.AsString();
+      env.CleanUp();
+  }
+
+  /*!
+   * \brief GetTimeOutFromOpts Parse and get the timeout option.
+   * \param opts The option string
+   * \param timeout value after parsing.
+   */
+  int GetTimeOutFromOpts(std::string opts) {
+    std::string cmd;
+    std::string option = "-timeout=";
+
+    if (opts.find(option) == 0) {
+      cmd = opts.substr(opts.find_last_of(option) + 1);
+      CHECK(common::IsNumber(cmd)) << "Timeout is not valid";
+      return std::stoi(cmd);
+    }
+    return 0;
+  }
+
+  std::string host_;
+  int port_;
+  int my_port_;
+  int port_end_;
+  std::string tracker_addr_;
+  std::string key_;
+  std::string custom_addr_;
+  common::TCPSocket listen_sock_;
+  common::TCPSocket tracker_sock_;
+};
+
+/*!
+ * \brief RPCServerCreate Creates the RPC Server.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \param silent Whether run in silent mode. Default=True
+ */
+void RPCServerCreate(std::string host,
+                     int port,
+                     int port_end,
+                     std::string tracker_addr,
+                     std::string key,
+                     std::string custom_addr,
+                     bool silent) {
+  if (silent) {
+    // Only errors and fatal is logged
+    dmlc::InitLogging("--minloglevel=2");
+  }
+  // Start the rpc server
+  RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr);
+  rpc.Start();
+}
+
+TVM_REGISTER_GLOBAL("rpc._ServerCreate")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
+  });
+}  // namespace runtime
+}  // namespace tvm
diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h
new file mode 100644 (file)
index 0000000..205182e
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.h
+ * \brief RPC Server implementation.
+ */
+#ifndef TVM_APPS_CPP_RPC_SERVER_H_
+#define TVM_APPS_CPP_RPC_SERVER_H_
+
+#include <string>
+#include "tvm/runtime/c_runtime_api.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief RPCServerCreate Creates the RPC Server.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \param silent Whether run in silent mode. Default=True
+ */
+TVM_DLL void RPCServerCreate(std::string host = "",
+                             int port = 9090,
+                             int port_end = 9099,
+                             std::string tracker_addr = "",
+                             std::string key = "",
+                             std::string custom_addr = "",
+                             bool silent = true);
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_APPS_CPP_RPC_SERVER_H_
diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h
new file mode 100644 (file)
index 0000000..89424c7
--- /dev/null
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_tracker_client.h
+ * \brief RPC Tracker client to report resources.
+ */
+#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
+#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
+
+#include <set>
+#include <iostream>
+#include <chrono>
+#include <random>
+#include <vector>
+#include <string>
+
+#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/common/socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief TrackerClient Tracker client class.
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ */
+class TrackerClient {
+ public:
+  /*!
+   * \brief Constructor.
+  */
+  TrackerClient(const std::string& tracker_addr,
+                const std::string& key,
+                const std::string& custom_addr)
+      : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr),
+        gen_(std::random_device{}()), dis_(0.0, 1.0) {
+  }
+  /*!
+   * \brief Destructor.
+  */
+  ~TrackerClient() {
+    // Free the resources
+    Close();
+  }
+  /*!
+   * \brief IsValid Check tracker is valid.
+  */
+  bool IsValid() {
+    return (!tracker_addr_.empty() && !tracker_sock_.IsClosed());
+  }
+  /*!
+   * \brief TryConnect Connect to tracker if the tracker address is valid.
+  */
+  void TryConnect() {
+    if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) {
+      tracker_sock_ = ConnectWithRetry();
+
+      int code = kRPCTrackerMagic;
+      CHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code));
+      CHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code));
+      CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker";
+
+      std::ostringstream ss;
+      ss << "[" << static_cast<int>(TrackerCode::kUpdateInfo)
+         << ", {\"key\": \"server:"<< key_ << "\"}]";
+      tracker_sock_.SendBytes(ss.str());
+
+      // Receive status and validate
+      std::string remote_status = tracker_sock_.RecvBytes();
+      CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+    }
+  }
+  /*!
+   * \brief Close Clean up tracker resources.
+  */
+  void Close() {
+    // close tracker resource
+    if (!tracker_sock_.IsClosed()) {
+      tracker_sock_.Close();
+    }
+  }
+ /*!
+  * \brief ReportResourceAndGetKey Report resource to tracker.
+  * \param port listening port.
+  * \param matchkey Random match key output.
+ */
+  void ReportResourceAndGetKey(int port,
+                               std::string *matchkey) {
+    if (!tracker_sock_.IsClosed()) {
+      *matchkey = RandomKey(key_ + ":", old_keyset_);
+      if (custom_addr_.empty()) {
+        custom_addr_ = "null";
+      }
+
+      std::ostringstream ss;
+      ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
+         << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+
+      tracker_sock_.SendBytes(ss.str());
+
+      // Receive status and validate
+      std::string remote_status = tracker_sock_.RecvBytes();
+      CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+    } else {
+        *matchkey = key_;
+    }
+  }
+
+  /*!
+   * \brief ReportResourceAndGetKey Report resource to tracker.
+   * \param listen_sock Listen socket details for select.
+   * \param port listening port.
+   * \param ping_period Select wait time.
+   * \param matchkey Random match key output.
+  */
+  void WaitConnectionAndUpdateKey(common::TCPSocket listen_sock,
+                                  int port,
+                                  int ping_period,
+                                  std::string *matchkey) {
+    int unmatch_period_count = 0;
+    int unmatch_timeout = 4;
+    while (true) {
+      if (!tracker_sock_.IsClosed()) {
+        common::PollHelper poller;
+        poller.WatchRead(listen_sock.sockfd);
+        poller.Poll(ping_period * 1000);
+        if (!poller.CheckRead(listen_sock.sockfd)) {
+          std::ostringstream ss;
+          ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]";
+          tracker_sock_.SendBytes(ss.str());
+
+          // Receive status and validate
+          std::string pending_keys = tracker_sock_.RecvBytes();
+          old_keyset_.insert(*matchkey);
+
+          // if match key not in pending key set
+          // it means the key is acquired by a client but not used.
+          if (pending_keys.find(*matchkey) == std::string::npos) {
+              unmatch_period_count += 1;
+          } else {
+              unmatch_period_count = 0;
+          }
+          // regenerate match key if key is acquired but not used for a while
+          if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) {
+            LOG(INFO) << "no incoming connections, regenerate key ...";
+
+            *matchkey = RandomKey(key_ + ":", old_keyset_);
+
+            std::ostringstream ss;
+            ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
+               << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+            tracker_sock_.SendBytes(ss.str());
+
+            std::string remote_status = tracker_sock_.RecvBytes();
+            CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+            unmatch_period_count = 0;
+          }
+          continue;
+        }
+      }
+      break;
+    }
+  }
+
+ private:
+  /*!
+   * \brief Connect to a RPC address with retry.
+            This function is only reliable to short period of server restart.
+   * \param timeout Timeout during retry
+   * \param retry_period Number of seconds before we retry again.
+   * \return TCPSocket The socket information if connect is success.
+   */
+  common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) {
+    auto tbegin = std::chrono::system_clock::now();
+    while (true) {
+      common::SockAddr addr(tracker_addr_);
+      common::TCPSocket sock;
+      sock.Create();
+      LOG(INFO) << "Tracker connecting to " << addr.AsString();
+      if (sock.Connect(addr)) {
+        return sock;
+      }
+
+      auto period = (std::chrono::duration_cast<std::chrono::seconds>(
+                  std::chrono::system_clock::now() - tbegin)).count();
+      CHECK(period < timeout) << "Failed to connect to server" << addr.AsString();
+      LOG(WARNING) << "Cannot connect to tracker " << addr.AsString()
+                   << " retry in " << retry_period << " seconds.";
+      std::this_thread::sleep_for(std::chrono::seconds(retry_period));
+    }
+  }
+  /*!
+  * \brief Random Generate a random number between 0 and 1.
+  * \return random float value.
+  */
+  float Random() {
+    return dis_(gen_);
+  }
+  /*!
+   * \brief Generate a random key.
+   * \param prefix The string prefix.
+   * \return cmap The conflict map set.
+   */
+  std::string RandomKey(const std::string& prefix, const std::set <std::string> &cmap) {
+    if (!cmap.empty()) {
+      while (true) {
+        std::string key = prefix + std::to_string(Random());
+        if (cmap.find(key) == cmap.end()) {
+          return key;
+        }
+      }
+    }
+    return prefix + std::to_string(Random());
+  }
+
+  std::string tracker_addr_;
+  std::string key_;
+  std::string custom_addr_;
+  common::TCPSocket tracker_sock_;
+  std::set <std::string> old_keyset_;
+  std::mt19937 gen_;
+  std::uniform_real_distribution<float> dis_;
+
+};
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
index 39bcff8..616991d 100644 (file)
@@ -43,12 +43,27 @@ using ssize_t = int;
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <sys/socket.h>
+#include <sys/select.h>
 #include <sys/ioctl.h>
 #endif
 #include <dmlc/logging.h>
 #include <string>
 #include <cstring>
+#include <vector>
+#include <unordered_map>
+#include "../common/util.h"
 
+#if defined(_WIN32)
+static inline int poll(struct pollfd *pfd, int nfds,
+                       int timeout) {
+  return WSAPoll(pfd, nfds, timeout);
+}
+static inline int inet_pton(int family, const char* addr_str, void* addr_buf) {
+  return InetPton(family, addr_str, addr_buf);
+}
+#else
+#include <sys/poll.h>
+#endif  // defined(_WIN32)
 
 namespace tvm {
 namespace common {
@@ -63,6 +78,22 @@ inline std::string GetHostName() {
 }
 
 /*!
+ * \brief ValidateIP validates an ip address.
+ * \param ip The ip address in string format localhost or x.x.x.x format
+ * \return result of operation.
+ */
+inline bool ValidateIP(std::string ip) {
+  if (ip == "localhost") {
+    return true;
+  }
+  struct sockaddr_in sa_ipv4;
+  struct sockaddr_in6 sa_ipv6;
+  bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr));
+  bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr));
+  return is_ipv4 || is_ipv6;
+}
+
+/*!
  * \brief Common data structure for network address.
  */
 struct SockAddr {
@@ -76,6 +107,23 @@ struct SockAddr {
   SockAddr(const char *url, int port) {
     this->Set(url, port);
   }
+
+  /*!
+  * \brief SockAddr Get the socket address from tracker.
+  * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
+  * \return SockAddr parsed from url.
+  */
+  explicit SockAddr(const std::string &url) {
+    size_t sep = url.find(",");
+    std::string host = url.substr(2, sep - 3);
+    std::string port = url.substr(sep + 1, url.length() - 1);
+    CHECK(ValidateIP(host)) << "Url address is not valid " << url;
+    if (host == "localhost") {
+      host = "127.0.0.1";
+    }
+    this->Set(host.c_str(), std::stoi(port));
+  }
+
   /*!
    * \brief set the address
    * \param host the url of the address
@@ -203,17 +251,20 @@ class Socket {
   }
   /*!
    * \brief try bind the socket to host, from start_port to end_port
+   * \param host host address to bind the socket
    * \param start_port starting port number to try
    * \param end_port ending port number to try
    * \return the port successfully bind to, return -1 if failed to bind any port
    */
-  inline int TryBindHost(int start_port, int end_port) {
+  inline int TryBindHost(std::string host, int start_port, int end_port) {
     for (int port = start_port; port < end_port; ++port) {
-      SockAddr addr("0.0.0.0", port);
+      SockAddr addr(host.c_str(), port);
       if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
                (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                   sizeof(sockaddr_in))) == 0) {
         return port;
+      } else {
+        LOG(WARNING) << "Bind failed to " << host << ":" << port;
       }
 #if defined(_WIN32)
       if (WSAGetLastError() != WSAEADDRINUSE) {
@@ -374,6 +425,20 @@ class TCPSocket : public Socket {
     return TCPSocket(newfd);
   }
   /*!
+  * \brief get a new connection
+  * \param addr client address from which connection accepted
+  * \return The accepted socket connection.
+  */
+  TCPSocket Accept(SockAddr *addr) {
+    socklen_t addrlen = sizeof(addr->addr);
+    SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr),
+                            &addrlen);
+    if (newfd == INVALID_SOCKET) {
+      Socket::Error("Accept");
+    }
+    return TCPSocket(newfd);
+  }
+  /*!
    * \brief decide whether the socket is at OOB mark
    * \return 1 if at mark, 0 if not, -1 if an error occurred
    */
@@ -468,7 +533,125 @@ class TCPSocket : public Socket {
     }
     return ndone;
   }
+  /*!
+   * \brief Send the data to remote.
+   * \param data The data to be sent.
+   */
+  void SendBytes(std::string data) {
+    int datalen = data.length();
+    CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen));
+    CHECK_EQ(SendAll(data.c_str(), datalen), datalen);
+  }
+  /*!
+   * \brief Receive the data to remote.
+   * \return The data received.
+   */
+  std::string RecvBytes() {
+    int datalen = 0;
+    CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen));
+    std::string data;
+    data.resize(datalen);
+    CHECK_EQ(RecvAll(&data[0], datalen), datalen);
+    return data;
+  }
 };
+
+/*! \brief helper data structure to perform poll */
+struct PollHelper {
+ public:
+  /*!
+   * \brief add file descriptor to watch for read
+   * \param fd file descriptor to be watched
+   */
+  inline void WatchRead(TCPSocket::SockType fd) {
+    auto& pfd = fds[fd];
+    pfd.fd = fd;
+    pfd.events |= POLLIN;
+  }
+  /*!
+   * \brief add file descriptor to watch for write
+   * \param fd file descriptor to be watched
+   */
+  inline void WatchWrite(TCPSocket::SockType fd) {
+    auto& pfd = fds[fd];
+    pfd.fd = fd;
+    pfd.events |= POLLOUT;
+  }
+  /*!
+   * \brief add file descriptor to watch for exception
+   * \param fd file descriptor to be watched
+   */
+  inline void WatchException(TCPSocket::SockType fd) {
+    auto& pfd = fds[fd];
+    pfd.fd = fd;
+    pfd.events |= POLLPRI;
+  }
+  /*!
+   * \brief Check if the descriptor is ready for read
+   * \param fd file descriptor to check status
+   */
+  inline bool CheckRead(TCPSocket::SockType fd) const {
+    const auto& pfd = fds.find(fd);
+    return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
+  }
+  /*!
+   * \brief Check if the descriptor is ready for write
+   * \param fd file descriptor to check status
+   */
+  inline bool CheckWrite(TCPSocket::SockType fd) const {
+    const auto& pfd = fds.find(fd);
+    return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
+  }
+  /*!
+   * \brief Check if the descriptor has any exception
+   * \param fd file descriptor to check status
+   */
+  inline bool CheckExcept(TCPSocket::SockType fd) const {
+    const auto& pfd = fds.find(fd);
+    return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
+  }
+  /*!
+   * \brief wait for exception event on a single descriptor
+   * \param fd the file descriptor to wait the event for
+   * \param timeout the timeout counter, can be negative, which means wait until the event happen
+   * \return 1 if success, 0 if timeout, and -1 if error occurs
+   */
+  inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
+    pollfd pfd;
+    pfd.fd = fd;
+    pfd.events = POLLPRI;
+    return poll(&pfd, 1, timeout);
+  }
+
+  /*!
+   * \brief peform poll on the set defined, read, write, exception
+   * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
+   * \return
+   */
+  inline void Poll(long timeout = -1) {  // NOLINT(*)
+    std::vector<pollfd> fdset;
+    fdset.reserve(fds.size());
+    for (auto kv : fds) {
+      fdset.push_back(kv.second);
+    }
+    int ret = poll(fdset.data(), fdset.size(), timeout);
+    if (ret == -1) {
+      Socket::Error("Poll");
+    } else {
+      for (auto& pfd : fdset) {
+        auto revents = pfd.revents & pfd.events;
+        if (!revents) {
+          fds.erase(pfd.fd);
+        } else {
+          fds[pfd.fd].events = revents;
+        }
+      }
+    }
+  }
+
+  std::unordered_map<TCPSocket::SockType, pollfd> fds;
+};
+
 }  // namespace common
 }  // namespace tvm
 #endif  // TVM_COMMON_SOCKET_H_
diff --git a/src/common/util.h b/src/common/util.h
new file mode 100644 (file)
index 0000000..93f32f4
--- /dev/null
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file util.h
+ * \brief Defines some common utility function..
+ */
+#ifndef TVM_COMMON_UTIL_H_
+#define TVM_COMMON_UTIL_H_
+
+#include <stdio.h>
+#ifndef _WIN32
+#include <sys/wait.h>
+#include <sys/types.h>
+#endif
+#include <vector>
+#include <string>
+#include <sstream>
+#include <algorithm>
+#include <array>
+#include <memory>
+
+namespace tvm {
+namespace common {
+/*!
+ * \brief TVMPOpen wrapper of popen between windows / unix.
+ * \param command executed command
+ * \param type "r" is for reading or "w" for writing.
+ * \return normal standard stream
+ */
+inline FILE* TVMPOpen(const char* command, const char* type) {
+#if defined(_WIN32)
+  return _popen(command, type);
+#else
+  return popen(command, type);
+#endif
+}
+
+/*!
+ * \brief TVMPClose wrapper of pclose between windows / linux
+ * \param stream the stream needed to be close.
+ * \return exit status
+ */
+inline int TVMPClose(FILE* stream) {
+#if defined(_WIN32)
+  return _pclose(stream);
+#else
+  return pclose(stream);
+#endif
+}
+
+/*!
+ * \brief TVMWifexited wrapper of WIFEXITED between windows / linux
+ * \param status The status field that was filled in by the wait or waitpid function
+ * \return the exit code of the child process
+ */
+inline int TVMWifexited(int status) {
+#if defined(_WIN32)
+  return (status != 3);
+#else
+  return WIFEXITED(status);
+#endif
+}
+
+/*!
+ * \brief TVMWexitstatus wrapper of WEXITSTATUS between windows / linux
+ * \param status The status field that was filled in by the wait or waitpid function.
+ * \return the child process exited normally or not
+ */
+inline int TVMWexitstatus(int status) {
+#if defined(_WIN32)
+  return status;
+#else
+  return WEXITSTATUS(status);
+#endif
+}
+
+
+/*!
+ * \brief IsNumber check whether string is a number.
+ * \param str input string
+ * \return result of operation.
+ */
+inline bool IsNumber(const std::string& str) {
+  return !str.empty() && std::find_if(str.begin(),
+      str.end(), [](char c) { return !std::isdigit(c); }) == str.end();
+}
+
+/*!
+ * \brief split Split the string based on delimiter
+ * \param str Input string
+ * \param delim The delimiter.
+ * \return vector of strings which are splitted.
+ */
+inline std::vector<std::string> Split(const std::string& str, char delim) {
+  std::string item;
+  std::istringstream is(str);
+  std::vector<std::string> ret;
+  while (std::getline(is, item, delim)) {
+    ret.push_back(item);
+  }
+  return ret;
+}
+
+/*!
+ * \brief EndsWith check whether the strings ends with
+ * \param value The full string
+ * \param end The end substring
+ * \return bool The result.
+ */
+inline bool EndsWith(std::string const& value, std::string const& end) {
+  if (end.size() <= value.size()) {
+    return std::equal(end.rbegin(), end.rend(), value.rbegin());
+  }
+  return false;
+}
+
+/*!
+ * \brief Execute the command
+ * \param cmd The command we want to execute
+ * \param err_msg The error message if we have
+ * \return executed output status
+ */
+inline int Execute(std::string cmd, std::string* err_msg) {
+  std::array<char, 128> buffer;
+  std::string result;
+  cmd += " 2>&1";
+  FILE* fd = TVMPOpen(cmd.c_str(), "r");
+  while (fgets(buffer.data(), buffer.size(), fd) != nullptr) {
+    *err_msg += buffer.data();
+  }
+  int status = TVMPClose(fd);
+  if (TVMWifexited(status)) {
+    return TVMWexitstatus(status);
+  }
+  return 255;
+}
+
+}  // namespace common
+}  // namespace tvm
+#endif  // TVM_COMMON_UTIL_H_
index d982f68..3518455 100644 (file)
 namespace tvm {
 namespace runtime {
 
+// Magic header for RPC data plane
 const int kRPCMagic = 0xff271;
+// magic header for RPC tracker(control plane)
+const int kRPCTrackerMagic = 0x2f271;
+// sucess response
+const int kRPCSuccess = kRPCMagic + 0;
+// cannot found matched key in server
+const int kRPCMismatch = kRPCMagic + 2;
 
+/*! \brief Enumeration code for the RPC tracker */
+enum class TrackerCode : int {
+    kFail = -1,
+    kSuccess = 0,
+    kPing = 1,
+    kStop = 2,
+    kPut = 3,
+    kRequest = 4,
+    kUpdateInfo = 5,
+    kSummary = 6,
+    kGetPendingMatchKeys = 7
+};
 /*! \brief The remote functio handle */
 using RPCFuncHandle = void*;
 
diff --git a/src/runtime/rpc/rpc_socket_impl.h b/src/runtime/rpc/rpc_socket_impl.h
new file mode 100644 (file)
index 0000000..ea7c839
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file rpc_socket_impl.h
+ * \brief Socket based RPC implementation.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
+#define TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief RPCServerLoop Start the rpc server loop.
+ * \param sockfd Socket file descriptor
+ */
+void RPCServerLoop(int sockfd);
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_