[RUNTIME] Improved Packed FFI for optional. (#5478)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 30 Apr 2020 15:07:00 +0000 (08:07 -0700)
committerGitHub <noreply@github.com>
Thu, 30 Apr 2020 15:07:00 +0000 (08:07 -0700)
Allows Optional<NDArray> and module to be passed with the right type code.

include/tvm/runtime/packed_func.h
tests/cpp/build_module_test.cc
tests/cpp/container_test.cc

index c566007..cf6d5fa 100644 (file)
@@ -1346,16 +1346,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
 // We use type traits to eliminate un-necessary checks.
 template<typename T>
 inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
-  using TObjectRef = typename std::remove_reference<T>::type;
+  using ContainerType = typename std::remove_reference<T>::type::ContainerType;
   if (value.defined()) {
     Object* ptr = value.data_.data_;
-    if (std::is_base_of<NDArray, TObjectRef>::value ||
-        (std::is_base_of<TObjectRef, NDArray>::value &&
+    if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
+        (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
          ptr->IsInstance<NDArray::ContainerType>())) {
       values_[i].v_handle = NDArray::FFIGetHandle(value);
       type_codes_[i] = kTVMNDArrayHandle;
-    } else if (std::is_base_of<Module, TObjectRef>::value ||
-               (std::is_base_of<TObjectRef, Module>::value &&
+    } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
+               (std::is_base_of<ContainerType, Module::ContainerType>::value &&
                 ptr->IsInstance<Module::ContainerType>())) {
       values_[i].v_handle = ptr;
       type_codes_[i] = kTVMModuleHandle;
@@ -1375,12 +1375,12 @@ template<typename TObjectRef, typename>
 inline bool TVMPODValue_::IsObjectRef() const {
   using ContainerType = typename TObjectRef::ContainerType;
   // NOTE: the following code can be optimized by constant folding.
-  if (std::is_base_of<NDArray, TObjectRef>::value) {
+  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
     return type_code_ == kTVMNDArrayHandle &&
         TVMArrayHandleToObjectHandle(
             static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
   }
-  if (std::is_base_of<Module, TObjectRef>::value) {
+  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
     return type_code_ == kTVMModuleHandle &&
         static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
   }
@@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
         *static_cast<Object**>(value_.v_handle));
   }
   return
-      (std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) ||
-      (std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) ||
+      (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
+       type_code_ == kTVMNDArrayHandle) ||
+      (std::is_base_of<ContainerType, Module::ContainerType>::value &&
+       type_code_ == kTVMModuleHandle) ||
       (type_code_ == kTVMObjectHandle &&
        ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
 }
@@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
       std::is_base_of<ObjectRef, TObjectRef>::value,
       "Conversion only works for ObjectRef");
   using ContainerType = typename TObjectRef::ContainerType;
+
   if (type_code_ == kTVMNullptr) {
     CHECK(TObjectRef::_type_is_nullable)
         << "Expect a not null value of " << ContainerType::_type_key;
     return TObjectRef(ObjectPtr<Object>(nullptr));
   }
   // NOTE: the following code can be optimized by constant folding.
-  if (std::is_base_of<NDArray, TObjectRef>::value) {
+  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
     // Casting to a sub-class of NDArray
     TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
     ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
@@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
         << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
     return TObjectRef(data);
   }
-  if (std::is_base_of<Module, TObjectRef>::value) {
+  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
     // Casting to a sub-class of Module
     TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
     ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
@@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
         << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
         << " but get " << ptr->GetTypeKey();
     return TObjectRef(GetObjectPtr<Object>(ptr));
-  } else if (std::is_base_of<TObjectRef, NDArray>::value &&
+  } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
              type_code_ == kTVMNDArrayHandle) {
     // Casting to a base class that NDArray can sub-class
     ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
         static_cast<TVMArrayHandle>(value_.v_handle));
     return TObjectRef(data);
-  } else if (std::is_base_of<TObjectRef, Module>::value &&
+  } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
              type_code_ == kTVMModuleHandle) {
     // Casting to a base class that Module can sub-class
     return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
@@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
 
 template<typename TObjectRef, typename>
 inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
+  using ContainerType = typename TObjectRef::ContainerType;
   const Object* ptr = other.get();
   if (ptr != nullptr) {
-    if (std::is_base_of<NDArray, TObjectRef>::value ||
-        (std::is_base_of<TObjectRef, NDArray>::value &&
+    if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
+        (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
          ptr->IsInstance<NDArray::ContainerType>())) {
       return operator=(NDArray(std::move(other.data_)));
     }
-    if (std::is_base_of<Module, TObjectRef>::value ||
-        (std::is_base_of<TObjectRef, Module>::value &&
+    if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
+        (std::is_base_of<ContainerType, Module::ContainerType>::value &&
          ptr->IsInstance<Module::ContainerType>())) {
       return operator=(Module(std::move(other.data_)));
     }
index 4913731..6ea0b21 100644 (file)
@@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) {
   runtime::Module mod = (*graph_runtime)(
       json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);
 
+  // test FFI for module.
+  auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+    int tcode = args[1];
+    CHECK_EQ(args[0].type_code(), tcode);
+  });
+
+  test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
+  test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));
+
+
   PackedFunc set_input = mod.GetFunction("set_input", false);
   PackedFunc run = mod.GetFunction("run", false);
   PackedFunc get_output = mod.GetFunction("get_output", false);
index c67df63..c89f815 100644 (file)
@@ -468,6 +468,18 @@ TEST(Optional, PackedCall) {
   CHECK(packedfunc("xyz", false).operator String() == "xyz");
   CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz");
   CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr);
+
+  // test FFI convention.
+  auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+    int tcode = args[1];
+    CHECK_EQ(args[0].type_code(), tcode);
+  });
+  String s = "xyz";
+  auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0});
+  test_ffi(Optional<NDArray>(nd), static_cast<int>(kTVMNDArrayHandle));
+  test_ffi(Optional<String>(s), static_cast<int>(kTVMObjectRValueRefArg));
+  test_ffi(s, static_cast<int>(kTVMObjectHandle));
+  test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
 }
 
 int main(int argc, char** argv) {