From 2b02d567c65dcf3b9464e66da203fe159f3a4351 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 7 May 2020 20:10:58 -0700 Subject: [PATCH] [RUNTIME] Store nullptr PackedFunc as nullptr for better error propagation (#5540) --- include/tvm/runtime/packed_func.h | 15 ++++++++++++--- tests/python/unittest/test_runtime_rpc.py | 4 ++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 0726292..dfc21fc 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -736,7 +736,11 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - this->SwitchToClass(kTVMPackedFuncHandle, f); + if (f == nullptr) { + this->SwitchToPOD(kTVMNullptr); + } else { + this->SwitchToClass(kTVMPackedFuncHandle, f); + } return *this; } template @@ -1185,8 +1189,13 @@ class TVMArgsSetter { type_codes_[i] = kTVMBytes; } TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; + if (value != nullptr) { + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kTVMPackedFuncHandle; + } else { + values_[i].v_handle = nullptr; + type_codes_[i] = kTVMNullptr; + } } template TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 6a46eab..17321bd 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -129,6 +129,9 @@ def test_rpc_echo(): raise_err() remote.cpu().sync() + with pytest.raises(AttributeError): + f3 = remote.system_lib()["notexist"] + temp = rpc.server._server_env([]) server = rpc.Server("localhost") @@ -214,6 +217,7 @@ def test_rpc_remote_module(): remote = tvm.rpc.PopenSession(path_minrpc) ctx = remote.cpu(0) f1 = remote.system_lib() + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) -- 2.7.4