From: Tianqi Chen Date: Thu, 6 Aug 2020 01:57:25 +0000 (-0700) Subject: [RUNTIME] Enable auto conversion String->DLDataType (#6214) X-Git-Tag: upstream/0.7.0~301 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=95045d1bf95d3f61f1ba7f9a3c22e0a06aa8065b;p=platform%2Fupstream%2Ftvm.git [RUNTIME] Enable auto conversion String->DLDataType (#6214) --- diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index f8fa09d..423ea89 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -74,6 +73,9 @@ class StringRef; namespace tvm { namespace runtime { +// Forward declare TVMArgValue +class TVMArgValue; + /*! \brief String-aware ObjectRef equal functor */ struct ObjectHash { /*! @@ -1289,9 +1291,7 @@ class String : public ObjectRef { * \param val The value to be checked * \return A boolean indicating if val can be converted to String */ - static bool CanConvertFrom(const TVMArgValue& val) { - return val.type_code() == kTVMStr || val.IsObjectRef(); - } + inline static bool CanConvertFrom(const TVMArgValue& val); /*! * \brief Hash the binary bytes @@ -1523,25 +1523,6 @@ inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) cons return false; } -template <> -struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } - - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } -}; - /*! \brief Helper to represent nullptr for optional. */ struct NullOptType {}; @@ -1659,18 +1640,6 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; -template -struct PackedFuncValueConverter> { - static Optional From(const TVMArgValue& val) { - if (val.type_code() == kTVMNullptr) return Optional(nullptr); - return PackedFuncValueConverter::From(val); - } - static Optional From(const TVMRetValue& val) { - if (val.type_code() == kTVMNullptr) return Optional(nullptr); - return PackedFuncValueConverter::From(val); - } -}; - /*! * \brief An object representing a closure. This object is used by both the * Relay VM and interpreter. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 3231217..d2450c4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -492,22 +493,6 @@ class TVMArgValue : public TVMPODValue_ { return std::string(value_.v_str); } } - operator DLDataType() const { - if (type_code_ == kTVMStr) { - return String2DLDataType(operator std::string()); - } - // None type - if (type_code_ == kTVMNullptr) { - DLDataType t; - t.code = kTVMOpaqueHandle; - t.bits = 0; - t.lanes = 0; - return t; - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); - return value_.v_type; - } - operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); @@ -521,6 +506,8 @@ class TVMArgValue : public TVMPODValue_ { template ::value>::type> inline operator T() const; + inline operator DLDataType() const; + inline operator DataType() const; }; /*! @@ -1473,6 +1460,60 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import return (*this)->GetFunction(name, query_imports); } +// specializations of PackedFuncValueConverter +template <> +struct PackedFuncValueConverter<::tvm::runtime::String> { + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } + + static String From(const TVMRetValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + +inline bool String::CanConvertFrom(const TVMArgValue& val) { + return val.type_code() == kTVMStr || val.IsObjectRef(); +} + +inline TVMArgValue::operator DLDataType() const { + if (String::CanConvertFrom(*this)) { + return String2DLDataType(PackedFuncValueConverter::From(*this).operator std::string()); + } + // None type + if (type_code_ == kTVMNullptr) { + DLDataType t; + t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; + return t; + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); + return value_.v_type; +} + +inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index d375fa0..edf8b42 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -65,6 +65,8 @@ def test_make_node(): assert AA.op == A.op assert AA.value_index == A.value_index + y = tvm.ir.make_node("IntImm", dtype=tvm.runtime.String("int32"), value=10) + def test_make_sum(): A = te.placeholder((2, 10), name='A')