[RUNTIME] Store nullptr PackedFunc as nullptr for better error propagation (#5540)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 8 May 2020 03:10:58 +0000 (20:10 -0700)
committerGitHub <noreply@github.com>
Fri, 8 May 2020 03:10:58 +0000 (20:10 -0700)
include/tvm/runtime/packed_func.h
tests/python/unittest/test_runtime_rpc.py

index 0726292..dfc21fc 100644 (file)
@@ -736,7 +736,11 @@ class TVMRetValue : public TVMPODValue_ {
     return *this;
   }
   TVMRetValue& operator=(PackedFunc f) {
-    this->SwitchToClass(kTVMPackedFuncHandle, f);
+    if (f == nullptr) {
+      this->SwitchToPOD(kTVMNullptr);
+    } else {
+      this->SwitchToClass(kTVMPackedFuncHandle, f);
+    }
     return *this;
   }
   template<typename FType>
@@ -1185,8 +1189,13 @@ class TVMArgsSetter {
     type_codes_[i] = kTVMBytes;
   }
   TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
-    values_[i].v_handle = const_cast<PackedFunc*>(&value);
-    type_codes_[i] = kTVMPackedFuncHandle;
+    if (value != nullptr) {
+      values_[i].v_handle = const_cast<PackedFunc*>(&value);
+      type_codes_[i] = kTVMPackedFuncHandle;
+    } else {
+      values_[i].v_handle = nullptr;
+      type_codes_[i] = kTVMNullptr;
+    }
   }
   template<typename FType>
   TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
index 6a46eab..17321bd 100644 (file)
@@ -129,6 +129,9 @@ def test_rpc_echo():
             raise_err()
 
         remote.cpu().sync()
+        with pytest.raises(AttributeError):
+            f3 = remote.system_lib()["notexist"]
+
 
     temp = rpc.server._server_env([])
     server = rpc.Server("localhost")
@@ -214,6 +217,7 @@ def test_rpc_remote_module():
         remote = tvm.rpc.PopenSession(path_minrpc)
         ctx = remote.cpu(0)
         f1 = remote.system_lib()
+
         a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx)
         time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1)