--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Makefile to compile RPC Server.
+TVM_ROOT=$(shell cd ../..; pwd)
+DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
+TVM_RUNTIME_DIR?=
+OS?=
+
+# Android can not link pthrad, but Linux need.
+ifeq ($(OS), Linux)
+LINK_PTHREAD=-lpthread
+else
+LINK_PTHREAD=
+endif
+
+PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\
+ -I${TVM_ROOT}/include\
+ -I${DMLC_CORE}/include\
+ -I${TVM_ROOT}/3rdparty/dlpack/include
+
+PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) $(LINK_PTHREAD) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR)
+
+ifeq ($(USE_GLOG), 1)
+ PKG_CFLAGS += -DDMLC_USE_GLOG=1
+ PKG_LDFLAGS += -lglog
+endif
+
+.PHONY: clean all
+
+all: tvm_rpc
+
+# Build rule for all in one TVM package library
+tvm_rpc: *.cc
+ @mkdir -p $(@D)
+ $(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS)
+
+clean:
+ -rm -f tvm_rpc
\ No newline at end of file
--- /dev/null
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# TVM RPC Server
+This folder contains a simple recipe to make RPC server in c++.
+
+## Usage
+- Build tvm runtime
+- Make the rpc executable [Makefile](Makefile).
+ `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
+ if you want to compile it for embedded Linux, you should add `OS=Linux`.
+ if the target os is Android, you doesn't need to pass OS argument.
+ You could cross compile the TVM runtime like this:
+```
+ cd tvm
+ mkdir arm_runtime
+ cp cmake/config.cmake arm_runtime
+ cd arm_runtime
+ cmake .. -DCMAKE_CXX_COMPILER="/path/to/cross compiler g++/"
+ make runtime
+```
+- Use `./tvm_rpc server` to start the RPC server
+
+## How it works
+- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
+
+```
+Command line usage
+ server - Start the server
+--host - The hostname of the server, Default=0.0.0.0
+--port - The port of the RPC, Default=9090
+--port-end - The end search port of the RPC, Default=9199
+--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=""
+--key - The key used to identify the device type in tracker. Default=""
+--custom-addr - Custom IP Address to Report to RPC Tracker. Default=""
+--silent - Whether to run in silent mode. Default=False
+ Example
+ ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp
+```
+
+## Note
+Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently.
\ No newline at end of file
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.cc
+ * \brief RPC Server for TVM.
+ */
+#include <stdlib.h>
+#include <signal.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <dmlc/logging.h>
+#include <iostream>
+#include <cstring>
+#include <vector>
+#include <sstream>
+
+#include "../../src/common/util.h"
+#include "../../src/common/socket.h"
+#include "rpc_server.h"
+
+using namespace std;
+using namespace tvm::runtime;
+using namespace tvm::common;
+
+static const string kUSAGE = \
+"Command line usage\n" \
+" server - Start the server\n" \
+"--host - The hostname of the server, Default=0.0.0.0\n" \
+"--port - The port of the RPC, Default=9090\n" \
+"--port-end - The end search port of the RPC, Default=9199\n" \
+"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \
+"--key - The key used to identify the device type in tracker. Default=\"\"\n" \
+"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \
+"--silent - Whether to run in silent mode. Default=False\n" \
+"\n" \
+" Example\n" \
+" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
+" --tracker=127.0.0.1:9190 --key=rasp" \
+"\n";
+
+/*!
+ * \brief RpcServerArgs.
+ * \arg host The hostname of the server, Default=0.0.0.0
+ * \arg port The port of the RPC, Default=9090
+ * \arg port_end The end search port of the RPC, Default=9199
+ * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \arg key The key used to identify the device type in tracker. Default=""
+ * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \arg silent Whether run in silent mode. Default=False
+ */
+struct RpcServerArgs {
+ string host = "0.0.0.0";
+ int port = 9090;
+ int port_end = 9099;
+ string tracker;
+ string key;
+ string custom_addr;
+ bool silent = false;
+};
+
+/*!
+ * \brief PrintArgs print the contents of RpcServerArgs
+ * \param args RpcServerArgs structure
+ */
+void PrintArgs(struct RpcServerArgs args) {
+ LOG(INFO) << "host = " << args.host;
+ LOG(INFO) << "port = " << args.port;
+ LOG(INFO) << "port_end = " << args.port_end;
+ LOG(INFO) << "tracker = " << args.tracker;
+ LOG(INFO) << "key = " << args.key;
+ LOG(INFO) << "custom_addr = " << args.custom_addr;
+ LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
+}
+
+/*!
+ * \brief CtrlCHandler, exits if Ctrl+C is pressed
+ * \param s signal
+ */
+void CtrlCHandler(int s) {
+ LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
+ exit(1);
+}
+
+/*!
+ * \brief HandleCtrlC Register for handling Ctrl+C event.
+ */
+void HandleCtrlC() {
+ // Ctrl+C handler
+ struct sigaction sigIntHandler;
+ sigIntHandler.sa_handler = CtrlCHandler;
+ sigemptyset(&sigIntHandler.sa_mask);
+ sigIntHandler.sa_flags = 0;
+ sigaction(SIGINT, &sigIntHandler, nullptr);
+}
+
+/*!
+ * \brief GetCmdOption Parse and find the command option.
+ * \param argc arg counter
+ * \param argv arg values
+ * \param option command line option to search for.
+ * \param key whether the option itself is key
+ * \return value corresponding to option.
+ */
+string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
+ string cmd;
+ for (int i = 1; i < argc; ++i) {
+ string arg = argv[i];
+ if (arg.find(option) == 0) {
+ if (key) {
+ cmd = argv[i];
+ return cmd;
+ }
+ // We assume "=" is the end of option.
+ CHECK_EQ(*option.rbegin(), '=');
+ cmd = arg.substr(arg.find("=") + 1);
+ return cmd;
+ }
+ }
+ return cmd;
+}
+
+/*!
+ * \brief ValidateTracker Check the tracker address format is correct and changes the format.
+ * \param tracker The tracker input.
+ * \return result of operation.
+ */
+bool ValidateTracker(string &tracker) {
+ vector<string> list = Split(tracker, ':');
+ if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) {
+ return false;
+ }
+ ostringstream ss;
+ ss << "('" << list[0] << "', " << list[1] << ")";
+ tracker = ss.str();
+ return true;
+}
+
+/*!
+ * \brief ParseCmdArgs parses the command line arguments.
+ * \param argc arg counter
+ * \param argv arg values
+ * \param args, the output structure which holds the parsed values
+ */
+void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
+ string silent = GetCmdOption(argc, argv, "--silent", true);
+ if (!silent.empty()) {
+ args.silent = true;
+ // Only errors and fatal is logged
+ dmlc::InitLogging("--minloglevel=2");
+ }
+
+ string host = GetCmdOption(argc, argv, "--host=");
+ if (!host.empty()) {
+ if (!ValidateIP(host)) {
+ LOG(WARNING) << "Wrong host address format.";
+ LOG(INFO) << kUSAGE;
+ exit(1);
+ }
+ args.host = host;
+ }
+
+ string port = GetCmdOption(argc, argv, "--port=");
+ if (!port.empty()) {
+ if (!IsNumber(port) || stoi(port) > 65535) {
+ LOG(WARNING) << "Wrong port number.";
+ LOG(INFO) << kUSAGE;
+ exit(1);
+ }
+ args.port = stoi(port);
+ }
+
+ string port_end = GetCmdOption(argc, argv, "--port_end=");
+ if (!port_end.empty()) {
+ if (!IsNumber(port_end) || stoi(port_end) > 65535) {
+ LOG(WARNING) << "Wrong port_end number.";
+ LOG(INFO) << kUSAGE;
+ exit(1);
+ }
+ args.port_end = stoi(port_end);
+ }
+
+ string tracker = GetCmdOption(argc, argv, "--tracker=");
+ if (!tracker.empty()) {
+ if (!ValidateTracker(tracker)) {
+ LOG(WARNING) << "Wrong tracker address format.";
+ LOG(INFO) << kUSAGE;
+ exit(1);
+ }
+ args.tracker = tracker;
+ }
+
+ string key = GetCmdOption(argc, argv, "--key=");
+ if (!key.empty()) {
+ args.key = key;
+ }
+
+ string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
+ if (!custom_addr.empty()) {
+ if (!ValidateIP(custom_addr)) {
+ LOG(WARNING) << "Wrong custom address format.";
+ LOG(INFO) << kUSAGE;
+ exit(1);
+ }
+ args.custom_addr = custom_addr;
+ }
+}
+
+/*!
+ * \brief RpcServer Starts the RPC server.
+ * \param argc arg counter
+ * \param argv arg values
+ * \return result of operation.
+ */
+int RpcServer(int argc, char * argv[]) {
+ struct RpcServerArgs args;
+
+ /* parse the command line args */
+ ParseCmdArgs(argc, argv, args);
+ PrintArgs(args);
+
+ // Ctrl+C handler
+ LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
+ HandleCtrlC();
+ tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
+ args.key, args.custom_addr, args.silent);
+ return 0;
+}
+
+/*!
+ * \brief main The main function.
+ * \param argc arg counter
+ * \param argv arg values
+ * \return result of operation.
+ */
+int main(int argc, char * argv[]) {
+ if (argc <= 1) {
+ LOG(INFO) << kUSAGE;
+ return 0;
+ }
+
+ if (0 == strcmp(argv[1], "server")) {
+ RpcServer(argc, argv);
+ } else {
+ LOG(INFO) << kUSAGE;
+ }
+
+ return 0;
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file rpc_env.cc
+ * \brief Server environment of the RPC.
+ */
+#include <tvm/runtime/registry.h>
+#include <errno.h>
+#ifndef _MSC_VER
+#include <sys/stat.h>
+#include <dirent.h>
+#include <unistd.h>
+#else
+#include <Windows.h>
+#endif
+#include <fstream>
+#include <vector>
+#include <iostream>
+#include <string>
+#include <cstring>
+
+#include "rpc_env.h"
+#include "../../src/common/util.h"
+#include "../../src/runtime/file_util.h"
+
+namespace tvm {
+namespace runtime {
+
+RPCEnv::RPCEnv() {
+ #if defined(__linux__) || defined(__ANDROID__)
+ base_ = "./rpc";
+ mkdir(&base_[0], 0777);
+
+ TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ static RPCEnv env;
+ *rv = env.GetPath(args[0]);
+ });
+
+ TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
+ .set_body([](TVMArgs args, TVMRetValue *rv) {
+ static RPCEnv env;
+ std::string file_name = env.GetPath(args[0]);
+ *rv = Load(&file_name, "");
+ LOG(INFO) << "Load module from " << file_name << " ...";
+ });
+ #else
+ LOG(FATAL) << "Only support RPC in linux environment";
+ #endif
+}
+/*!
+ * \brief GetPath To get the workpath from packed function
+ * \param name The file name
+ * \return The full path of file.
+ */
+std::string RPCEnv::GetPath(std::string file_name) {
+ // we assume file_name has "/" means file_name is the exact path
+ // and does not create /.rpc/
+ if (file_name.find("/") != std::string::npos) {
+ return file_name;
+ } else {
+ return base_ + "/" + file_name;
+ }
+}
+/*!
+ * \brief Remove The RPC Environment cleanup function
+ */
+void RPCEnv::CleanUp() {
+ #if defined(__linux__) || defined(__ANDROID__)
+ CleanDir(&base_[0]);
+ int ret = rmdir(&base_[0]);
+ if (ret != 0) {
+ LOG(WARNING) << "Remove directory " << base_ << " failed";
+ }
+ #else
+ LOG(FATAL) << "Only support RPC in linux environment";
+ #endif
+}
+
+/*!
+ * \brief ListDir get the list of files in a directory
+ * \param dirname The root directory name
+ * \return vector Files in directory.
+ */
+std::vector<std::string> ListDir(const std::string &dirname) {
+ std::vector<std::string> vec;
+ #ifndef _MSC_VER
+ DIR *dp = opendir(dirname.c_str());
+ if (dp == nullptr) {
+ int errsv = errno;
+ LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv);
+ }
+ dirent *d;
+ while ((d = readdir(dp)) != nullptr) {
+ std::string filename = d->d_name;
+ if (filename != "." && filename != "..") {
+ std::string f = dirname;
+ if (f[f.length() - 1] != '/') {
+ f += '/';
+ }
+ f += d->d_name;
+ vec.push_back(f);
+ }
+ }
+ closedir(dp);
+ #else
+ WIN32_FIND_DATA fd;
+ std::string pattern = dirname + "/*";
+ HANDLE handle = FindFirstFile(pattern.c_str(), &fd);
+ if (handle == INVALID_HANDLE_VALUE) {
+ int errsv = GetLastError();
+ LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
+ }
+ do {
+ if (fd.cFileName != "." && fd.cFileName != "..") {
+ std::string f = dirname;
+ char clast = f[f.length() - 1];
+ if (f == ".") {
+ f = fd.cFileName;
+ } else if (clast != '/' && clast != '\\') {
+ f += '/';
+ f += fd.cFileName;
+ }
+ vec.push_back(f);
+ }
+ } while (FindNextFile(handle, &fd));
+ FindClose(handle);
+ #endif
+ return vec;
+}
+
+/*!
+ * \brief LinuxShared Creates a linux shared library
+ * \param output The output file name
+ * \param files The files for building
+ * \param options The compiler options
+ * \param cc The compiler
+ */
+void LinuxShared(const std::string output,
+ const std::vector<std::string> &files,
+ std::string options = "",
+ std::string cc = "g++") {
+ std::string cmd = cc;
+ cmd += " -shared -fPIC ";
+ cmd += " -o " + output;
+ for (auto f = files.begin(); f != files.end(); ++f) {
+ cmd += " " + *f;
+ }
+ cmd += " " + options;
+ std::string err_msg;
+ auto executed_status = common::Execute(cmd, &err_msg);
+ if (executed_status) {
+ LOG(FATAL) << err_msg;
+ }
+}
+
+/*!
+ * \brief CreateShared Creates a shared library
+ * \param output The output file name
+ * \param files The files for building
+ */
+void CreateShared(const std::string output, const std::vector<std::string> &files) {
+ #if defined(__linux__) || defined(__ANDROID__)
+ LinuxShared(output, files);
+ #else
+ LOG(FATAL) << "Do not support creating shared library";
+ #endif
+}
+
+/*!
+ * \brief Load Load module from file
+ This function will automatically call
+ cc.create_shared if the path is in format .o or .tar
+ High level handling for .o and .tar file.
+ We support this to be consistent with RPC module load.
+ * \param fileIn The input file, file name will be updated
+ * \param fmt The format of file
+ * \return Module The loaded module
+ */
+Module Load(std::string *fileIn, const std::string fmt) {
+ std::string file = *fileIn;
+ if (common::EndsWith(file, ".so")) {
+ return Module::LoadFromFile(file, fmt);
+ }
+
+ #if defined(__linux__) || defined(__ANDROID__)
+ std::string file_name = file + ".so";
+ if (common::EndsWith(file, ".o")) {
+ std::vector<std::string> files;
+ files.push_back(file);
+ CreateShared(file_name, files);
+ } else if (common::EndsWith(file, ".tar")) {
+ std::string tmp_dir = "./rpc/tmp/";
+ mkdir(&tmp_dir[0], 0777);
+ std::string cmd = "tar -C " + tmp_dir + " -zxf " + file;
+ std::string err_msg;
+ int executed_status = common::Execute(cmd, &err_msg);
+ if (executed_status) {
+ LOG(FATAL) << err_msg;
+ }
+ CreateShared(file_name, ListDir(tmp_dir));
+ CleanDir(tmp_dir);
+ rmdir(&tmp_dir[0]);
+ } else {
+ file_name = file;
+ }
+ *fileIn = file_name;
+ return Module::LoadFromFile(file_name, fmt);
+ #else
+ LOG(FATAL) << "Do not support creating shared library";
+ #endif
+}
+
+/*!
+ * \brief CleanDir Removes the files from the directory
+ * \param dirname The name of the directory
+ */
+void CleanDir(const std::string &dirname) {
+ #if defined(__linux__) || defined(__ANDROID__)
+ DIR *dp = opendir(dirname.c_str());
+ dirent *d;
+ while ((d = readdir(dp)) != nullptr) {
+ std::string filename = d->d_name;
+ if (filename != "." && filename != "..") {
+ filename = dirname + "/" + d->d_name;
+ int ret = std::remove(&filename[0]);
+ if (ret != 0) {
+ LOG(WARNING) << "Remove file " << filename << " failed";
+ }
+ }
+ }
+ #else
+ LOG(FATAL) << "Only support RPC in linux environment";
+ #endif
+}
+
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_env.h
+ * \brief Server environment of the RPC.
+ */
+#ifndef TVM_APPS_CPP_RPC_ENV_H_
+#define TVM_APPS_CPP_RPC_ENV_H_
+
+#include <tvm/runtime/registry.h>
+#include <string>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Load Load module from file
+ This function will automatically call
+ cc.create_shared if the path is in format .o or .tar
+ High level handling for .o and .tar file.
+ We support this to be consistent with RPC module load.
+ * \param file The input file
+ * \param file The format of file
+ * \return Module The loaded module
+ */
+Module Load(std::string *path, const std::string fmt = "");
+
+/*!
+ * \brief CleanDir Removes the files from the directory
+ * \param dirname THe name of the directory
+ */
+void CleanDir(const std::string &dirname);
+
+/*!
+ * \brief RPCEnv The RPC Environment parameters for c++ rpc server
+ */
+struct RPCEnv {
+ public:
+ /*!
+ * \brief Constructor Init The RPC Environment initialize function
+ */
+ RPCEnv();
+ /*!
+ * \brief GetPath To get the workpath from packed function
+ * \param name The file name
+ * \return The full path of file.
+ */
+ std::string GetPath(std::string file_name);
+ /*!
+ * \brief The RPC Environment cleanup function
+ */
+ void CleanUp();
+
+ private:
+ /*!
+ * \brief Holds the environment path.
+ */
+ std::string base_;
+}; // RPCEnv
+
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_APPS_CPP_RPC_ENV_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.cc
+ * \brief RPC Server implementation.
+ */
+#include <tvm/runtime/registry.h>
+
+#if defined(__linux__) || defined(__ANDROID__)
+#include <sys/select.h>
+#include <sys/wait.h>
+#endif
+#include <set>
+#include <iostream>
+#include <future>
+#include <thread>
+#include <chrono>
+#include <string>
+
+#include "rpc_server.h"
+#include "rpc_env.h"
+#include "rpc_tracker_client.h"
+#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/runtime/rpc/rpc_socket_impl.h"
+#include "../../src/common/socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief wait the child process end.
+ * \param status status value
+ */
+#if defined(__linux__) || defined(__ANDROID__)
+static pid_t waitPidEintr(int *status) {
+ pid_t pid = 0;
+ while ((pid = waitpid(-1, status, 0)) == -1) {
+ if (errno == EINTR) {
+ continue;
+ } else {
+ perror("waitpid");
+ abort();
+ }
+ }
+ return pid;
+}
+#endif
+
+/*!
+ * \brief RPCServer RPC Server class.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ */
+class RPCServer {
+ public:
+ /*!
+ * \brief Constructor.
+ */
+ RPCServer(const std::string &host,
+ int port,
+ int port_end,
+ const std::string &tracker_addr,
+ const std::string &key,
+ const std::string &custom_addr) {
+ // Init the values
+ host_ = host;
+ port_ = port;
+ port_end_ = port_end;
+ tracker_addr_ = tracker_addr;
+ key_ = key;
+ custom_addr_ = custom_addr;
+ }
+
+ /*!
+ * \brief Destructor.
+ */
+ ~RPCServer() {
+ // Free the resources
+ tracker_sock_.Close();
+ listen_sock_.Close();
+ }
+
+ /*!
+ * \brief Start Creates the RPC listen process and execution.
+ */
+ void Start() {
+ listen_sock_.Create();
+ my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_);
+ LOG(INFO) << "bind to " << host_ << ":" << my_port_;
+ listen_sock_.Listen(1);
+ std::future<void> proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this));
+ proc.get();
+ // Close the listen socket
+ listen_sock_.Close();
+ }
+
+ private:
+ /*!
+ * \brief ListenLoopProc The listen process.
+ */
+ void ListenLoopProc() {
+ TrackerClient tracker(tracker_addr_, key_, custom_addr_);
+ while (true) {
+ common::TCPSocket conn;
+ common::SockAddr addr("0.0.0.0", 0);
+ std::string opts;
+ try {
+ // step 1: setup tracker and report to tracker
+ tracker.TryConnect();
+ // step 2: wait for in-coming connections
+ AcceptConnection(&tracker, &conn, &addr, &opts);
+ }
+ catch (const char* msg) {
+ LOG(WARNING) << "Socket exception: " << msg;
+ // close tracker resource
+ tracker.Close();
+ continue;
+ }
+ catch (std::exception& e) {
+ // Other errors
+ LOG(WARNING) << "Exception standard: " << e.what();
+ continue;
+ }
+
+ int timeout = GetTimeOutFromOpts(opts);
+ #if defined(__linux__) || defined(__ANDROID__)
+ // step 3: serving
+ if (timeout != 0) {
+ const pid_t timer_pid = fork();
+ if (timer_pid == 0) {
+ // Timer process
+ sleep(timeout);
+ exit(0);
+ }
+
+ const pid_t worker_pid = fork();
+ if (worker_pid == 0) {
+ // Worker process
+ ServerLoopProc(conn, addr);
+ exit(0);
+ }
+
+ int status = 0;
+ const pid_t finished_first = waitPidEintr(&status);
+ if (finished_first == timer_pid) {
+ kill(worker_pid, SIGKILL);
+ } else if (finished_first == worker_pid) {
+ kill(timer_pid, SIGKILL);
+ } else {
+ LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
+ }
+
+ int status_second = 0;
+ waitPidEintr(&status_second);
+
+ // Logging.
+ if (finished_first == timer_pid) {
+ LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
+ << "), Process status = " << status_second;
+ } else if (finished_first == worker_pid) {
+ LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
+ }
+ } else {
+ auto pid = fork();
+ if (pid == 0) {
+ ServerLoopProc(conn, addr);
+ exit(0);
+ }
+ // Wait for the result
+ int status = 0;
+ wait(&status);
+ LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
+ }
+ #else
+ // step 3: serving
+ std::future<void> proc(std::async(std::launch::async,
+ &RPCServer::ServerLoopProc, this, conn, addr));
+ // wait until server process finish or timeout
+ if (timeout != 0) {
+ // Autoterminate after timeout
+ proc.wait_for(std::chrono::seconds(timeout));
+ } else {
+ // Wait for the result
+ proc.get();
+ }
+ #endif
+ // close from our side.
+ LOG(INFO) << "Socket Connection Closed";
+ conn.Close();
+ }
+ }
+
+
+ /*!
+ * \brief AcceptConnection Accepts the RPC Server connection.
+ * \param tracker Tracker details.
+ * \param conn New connection information.
+ * \param addr New connection address information.
+ * \param opts Parsed options for socket
+ * \param ping_period Timeout for select call waiting
+ */
+ void AcceptConnection(TrackerClient* tracker,
+ common::TCPSocket* conn_sock,
+ common::SockAddr* addr,
+ std::string* opts,
+ int ping_period = 2) {
+ std::set <std::string> old_keyset;
+ std::string matchkey;
+
+ // Report resource to tracker and get key
+ tracker->ReportResourceAndGetKey(my_port_, &matchkey);
+
+ while (true) {
+ tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey);
+ common::TCPSocket conn = listen_sock_.Accept(addr);
+
+ int code = kRPCMagic;
+ CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
+ if (code != kRPCMagic) {
+ conn.Close();
+ LOG(FATAL) << "Client connected is not TVM RPC server";
+ continue;
+ }
+
+ int keylen = 0;
+ CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
+
+ const char* CLIENT_HEADER = "client:";
+ const char* SERVER_HEADER = "server:";
+ std::string expect_header = CLIENT_HEADER + matchkey;
+ std::string server_key = SERVER_HEADER + key_;
+ if (size_t(keylen) < expect_header.length()) {
+ conn.Close();
+ LOG(INFO) << "Wrong client header length";
+ continue;
+ }
+
+ CHECK_NE(keylen, 0);
+ std::string remote_key;
+ remote_key.resize(keylen);
+ CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen);
+
+ std::stringstream ssin(remote_key);
+ std::string arg0;
+ ssin >> arg0;
+ if (arg0 != expect_header) {
+ code = kRPCMismatch;
+ CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
+ conn.Close();
+ LOG(WARNING) << "Mismatch key from" << addr->AsString();
+ continue;
+ } else {
+ code = kRPCSuccess;
+ CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
+ keylen = server_key.length();
+ CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
+ CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen);
+ LOG(INFO) << "Connection success " << addr->AsString();
+ ssin >> *opts;
+ *conn_sock = conn;
+ return;
+ }
+ }
+ }
+
+ /*!
+ * \brief ServerLoopProc The Server loop process.
+ * \param sock The socket information
+ * \param addr The socket address information
+ */
+ void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) {
+ // Server loop
+ auto env = RPCEnv();
+ RPCServerLoop(sock.sockfd);
+ LOG(INFO) << "Finish serving " << addr.AsString();
+ env.CleanUp();
+ }
+
+ /*!
+ * \brief GetTimeOutFromOpts Parse and get the timeout option.
+ * \param opts The option string
+ * \param timeout value after parsing.
+ */
+ int GetTimeOutFromOpts(std::string opts) {
+ std::string cmd;
+ std::string option = "-timeout=";
+
+ if (opts.find(option) == 0) {
+ cmd = opts.substr(opts.find_last_of(option) + 1);
+ CHECK(common::IsNumber(cmd)) << "Timeout is not valid";
+ return std::stoi(cmd);
+ }
+ return 0;
+ }
+
+ std::string host_;
+ int port_;
+ int my_port_;
+ int port_end_;
+ std::string tracker_addr_;
+ std::string key_;
+ std::string custom_addr_;
+ common::TCPSocket listen_sock_;
+ common::TCPSocket tracker_sock_;
+};
+
+/*!
+ * \brief RPCServerCreate Creates the RPC Server.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \param silent Whether run in silent mode. Default=True
+ */
+void RPCServerCreate(std::string host,
+ int port,
+ int port_end,
+ std::string tracker_addr,
+ std::string key,
+ std::string custom_addr,
+ bool silent) {
+ if (silent) {
+ // Only errors and fatal is logged
+ dmlc::InitLogging("--minloglevel=2");
+ }
+ // Start the rpc server
+ RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr);
+ rpc.Start();
+}
+
+TVM_REGISTER_GLOBAL("rpc._ServerCreate")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
+ });
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_server.h
+ * \brief RPC Server implementation.
+ */
+#ifndef TVM_APPS_CPP_RPC_SERVER_H_
+#define TVM_APPS_CPP_RPC_SERVER_H_
+
+#include <string>
+#include "tvm/runtime/c_runtime_api.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief RPCServerCreate Creates the RPC Server.
+ * \param host The hostname of the server, Default=0.0.0.0
+ * \param port The port of the RPC, Default=9090
+ * \param port_end The end search port of the RPC, Default=9199
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ * \param silent Whether run in silent mode. Default=True
+ */
+TVM_DLL void RPCServerCreate(std::string host = "",
+ int port = 9090,
+ int port_end = 9099,
+ std::string tracker_addr = "",
+ std::string key = "",
+ std::string custom_addr = "",
+ bool silent = true);
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_APPS_CPP_RPC_SERVER_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file rpc_tracker_client.h
+ * \brief RPC Tracker client to report resources.
+ */
+#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
+#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
+
+#include <set>
+#include <iostream>
+#include <chrono>
+#include <random>
+#include <vector>
+#include <string>
+
+#include "../../src/runtime/rpc/rpc_session.h"
+#include "../../src/common/socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief TrackerClient Tracker client class.
+ * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param key The key used to identify the device type in tracker. Default=""
+ * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
+ */
+class TrackerClient {
+ public:
+ /*!
+ * \brief Constructor.
+ */
+ TrackerClient(const std::string& tracker_addr,
+ const std::string& key,
+ const std::string& custom_addr)
+ : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr),
+ gen_(std::random_device{}()), dis_(0.0, 1.0) {
+ }
+ /*!
+ * \brief Destructor.
+ */
+ ~TrackerClient() {
+ // Free the resources
+ Close();
+ }
+ /*!
+ * \brief IsValid Check tracker is valid.
+ */
+ bool IsValid() {
+ return (!tracker_addr_.empty() && !tracker_sock_.IsClosed());
+ }
+ /*!
+ * \brief TryConnect Connect to tracker if the tracker address is valid.
+ */
+ void TryConnect() {
+ if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) {
+ tracker_sock_ = ConnectWithRetry();
+
+ int code = kRPCTrackerMagic;
+ CHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code));
+ CHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code));
+ CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker";
+
+ std::ostringstream ss;
+ ss << "[" << static_cast<int>(TrackerCode::kUpdateInfo)
+ << ", {\"key\": \"server:"<< key_ << "\"}]";
+ tracker_sock_.SendBytes(ss.str());
+
+ // Receive status and validate
+ std::string remote_status = tracker_sock_.RecvBytes();
+ CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+ }
+ }
+ /*!
+ * \brief Close Clean up tracker resources.
+ */
+ void Close() {
+ // close tracker resource
+ if (!tracker_sock_.IsClosed()) {
+ tracker_sock_.Close();
+ }
+ }
+ /*!
+ * \brief ReportResourceAndGetKey Report resource to tracker.
+ * \param port listening port.
+ * \param matchkey Random match key output.
+ */
+ void ReportResourceAndGetKey(int port,
+ std::string *matchkey) {
+ if (!tracker_sock_.IsClosed()) {
+ *matchkey = RandomKey(key_ + ":", old_keyset_);
+ if (custom_addr_.empty()) {
+ custom_addr_ = "null";
+ }
+
+ std::ostringstream ss;
+ ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
+ << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+
+ tracker_sock_.SendBytes(ss.str());
+
+ // Receive status and validate
+ std::string remote_status = tracker_sock_.RecvBytes();
+ CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+ } else {
+ *matchkey = key_;
+ }
+ }
+
+ /*!
+ * \brief ReportResourceAndGetKey Report resource to tracker.
+ * \param listen_sock Listen socket details for select.
+ * \param port listening port.
+ * \param ping_period Select wait time.
+ * \param matchkey Random match key output.
+ */
+ void WaitConnectionAndUpdateKey(common::TCPSocket listen_sock,
+ int port,
+ int ping_period,
+ std::string *matchkey) {
+ int unmatch_period_count = 0;
+ int unmatch_timeout = 4;
+ while (true) {
+ if (!tracker_sock_.IsClosed()) {
+ common::PollHelper poller;
+ poller.WatchRead(listen_sock.sockfd);
+ poller.Poll(ping_period * 1000);
+ if (!poller.CheckRead(listen_sock.sockfd)) {
+ std::ostringstream ss;
+ ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]";
+ tracker_sock_.SendBytes(ss.str());
+
+ // Receive status and validate
+ std::string pending_keys = tracker_sock_.RecvBytes();
+ old_keyset_.insert(*matchkey);
+
+ // if match key not in pending key set
+ // it means the key is acquired by a client but not used.
+ if (pending_keys.find(*matchkey) == std::string::npos) {
+ unmatch_period_count += 1;
+ } else {
+ unmatch_period_count = 0;
+ }
+ // regenerate match key if key is acquired but not used for a while
+ if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) {
+ LOG(INFO) << "no incoming connections, regenerate key ...";
+
+ *matchkey = RandomKey(key_ + ":", old_keyset_);
+
+ std::ostringstream ss;
+ ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
+ << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
+ tracker_sock_.SendBytes(ss.str());
+
+ std::string remote_status = tracker_sock_.RecvBytes();
+ CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
+ unmatch_period_count = 0;
+ }
+ continue;
+ }
+ }
+ break;
+ }
+ }
+
+ private:
+ /*!
+ * \brief Connect to a RPC address with retry.
+ This function is only reliable to short period of server restart.
+ * \param timeout Timeout during retry
+ * \param retry_period Number of seconds before we retry again.
+ * \return TCPSocket The socket information if connect is success.
+ */
+ common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) {
+ auto tbegin = std::chrono::system_clock::now();
+ while (true) {
+ common::SockAddr addr(tracker_addr_);
+ common::TCPSocket sock;
+ sock.Create();
+ LOG(INFO) << "Tracker connecting to " << addr.AsString();
+ if (sock.Connect(addr)) {
+ return sock;
+ }
+
+ auto period = (std::chrono::duration_cast<std::chrono::seconds>(
+ std::chrono::system_clock::now() - tbegin)).count();
+ CHECK(period < timeout) << "Failed to connect to server" << addr.AsString();
+ LOG(WARNING) << "Cannot connect to tracker " << addr.AsString()
+ << " retry in " << retry_period << " seconds.";
+ std::this_thread::sleep_for(std::chrono::seconds(retry_period));
+ }
+ }
+ /*!
+ * \brief Random Generate a random number between 0 and 1.
+ * \return random float value.
+ */
+ float Random() {
+ return dis_(gen_);
+ }
+ /*!
+ * \brief Generate a random key.
+ * \param prefix The string prefix.
+ * \return cmap The conflict map set.
+ */
+ std::string RandomKey(const std::string& prefix, const std::set <std::string> &cmap) {
+ if (!cmap.empty()) {
+ while (true) {
+ std::string key = prefix + std::to_string(Random());
+ if (cmap.find(key) == cmap.end()) {
+ return key;
+ }
+ }
+ }
+ return prefix + std::to_string(Random());
+ }
+
+ std::string tracker_addr_;
+ std::string key_;
+ std::string custom_addr_;
+ common::TCPSocket tracker_sock_;
+ std::set <std::string> old_keyset_;
+ std::mt19937 gen_;
+ std::uniform_real_distribution<float> dis_;
+
+};
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
+#include <sys/select.h>
#include <sys/ioctl.h>
#endif
#include <dmlc/logging.h>
#include <string>
#include <cstring>
+#include <vector>
+#include <unordered_map>
+#include "../common/util.h"
+#if defined(_WIN32)
+static inline int poll(struct pollfd *pfd, int nfds,
+ int timeout) {
+ return WSAPoll(pfd, nfds, timeout);
+}
+static inline int inet_pton(int family, const char* addr_str, void* addr_buf) {
+ return InetPton(family, addr_str, addr_buf);
+}
+#else
+#include <sys/poll.h>
+#endif // defined(_WIN32)
namespace tvm {
namespace common {
}
/*!
+ * \brief ValidateIP validates an ip address.
+ * \param ip The ip address in string format localhost or x.x.x.x format
+ * \return result of operation.
+ */
+inline bool ValidateIP(std::string ip) {
+ if (ip == "localhost") {
+ return true;
+ }
+ struct sockaddr_in sa_ipv4;
+ struct sockaddr_in6 sa_ipv6;
+ bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr));
+ bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr));
+ return is_ipv4 || is_ipv6;
+}
+
+/*!
* \brief Common data structure for network address.
*/
struct SockAddr {
SockAddr(const char *url, int port) {
this->Set(url, port);
}
+
+ /*!
+ * \brief SockAddr Get the socket address from tracker.
+ * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
+ * \return SockAddr parsed from url.
+ */
+ explicit SockAddr(const std::string &url) {
+ size_t sep = url.find(",");
+ std::string host = url.substr(2, sep - 3);
+ std::string port = url.substr(sep + 1, url.length() - 1);
+ CHECK(ValidateIP(host)) << "Url address is not valid " << url;
+ if (host == "localhost") {
+ host = "127.0.0.1";
+ }
+ this->Set(host.c_str(), std::stoi(port));
+ }
+
/*!
* \brief set the address
* \param host the url of the address
}
/*!
* \brief try bind the socket to host, from start_port to end_port
+ * \param host host address to bind the socket
* \param start_port starting port number to try
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
- inline int TryBindHost(int start_port, int end_port) {
+ inline int TryBindHost(std::string host, int start_port, int end_port) {
for (int port = start_port; port < end_port; ++port) {
- SockAddr addr("0.0.0.0", port);
+ SockAddr addr(host.c_str(), port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
(addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
sizeof(sockaddr_in))) == 0) {
return port;
+ } else {
+ LOG(WARNING) << "Bind failed to " << host << ":" << port;
}
#if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) {
return TCPSocket(newfd);
}
/*!
+ * \brief get a new connection
+ * \param addr client address from which connection accepted
+ * \return The accepted socket connection.
+ */
+ TCPSocket Accept(SockAddr *addr) {
+ socklen_t addrlen = sizeof(addr->addr);
+ SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr),
+ &addrlen);
+ if (newfd == INVALID_SOCKET) {
+ Socket::Error("Accept");
+ }
+ return TCPSocket(newfd);
+ }
+ /*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occurred
*/
}
return ndone;
}
+ /*!
+ * \brief Send the data to remote.
+ * \param data The data to be sent.
+ */
+ void SendBytes(std::string data) {
+ int datalen = data.length();
+ CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen));
+ CHECK_EQ(SendAll(data.c_str(), datalen), datalen);
+ }
+ /*!
+ * \brief Receive the data to remote.
+ * \return The data received.
+ */
+ std::string RecvBytes() {
+ int datalen = 0;
+ CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen));
+ std::string data;
+ data.resize(datalen);
+ CHECK_EQ(RecvAll(&data[0], datalen), datalen);
+ return data;
+ }
};
+
+/*! \brief helper data structure to perform poll */
+struct PollHelper {
+ public:
+ /*!
+ * \brief add file descriptor to watch for read
+ * \param fd file descriptor to be watched
+ */
+ inline void WatchRead(TCPSocket::SockType fd) {
+ auto& pfd = fds[fd];
+ pfd.fd = fd;
+ pfd.events |= POLLIN;
+ }
+ /*!
+ * \brief add file descriptor to watch for write
+ * \param fd file descriptor to be watched
+ */
+ inline void WatchWrite(TCPSocket::SockType fd) {
+ auto& pfd = fds[fd];
+ pfd.fd = fd;
+ pfd.events |= POLLOUT;
+ }
+ /*!
+ * \brief add file descriptor to watch for exception
+ * \param fd file descriptor to be watched
+ */
+ inline void WatchException(TCPSocket::SockType fd) {
+ auto& pfd = fds[fd];
+ pfd.fd = fd;
+ pfd.events |= POLLPRI;
+ }
+ /*!
+ * \brief Check if the descriptor is ready for read
+ * \param fd file descriptor to check status
+ */
+ inline bool CheckRead(TCPSocket::SockType fd) const {
+ const auto& pfd = fds.find(fd);
+ return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
+ }
+ /*!
+ * \brief Check if the descriptor is ready for write
+ * \param fd file descriptor to check status
+ */
+ inline bool CheckWrite(TCPSocket::SockType fd) const {
+ const auto& pfd = fds.find(fd);
+ return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
+ }
+ /*!
+ * \brief Check if the descriptor has any exception
+ * \param fd file descriptor to check status
+ */
+ inline bool CheckExcept(TCPSocket::SockType fd) const {
+ const auto& pfd = fds.find(fd);
+ return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
+ }
+ /*!
+ * \brief wait for exception event on a single descriptor
+ * \param fd the file descriptor to wait the event for
+ * \param timeout the timeout counter, can be negative, which means wait until the event happen
+ * \return 1 if success, 0 if timeout, and -1 if error occurs
+ */
+ inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
+ pollfd pfd;
+ pfd.fd = fd;
+ pfd.events = POLLPRI;
+ return poll(&pfd, 1, timeout);
+ }
+
+ /*!
+ * \brief peform poll on the set defined, read, write, exception
+ * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
+ * \return
+ */
+ inline void Poll(long timeout = -1) { // NOLINT(*)
+ std::vector<pollfd> fdset;
+ fdset.reserve(fds.size());
+ for (auto kv : fds) {
+ fdset.push_back(kv.second);
+ }
+ int ret = poll(fdset.data(), fdset.size(), timeout);
+ if (ret == -1) {
+ Socket::Error("Poll");
+ } else {
+ for (auto& pfd : fdset) {
+ auto revents = pfd.revents & pfd.events;
+ if (!revents) {
+ fds.erase(pfd.fd);
+ } else {
+ fds[pfd.fd].events = revents;
+ }
+ }
+ }
+ }
+
+ std::unordered_map<TCPSocket::SockType, pollfd> fds;
+};
+
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_SOCKET_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file util.h
+ * \brief Defines some common utility function..
+ */
+#ifndef TVM_COMMON_UTIL_H_
+#define TVM_COMMON_UTIL_H_
+
+#include <stdio.h>
+#ifndef _WIN32
+#include <sys/wait.h>
+#include <sys/types.h>
+#endif
+#include <vector>
+#include <string>
+#include <sstream>
+#include <algorithm>
+#include <array>
+#include <memory>
+
+namespace tvm {
+namespace common {
+/*!
+ * \brief TVMPOpen wrapper of popen between windows / unix.
+ * \param command executed command
+ * \param type "r" is for reading or "w" for writing.
+ * \return normal standard stream
+ */
+inline FILE* TVMPOpen(const char* command, const char* type) {
+#if defined(_WIN32)
+ return _popen(command, type);
+#else
+ return popen(command, type);
+#endif
+}
+
+/*!
+ * \brief TVMPClose wrapper of pclose between windows / linux
+ * \param stream the stream needed to be close.
+ * \return exit status
+ */
+inline int TVMPClose(FILE* stream) {
+#if defined(_WIN32)
+ return _pclose(stream);
+#else
+ return pclose(stream);
+#endif
+}
+
+/*!
+ * \brief TVMWifexited wrapper of WIFEXITED between windows / linux
+ * \param status The status field that was filled in by the wait or waitpid function
+ * \return the exit code of the child process
+ */
+inline int TVMWifexited(int status) {
+#if defined(_WIN32)
+ return (status != 3);
+#else
+ return WIFEXITED(status);
+#endif
+}
+
+/*!
+ * \brief TVMWexitstatus wrapper of WEXITSTATUS between windows / linux
+ * \param status The status field that was filled in by the wait or waitpid function.
+ * \return the child process exited normally or not
+ */
+inline int TVMWexitstatus(int status) {
+#if defined(_WIN32)
+ return status;
+#else
+ return WEXITSTATUS(status);
+#endif
+}
+
+
+/*!
+ * \brief IsNumber check whether string is a number.
+ * \param str input string
+ * \return result of operation.
+ */
+inline bool IsNumber(const std::string& str) {
+ return !str.empty() && std::find_if(str.begin(),
+ str.end(), [](char c) { return !std::isdigit(c); }) == str.end();
+}
+
+/*!
+ * \brief split Split the string based on delimiter
+ * \param str Input string
+ * \param delim The delimiter.
+ * \return vector of strings which are splitted.
+ */
+inline std::vector<std::string> Split(const std::string& str, char delim) {
+ std::string item;
+ std::istringstream is(str);
+ std::vector<std::string> ret;
+ while (std::getline(is, item, delim)) {
+ ret.push_back(item);
+ }
+ return ret;
+}
+
+/*!
+ * \brief EndsWith check whether the strings ends with
+ * \param value The full string
+ * \param end The end substring
+ * \return bool The result.
+ */
+inline bool EndsWith(std::string const& value, std::string const& end) {
+ if (end.size() <= value.size()) {
+ return std::equal(end.rbegin(), end.rend(), value.rbegin());
+ }
+ return false;
+}
+
+/*!
+ * \brief Execute the command
+ * \param cmd The command we want to execute
+ * \param err_msg The error message if we have
+ * \return executed output status
+ */
+inline int Execute(std::string cmd, std::string* err_msg) {
+ std::array<char, 128> buffer;
+ std::string result;
+ cmd += " 2>&1";
+ FILE* fd = TVMPOpen(cmd.c_str(), "r");
+ while (fgets(buffer.data(), buffer.size(), fd) != nullptr) {
+ *err_msg += buffer.data();
+ }
+ int status = TVMPClose(fd);
+ if (TVMWifexited(status)) {
+ return TVMWexitstatus(status);
+ }
+ return 255;
+}
+
+} // namespace common
+} // namespace tvm
+#endif // TVM_COMMON_UTIL_H_
namespace tvm {
namespace runtime {
+// Magic header for RPC data plane
const int kRPCMagic = 0xff271;
+// magic header for RPC tracker(control plane)
+const int kRPCTrackerMagic = 0x2f271;
+// sucess response
+const int kRPCSuccess = kRPCMagic + 0;
+// cannot found matched key in server
+const int kRPCMismatch = kRPCMagic + 2;
+/*! \brief Enumeration code for the RPC tracker */
+enum class TrackerCode : int {
+ kFail = -1,
+ kSuccess = 0,
+ kPing = 1,
+ kStop = 2,
+ kPut = 3,
+ kRequest = 4,
+ kUpdateInfo = 5,
+ kSummary = 6,
+ kGetPendingMatchKeys = 7
+};
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file rpc_socket_impl.h
+ * \brief Socket based RPC implementation.
+ */
+#ifndef TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
+#define TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief RPCServerLoop Start the rpc server loop.
+ * \param sockfd Socket file descriptor
+ */
+void RPCServerLoop(int sockfd);
+
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_