From 2942278a8e0e672cbd9c23e00aefdb39db5aaf66 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 7 Apr 2020 16:33:12 -0700 Subject: [PATCH] [RUNTIME] Quick fix PackedFunc String passing (#5266) --- include/tvm/runtime/packed_func.h | 14 ++++++++++---- tests/cpp/packed_func_test.cc | 6 ++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d5c0175..1b3ad57 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -513,8 +513,11 @@ class TVMArgValue : public TVMPODValue_ { } } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { @@ -605,8 +608,11 @@ class TVMRetValue : public TVMPODValue_ { return *ptr(); } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 4a815ff..d0313c6 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -95,6 +95,12 @@ TEST(PackedFunc, str) { CHECK(y == "hello"); *rv = x; })("hello"); + + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); + })(runtime::String("hello")); } -- 2.7.4