#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
-#include <tvm/runtime/packed_func.h>
#include <algorithm>
#include <cstring>
namespace tvm {
namespace runtime {
+// Forward declare TVMArgValue
+class TVMArgValue;
+
/*! \brief String-aware ObjectRef equal functor */
struct ObjectHash {
/*!
* \param val The value to be checked
* \return A boolean indicating if val can be converted to String
*/
- static bool CanConvertFrom(const TVMArgValue& val) {
- return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
- }
+ inline static bool CanConvertFrom(const TVMArgValue& val);
/*!
* \brief Hash the binary bytes
return false;
}
-template <>
-struct PackedFuncValueConverter<::tvm::runtime::String> {
- static String From(const TVMArgValue& val) {
- if (val.IsObjectRef<tvm::runtime::String>()) {
- return val.AsObjectRef<tvm::runtime::String>();
- } else {
- return tvm::runtime::String(val.operator std::string());
- }
- }
-
- static String From(const TVMRetValue& val) {
- if (val.IsObjectRef<tvm::runtime::String>()) {
- return val.AsObjectRef<tvm::runtime::String>();
- } else {
- return tvm::runtime::String(val.operator std::string());
- }
- }
-};
-
/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {};
static constexpr bool _type_is_nullable = true;
};
-template <typename T>
-struct PackedFuncValueConverter<Optional<T>> {
- static Optional<T> From(const TVMArgValue& val) {
- if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
- return PackedFuncValueConverter<T>::From(val);
- }
- static Optional<T> From(const TVMRetValue& val) {
- if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
- return PackedFuncValueConverter<T>::From(val);
- }
-};
-
/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
return std::string(value_.v_str);
}
}
- operator DLDataType() const {
- if (type_code_ == kTVMStr) {
- return String2DLDataType(operator std::string());
- }
- // None type
- if (type_code_ == kTVMNullptr) {
- DLDataType t;
- t.code = kTVMOpaqueHandle;
- t.bits = 0;
- t.lanes = 0;
- return t;
- }
- TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
- return value_.v_type;
- }
- operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
+ inline operator DLDataType() const;
+ inline operator DataType() const;
};
/*!
return (*this)->GetFunction(name, query_imports);
}
+// specializations of PackedFuncValueConverter
+template <>
+struct PackedFuncValueConverter<::tvm::runtime::String> {
+ static String From(const TVMArgValue& val) {
+ if (val.IsObjectRef<tvm::runtime::String>()) {
+ return val.AsObjectRef<tvm::runtime::String>();
+ } else {
+ return tvm::runtime::String(val.operator std::string());
+ }
+ }
+
+ static String From(const TVMRetValue& val) {
+ if (val.IsObjectRef<tvm::runtime::String>()) {
+ return val.AsObjectRef<tvm::runtime::String>();
+ } else {
+ return tvm::runtime::String(val.operator std::string());
+ }
+ }
+};
+
+template <typename T>
+struct PackedFuncValueConverter<Optional<T>> {
+ static Optional<T> From(const TVMArgValue& val) {
+ if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
+ return PackedFuncValueConverter<T>::From(val);
+ }
+ static Optional<T> From(const TVMRetValue& val) {
+ if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
+ return PackedFuncValueConverter<T>::From(val);
+ }
+};
+
+inline bool String::CanConvertFrom(const TVMArgValue& val) {
+ return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
+}
+
+inline TVMArgValue::operator DLDataType() const {
+ if (String::CanConvertFrom(*this)) {
+ return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
+ }
+ // None type
+ if (type_code_ == kTVMNullptr) {
+ DLDataType t;
+ t.code = kTVMOpaqueHandle;
+ t.bits = 0;
+ t.lanes = 0;
+ return t;
+ }
+ TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
+ return value_.v_type;
+}
+
+inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
+
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_