[RUNTIME] Quick fix PackedFunc String passing (#5266)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 7 Apr 2020 23:33:12 +0000 (16:33 -0700)
committerGitHub <noreply@github.com>
Tue, 7 Apr 2020 23:33:12 +0000 (16:33 -0700)
include/tvm/runtime/packed_func.h
tests/cpp/packed_func_test.cc

index d5c0175..1b3ad57 100644 (file)
@@ -513,8 +513,11 @@ class TVMArgValue : public TVMPODValue_ {
     }
   }
   operator tvm::runtime::String() const {
-    // directly use the std::string constructor for now.
-    return tvm::runtime::String(operator std::string());
+    if (IsObjectRef<tvm::runtime::String>()) {
+      return AsObjectRef<tvm::runtime::String>();
+    } else {
+      return tvm::runtime::String(operator std::string());
+    }
   }
   operator DLDataType() const {
     if (type_code_ == kTVMStr) {
@@ -605,8 +608,11 @@ class TVMRetValue : public TVMPODValue_ {
     return *ptr<std::string>();
   }
   operator tvm::runtime::String() const {
-    // directly use the std::string constructor for now.
-    return tvm::runtime::String(operator std::string());
+    if (IsObjectRef<tvm::runtime::String>()) {
+      return AsObjectRef<tvm::runtime::String>();
+    } else {
+      return tvm::runtime::String(operator std::string());
+    }
   }
   operator DLDataType() const {
     if (type_code_ == kTVMStr) {
index 4a815ff..d0313c6 100644 (file)
@@ -95,6 +95,12 @@ TEST(PackedFunc, str) {
       CHECK(y == "hello");
       *rv = x;
     })("hello");
+
+  PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
+      CHECK(args.num_args == 1);
+      runtime::String s = args[0];
+      CHECK(s == "hello");
+  })(runtime::String("hello"));
 }