Allow RPCWrappedFunc to rewrite runtime::String as std::string (#5796)
authorJunru Shao <junrushao1994@gmail.com>
Sun, 14 Jun 2020 23:28:06 +0000 (16:28 -0700)
committerGitHub <noreply@github.com>
Sun, 14 Jun 2020 23:28:06 +0000 (16:28 -0700)
src/runtime/rpc/rpc_module.cc
tests/python/unittest/test_runtime_rpc.py

index b9cdc2c..89f3e7c 100644 (file)
@@ -48,8 +48,13 @@ class RPCWrappedFunc : public Object {
     // scan and check whether we need rewrite these arguments
     // to their remote variant.
     for (int i = 0; i < args.size(); ++i) {
+      if (args[i].IsObjectRef<String>()) {
+        String str = args[i];
+        type_codes[i] = kTVMStr;
+        values[i].v_str = str.c_str();
+        continue;
+      }
       int tcode = type_codes[i];
-
       switch (tcode) {
         case kTVMDLTensorHandle:
         case kTVMNDArrayHandle: {
index dfbb3c5..7f01f88 100644 (file)
@@ -86,6 +86,22 @@ def test_rpc_simple():
     f2 = client.get_function("rpc.test.strcat")
     assert f2("abc", 11) == "abc:11"
 
+
+def test_rpc_runtime_string():
+    if not tvm.runtime.enabled("rpc"):
+        return
+    @tvm.register_func("rpc.test.runtime_str_concat")
+    def strcat(x, y):
+        return x + y
+
+    server = rpc.Server("localhost", key="x1")
+    client = rpc.connect(server.host, server.port, key="x1")
+    func = client.get_function("rpc.test.runtime_str_concat")
+    x = tvm.runtime.container.String("abc")
+    y = tvm.runtime.container.String("def")
+    assert str(func(x, y)) == "abcdef"
+
+
 def test_rpc_array():
     if not tvm.runtime.enabled("rpc"):
         return