#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
#include <type_traits>
+#include <string>
namespace tvm {
namespace runtime {
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
+
+/*!
+ * \brief Runtime utility for getting custom type name from code
+ * \param type_code Custom type code
+ * \return Custom type name
+ */
+TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
+
+/*!
+ * \brief Runtime utility for checking whether custom type is registered
+ * \param type_code Custom type code
+ * \return Bool representing whether type is registered
+ */
+TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
+
+/*!
+ * \brief Runtime utility for parsing string of the form "custom[<typename>]"
+ * \param s String to parse
+ * \param scan pointer to parsing pointer, which is scanning across s
+ * \return type code of custom type parsed
+ */
+TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
+
+/*!
+ * \brief Convert type code to its name
+ * \param type_code The type code .
+ * \return The name of type code.
+ */
+inline const char* TypeCode2Str(int type_code);
+
+/*!
+ * \brief convert a string to TVM type.
+ * \param s The string to be converted.
+ * \return The corresponding tvm type.
+ */
+inline DLDataType String2DLDataType(std::string s);
+
+/*!
+ * \brief convert a TVM type to string.
+ * \param t The type to be converted.
+ * \return The corresponding tvm type in string.
+ */
+inline std::string DLDataType2String(DLDataType t);
+
+// implementation details
+inline const char* TypeCode2Str(int type_code) {
+ switch (type_code) {
+ case kDLInt: return "int";
+ case kDLUInt: return "uint";
+ case kDLFloat: return "float";
+ case kTVMStr: return "str";
+ case kTVMBytes: return "bytes";
+ case kTVMOpaqueHandle: return "handle";
+ case kTVMNullptr: return "NULL";
+ case kTVMDLTensorHandle: return "ArrayHandle";
+ case kTVMDataType: return "DLDataType";
+ case kTVMContext: return "TVMContext";
+ case kTVMPackedFuncHandle: return "FunctionHandle";
+ case kTVMModuleHandle: return "ModuleHandle";
+ case kTVMNDArrayHandle: return "NDArrayContainer";
+ case kTVMObjectHandle: return "Object";
+ default: LOG(FATAL) << "unknown type_code="
+ << static_cast<int>(type_code); return "";
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
+ if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
+ os << "bool"; return os;
+ }
+ if (t.code < kTVMCustomBegin) {
+ os << TypeCode2Str(t.code);
+ } else {
+ os << "custom[" << GetCustomTypeName(t.code) << "]";
+ }
+ if (t.code == kTVMOpaqueHandle) return os;
+ os << static_cast<int>(t.bits);
+ if (t.lanes != 1) {
+ os << 'x' << static_cast<int>(t.lanes);
+ }
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
+ return os << dtype.operator DLDataType();
+}
+
+inline std::string DLDataType2String(DLDataType t) {
+ if (t.bits == 0) return "";
+ std::ostringstream os;
+ os << t;
+ return os.str();
+}
+
+inline DLDataType String2DLDataType(std::string s) {
+ DLDataType t;
+ // handle None type
+ if (s.length() == 0) {
+ t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
+ return t;
+ }
+ t.bits = 32; t.lanes = 1;
+ const char* scan;
+ if (s.substr(0, 3) == "int") {
+ t.code = kDLInt; scan = s.c_str() + 3;
+ } else if (s.substr(0, 4) == "uint") {
+ t.code = kDLUInt; scan = s.c_str() + 4;
+ } else if (s.substr(0, 5) == "float") {
+ t.code = kDLFloat; scan = s.c_str() + 5;
+ } else if (s.substr(0, 6) == "handle") {
+ t.code = kTVMOpaqueHandle;
+ t.bits = 64; // handle uses 64 bit by default.
+ scan = s.c_str() + 6;
+ } else if (s == "bool") {
+ t.code = kDLUInt;
+ t.bits = 1;
+ t.lanes = 1;
+ return t;
+ } else if (s.substr(0, 6) == "custom") {
+ t.code = ParseCustomDatatype(s, &scan);
+ } else {
+ scan = s.c_str();
+ LOG(FATAL) << "unknown type " << s;
+ }
+ char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
+ uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
+ if (bits != 0) t.bits = bits;
+ char* endpt = xdelim;
+ if (*xdelim == 'x') {
+ t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
+ }
+ CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
+ return t;
+}
+
} // namespace runtime
using DataType = runtime::DataType;
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/container.h>
#include <functional>
#include <tuple>
#include <vector>
namespace runtime {
-/*!
- * \brief Runtime utility for getting custom type name from code
- * \param type_code Custom type code
- * \return Custom type name
- */
-TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
-
-/*!
- * \brief Runtime utility for checking whether custom type is registered
- * \param type_code Custom type code
- * \return Bool representing whether type is registered
- */
-TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
-
-/*!
- * \brief Runtime utility for parsing string of the form "custom[<typename>]"
- * \param s String to parse
- * \param scan pointer to parsing pointer, which is scanning across s
- * \return type code of custom type parsed
- */
-TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
-
// forward declarations
class TVMArgs;
class TVMArgValue;
inline TVMArgValue operator[](int i) const;
};
-/*!
- * \brief Convert type code to its name
- * \param type_code The type code .
- * \return The name of type code.
- */
-inline const char* TypeCode2Str(int type_code);
-
-/*!
- * \brief convert a string to TVM type.
- * \param s The string to be converted.
- * \return The corresponding tvm type.
- */
-inline DLDataType String2DLDataType(std::string s);
-
-/*!
- * \brief convert a TVM type to string.
- * \param t The type to be converted.
- * \return The corresponding tvm type in string.
- */
-inline std::string DLDataType2String(DLDataType t);
-
// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
return std::string(value_.v_str);
}
}
+ operator tvm::runtime::String() const {
+ // directly use the std::string constructor for now.
+ return tvm::runtime::String(operator std::string());
+ }
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
return *ptr<std::string>();
}
+ operator tvm::runtime::String() const {
+ // directly use the std::string constructor for now.
+ return tvm::runtime::String(operator std::string());
+ }
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
} \
}
-// implementation details
-inline const char* TypeCode2Str(int type_code) {
- switch (type_code) {
- case kDLInt: return "int";
- case kDLUInt: return "uint";
- case kDLFloat: return "float";
- case kTVMStr: return "str";
- case kTVMBytes: return "bytes";
- case kTVMOpaqueHandle: return "handle";
- case kTVMNullptr: return "NULL";
- case kTVMDLTensorHandle: return "ArrayHandle";
- case kTVMDataType: return "DLDataType";
- case kTVMContext: return "TVMContext";
- case kTVMPackedFuncHandle: return "FunctionHandle";
- case kTVMModuleHandle: return "ModuleHandle";
- case kTVMNDArrayHandle: return "NDArrayContainer";
- case kTVMObjectHandle: return "Object";
- default: LOG(FATAL) << "unknown type_code="
- << static_cast<int>(type_code); return "";
- }
-}
-
-inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
- if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
- os << "bool"; return os;
- }
- if (t.code < kTVMCustomBegin) {
- os << TypeCode2Str(t.code);
- } else {
- os << "custom[" << GetCustomTypeName(t.code) << "]";
- }
- if (t.code == kTVMOpaqueHandle) return os;
- os << static_cast<int>(t.bits);
- if (t.lanes != 1) {
- os << 'x' << static_cast<int>(t.lanes);
- }
- return os;
-}
-
-inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
- return os << dtype.operator DLDataType();
-}
-
-inline std::string DLDataType2String(DLDataType t) {
- if (t.bits == 0) return "";
- std::ostringstream os;
- os << t;
- return os.str();
-}
-
-inline DLDataType String2DLDataType(std::string s) {
- DLDataType t;
- // handle None type
- if (s.length() == 0) {
- t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
- return t;
- }
- t.bits = 32; t.lanes = 1;
- const char* scan;
- if (s.substr(0, 3) == "int") {
- t.code = kDLInt; scan = s.c_str() + 3;
- } else if (s.substr(0, 4) == "uint") {
- t.code = kDLUInt; scan = s.c_str() + 4;
- } else if (s.substr(0, 5) == "float") {
- t.code = kDLFloat; scan = s.c_str() + 5;
- } else if (s.substr(0, 6) == "handle") {
- t.code = kTVMOpaqueHandle;
- t.bits = 64; // handle uses 64 bit by default.
- scan = s.c_str() + 6;
- } else if (s == "bool") {
- t.code = kDLUInt;
- t.bits = 1;
- t.lanes = 1;
- return t;
- } else if (s.substr(0, 6) == "custom") {
- t.code = ParseCustomDatatype(s, &scan);
- } else {
- scan = s.c_str();
- LOG(FATAL) << "unknown type " << s;
- }
- char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
- uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
- if (bits != 0) t.bits = bits;
- char* endpt = xdelim;
- if (*xdelim == 'x') {
- t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
- }
- CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
- return t;
-}
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)