[RUNTIME] Enable auto conversion String->DLDataType (#6214)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 6 Aug 2020 01:57:25 +0000 (18:57 -0700)
committerGitHub <noreply@github.com>
Thu, 6 Aug 2020 01:57:25 +0000 (18:57 -0700)
include/tvm/runtime/container.h
include/tvm/runtime/packed_func.h
tests/python/unittest/test_node_reflection.py

index f8fa09d..423ea89 100644 (file)
@@ -27,7 +27,6 @@
 #include <dmlc/logging.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/runtime/object.h>
-#include <tvm/runtime/packed_func.h>
 
 #include <algorithm>
 #include <cstring>
@@ -74,6 +73,9 @@ class StringRef;
 namespace tvm {
 namespace runtime {
 
+// Forward declare TVMArgValue
+class TVMArgValue;
+
 /*! \brief String-aware ObjectRef equal functor */
 struct ObjectHash {
   /*!
@@ -1289,9 +1291,7 @@ class String : public ObjectRef {
    * \param val The value to be checked
    * \return A boolean indicating if val can be converted to String
    */
-  static bool CanConvertFrom(const TVMArgValue& val) {
-    return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
-  }
+  inline static bool CanConvertFrom(const TVMArgValue& val);
 
   /*!
    * \brief Hash the binary bytes
@@ -1523,25 +1523,6 @@ inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) cons
   return false;
 }
 
-template <>
-struct PackedFuncValueConverter<::tvm::runtime::String> {
-  static String From(const TVMArgValue& val) {
-    if (val.IsObjectRef<tvm::runtime::String>()) {
-      return val.AsObjectRef<tvm::runtime::String>();
-    } else {
-      return tvm::runtime::String(val.operator std::string());
-    }
-  }
-
-  static String From(const TVMRetValue& val) {
-    if (val.IsObjectRef<tvm::runtime::String>()) {
-      return val.AsObjectRef<tvm::runtime::String>();
-    } else {
-      return tvm::runtime::String(val.operator std::string());
-    }
-  }
-};
-
 /*! \brief Helper to represent nullptr for optional. */
 struct NullOptType {};
 
@@ -1659,18 +1640,6 @@ class Optional : public ObjectRef {
   static constexpr bool _type_is_nullable = true;
 };
 
-template <typename T>
-struct PackedFuncValueConverter<Optional<T>> {
-  static Optional<T> From(const TVMArgValue& val) {
-    if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
-    return PackedFuncValueConverter<T>::From(val);
-  }
-  static Optional<T> From(const TVMRetValue& val) {
-    if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
-    return PackedFuncValueConverter<T>::From(val);
-  }
-};
-
 /*!
  * \brief An object representing a closure. This object is used by both the
  * Relay VM and interpreter.
index 3231217..d2450c4 100644 (file)
@@ -26,6 +26,7 @@
 
 #include <dmlc/logging.h>
 #include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/ndarray.h>
@@ -492,22 +493,6 @@ class TVMArgValue : public TVMPODValue_ {
       return std::string(value_.v_str);
     }
   }
-  operator DLDataType() const {
-    if (type_code_ == kTVMStr) {
-      return String2DLDataType(operator std::string());
-    }
-    // None type
-    if (type_code_ == kTVMNullptr) {
-      DLDataType t;
-      t.code = kTVMOpaqueHandle;
-      t.bits = 0;
-      t.lanes = 0;
-      return t;
-    }
-    TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
-    return value_.v_type;
-  }
-  operator DataType() const { return DataType(operator DLDataType()); }
   operator PackedFunc() const {
     if (type_code_ == kTVMNullptr) return PackedFunc();
     TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
@@ -521,6 +506,8 @@ class TVMArgValue : public TVMPODValue_ {
 
   template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
   inline operator T() const;
+  inline operator DLDataType() const;
+  inline operator DataType() const;
 };
 
 /*!
@@ -1473,6 +1460,60 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
   return (*this)->GetFunction(name, query_imports);
 }
 
+// specializations of PackedFuncValueConverter
+template <>
+struct PackedFuncValueConverter<::tvm::runtime::String> {
+  static String From(const TVMArgValue& val) {
+    if (val.IsObjectRef<tvm::runtime::String>()) {
+      return val.AsObjectRef<tvm::runtime::String>();
+    } else {
+      return tvm::runtime::String(val.operator std::string());
+    }
+  }
+
+  static String From(const TVMRetValue& val) {
+    if (val.IsObjectRef<tvm::runtime::String>()) {
+      return val.AsObjectRef<tvm::runtime::String>();
+    } else {
+      return tvm::runtime::String(val.operator std::string());
+    }
+  }
+};
+
+template <typename T>
+struct PackedFuncValueConverter<Optional<T>> {
+  static Optional<T> From(const TVMArgValue& val) {
+    if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
+    return PackedFuncValueConverter<T>::From(val);
+  }
+  static Optional<T> From(const TVMRetValue& val) {
+    if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
+    return PackedFuncValueConverter<T>::From(val);
+  }
+};
+
+inline bool String::CanConvertFrom(const TVMArgValue& val) {
+  return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
+}
+
+inline TVMArgValue::operator DLDataType() const {
+  if (String::CanConvertFrom(*this)) {
+    return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
+  }
+  // None type
+  if (type_code_ == kTVMNullptr) {
+    DLDataType t;
+    t.code = kTVMOpaqueHandle;
+    t.bits = 0;
+    t.lanes = 0;
+    return t;
+  }
+  TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
+  return value_.v_type;
+}
+
+inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
+
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_PACKED_FUNC_H_
index d375fa0..edf8b42 100644 (file)
@@ -65,6 +65,8 @@ def test_make_node():
     assert AA.op == A.op
     assert AA.value_index == A.value_index
 
+    y = tvm.ir.make_node("IntImm", dtype=tvm.runtime.String("int32"), value=10)
+
 
 def test_make_sum():
     A = te.placeholder((2, 10), name='A')