[RUNTIME] Enable auto conversion from str to runtime::String in PackedFunc, move...
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 6 Apr 2020 20:25:00 +0000 (13:25 -0700)
committerGitHub <noreply@github.com>
Mon, 6 Apr 2020 20:25:00 +0000 (13:25 -0700)
include/tvm/runtime/data_type.h
include/tvm/runtime/packed_func.h
tests/cpp/packed_func_test.cc

index 6379a13..9e92db9 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/runtime/c_runtime_api.h>
 #include <dmlc/logging.h>
 #include <type_traits>
+#include <string>
 
 namespace tvm {
 namespace runtime {
@@ -263,6 +264,141 @@ inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
 inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
   return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
 }
+
+/*!
+ * \brief Runtime utility for getting custom type name from code
+ * \param type_code Custom type code
+ * \return Custom type name
+ */
+TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
+
+/*!
+ * \brief Runtime utility for checking whether custom type is registered
+ * \param type_code Custom type code
+ * \return Bool representing whether type is registered
+ */
+TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
+
+/*!
+ * \brief Runtime utility for parsing string of the form "custom[<typename>]"
+ * \param s String to parse
+ * \param scan pointer to parsing pointer, which is scanning across s
+ * \return type code of custom type parsed
+ */
+TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
+
+/*!
+ * \brief Convert type code to its name
+ * \param type_code The type code .
+ * \return The name of type code.
+ */
+inline const char* TypeCode2Str(int type_code);
+
+/*!
+ * \brief convert a string to TVM type.
+ * \param s The string to be converted.
+ * \return The corresponding tvm type.
+ */
+inline DLDataType String2DLDataType(std::string s);
+
+/*!
+ * \brief convert a TVM type to string.
+ * \param t The type to be converted.
+ * \return The corresponding tvm type in string.
+ */
+inline std::string DLDataType2String(DLDataType t);
+
+// implementation details
+inline const char* TypeCode2Str(int type_code) {
+  switch (type_code) {
+    case kDLInt: return "int";
+    case kDLUInt: return "uint";
+    case kDLFloat: return "float";
+    case kTVMStr: return "str";
+    case kTVMBytes: return "bytes";
+    case kTVMOpaqueHandle: return "handle";
+    case kTVMNullptr: return "NULL";
+    case kTVMDLTensorHandle: return "ArrayHandle";
+    case kTVMDataType: return "DLDataType";
+    case kTVMContext: return "TVMContext";
+    case kTVMPackedFuncHandle: return "FunctionHandle";
+    case kTVMModuleHandle: return "ModuleHandle";
+    case kTVMNDArrayHandle: return "NDArrayContainer";
+    case kTVMObjectHandle: return "Object";
+    default: LOG(FATAL) << "unknown type_code="
+                        << static_cast<int>(type_code); return "";
+  }
+}
+
+inline std::ostream& operator<<(std::ostream& os, DLDataType t) {  // NOLINT(*)
+  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
+    os << "bool"; return os;
+  }
+  if (t.code < kTVMCustomBegin) {
+    os << TypeCode2Str(t.code);
+  } else {
+    os << "custom[" << GetCustomTypeName(t.code) << "]";
+  }
+  if (t.code == kTVMOpaqueHandle) return os;
+  os << static_cast<int>(t.bits);
+  if (t.lanes != 1) {
+    os << 'x' << static_cast<int>(t.lanes);
+  }
+  return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
+  return os << dtype.operator DLDataType();
+}
+
+inline std::string DLDataType2String(DLDataType t) {
+  if (t.bits == 0) return "";
+  std::ostringstream os;
+  os << t;
+  return os.str();
+}
+
+inline DLDataType String2DLDataType(std::string s) {
+  DLDataType t;
+  // handle None type
+  if (s.length() == 0) {
+    t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
+    return t;
+  }
+  t.bits = 32; t.lanes = 1;
+  const char* scan;
+  if (s.substr(0, 3) == "int") {
+    t.code = kDLInt;  scan = s.c_str() + 3;
+  } else if (s.substr(0, 4) == "uint") {
+    t.code = kDLUInt; scan = s.c_str() + 4;
+  } else if (s.substr(0, 5) == "float") {
+    t.code = kDLFloat; scan = s.c_str() + 5;
+  } else if (s.substr(0, 6) == "handle") {
+    t.code = kTVMOpaqueHandle;
+    t.bits = 64;  // handle uses 64 bit by default.
+    scan = s.c_str() + 6;
+  } else if (s == "bool") {
+    t.code = kDLUInt;
+    t.bits = 1;
+    t.lanes = 1;
+    return t;
+  } else if (s.substr(0, 6) == "custom") {
+    t.code = ParseCustomDatatype(s, &scan);
+  } else {
+    scan = s.c_str();
+    LOG(FATAL) << "unknown type " << s;
+  }
+  char* xdelim;  // emulate sscanf("%ux%u", bits, lanes)
+  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
+  if (bits != 0) t.bits = bits;
+  char* endpt = xdelim;
+  if (*xdelim == 'x') {
+    t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
+  }
+  CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
+  return t;
+}
+
 }  // namespace runtime
 
 using DataType = runtime::DataType;
index 46fe1a1..d5c0175 100644 (file)
@@ -30,6 +30,7 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/object.h>
+#include <tvm/runtime/container.h>
 #include <functional>
 #include <tuple>
 #include <vector>
@@ -52,28 +53,6 @@ class PrimExpr;
 
 namespace runtime {
 
-/*!
- * \brief Runtime utility for getting custom type name from code
- * \param type_code Custom type code
- * \return Custom type name
- */
-TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
-
-/*!
- * \brief Runtime utility for checking whether custom type is registered
- * \param type_code Custom type code
- * \return Bool representing whether type is registered
- */
-TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
-
-/*!
- * \brief Runtime utility for parsing string of the form "custom[<typename>]"
- * \param s String to parse
- * \param scan pointer to parsing pointer, which is scanning across s
- * \return type code of custom type parsed
- */
-TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
-
 // forward declarations
 class TVMArgs;
 class TVMArgValue;
@@ -359,27 +338,6 @@ class TVMArgs {
   inline TVMArgValue operator[](int i) const;
 };
 
-/*!
- * \brief Convert type code to its name
- * \param type_code The type code .
- * \return The name of type code.
- */
-inline const char* TypeCode2Str(int type_code);
-
-/*!
- * \brief convert a string to TVM type.
- * \param s The string to be converted.
- * \return The corresponding tvm type.
- */
-inline DLDataType String2DLDataType(std::string s);
-
-/*!
- * \brief convert a TVM type to string.
- * \param t The type to be converted.
- * \return The corresponding tvm type in string.
- */
-inline std::string DLDataType2String(DLDataType t);
-
 // macro to check type code.
 #define TVM_CHECK_TYPE_CODE(CODE, T)                           \
   CHECK_EQ(CODE, T) << " expected "                            \
@@ -554,6 +512,10 @@ class TVMArgValue : public TVMPODValue_ {
       return std::string(value_.v_str);
     }
   }
+  operator tvm::runtime::String() const {
+    // directly use the std::string constructor for now.
+    return tvm::runtime::String(operator std::string());
+  }
   operator DLDataType() const {
     if (type_code_ == kTVMStr) {
       return String2DLDataType(operator std::string());
@@ -642,6 +604,10 @@ class TVMRetValue : public TVMPODValue_ {
     TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
     return *ptr<std::string>();
   }
+  operator tvm::runtime::String() const {
+    // directly use the std::string constructor for now.
+    return tvm::runtime::String(operator std::string());
+  }
   operator DLDataType() const {
     if (type_code_ == kTVMStr) {
       return String2DLDataType(operator std::string());
@@ -994,96 +960,6 @@ class TVMRetValue : public TVMPODValue_ {
     }                                                                   \
   }
 
-// implementation details
-inline const char* TypeCode2Str(int type_code) {
-  switch (type_code) {
-    case kDLInt: return "int";
-    case kDLUInt: return "uint";
-    case kDLFloat: return "float";
-    case kTVMStr: return "str";
-    case kTVMBytes: return "bytes";
-    case kTVMOpaqueHandle: return "handle";
-    case kTVMNullptr: return "NULL";
-    case kTVMDLTensorHandle: return "ArrayHandle";
-    case kTVMDataType: return "DLDataType";
-    case kTVMContext: return "TVMContext";
-    case kTVMPackedFuncHandle: return "FunctionHandle";
-    case kTVMModuleHandle: return "ModuleHandle";
-    case kTVMNDArrayHandle: return "NDArrayContainer";
-    case kTVMObjectHandle: return "Object";
-    default: LOG(FATAL) << "unknown type_code="
-                        << static_cast<int>(type_code); return "";
-  }
-}
-
-inline std::ostream& operator<<(std::ostream& os, DLDataType t) {  // NOLINT(*)
-  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
-    os << "bool"; return os;
-  }
-  if (t.code < kTVMCustomBegin) {
-    os << TypeCode2Str(t.code);
-  } else {
-    os << "custom[" << GetCustomTypeName(t.code) << "]";
-  }
-  if (t.code == kTVMOpaqueHandle) return os;
-  os << static_cast<int>(t.bits);
-  if (t.lanes != 1) {
-    os << 'x' << static_cast<int>(t.lanes);
-  }
-  return os;
-}
-
-inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
-  return os << dtype.operator DLDataType();
-}
-
-inline std::string DLDataType2String(DLDataType t) {
-  if (t.bits == 0) return "";
-  std::ostringstream os;
-  os << t;
-  return os.str();
-}
-
-inline DLDataType String2DLDataType(std::string s) {
-  DLDataType t;
-  // handle None type
-  if (s.length() == 0) {
-    t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
-    return t;
-  }
-  t.bits = 32; t.lanes = 1;
-  const char* scan;
-  if (s.substr(0, 3) == "int") {
-    t.code = kDLInt;  scan = s.c_str() + 3;
-  } else if (s.substr(0, 4) == "uint") {
-    t.code = kDLUInt; scan = s.c_str() + 4;
-  } else if (s.substr(0, 5) == "float") {
-    t.code = kDLFloat; scan = s.c_str() + 5;
-  } else if (s.substr(0, 6) == "handle") {
-    t.code = kTVMOpaqueHandle;
-    t.bits = 64;  // handle uses 64 bit by default.
-    scan = s.c_str() + 6;
-  } else if (s == "bool") {
-    t.code = kDLUInt;
-    t.bits = 1;
-    t.lanes = 1;
-    return t;
-  } else if (s.substr(0, 6) == "custom") {
-    t.code = ParseCustomDatatype(s, &scan);
-  } else {
-    scan = s.c_str();
-    LOG(FATAL) << "unknown type " << s;
-  }
-  char* xdelim;  // emulate sscanf("%ux%u", bits, lanes)
-  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
-  if (bits != 0) t.bits = bits;
-  char* endpt = xdelim;
-  if (*xdelim == 'x') {
-    t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
-  }
-  CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
-  return t;
-}
 
 inline TVMArgValue TVMArgs::operator[](int i) const {
   CHECK_LT(i, num_args)
index 8357a70..4a815ff 100644 (file)
@@ -91,6 +91,8 @@ TEST(PackedFunc, str) {
       CHECK(args.num_args == 1);
       std::string x = args[0];
       CHECK(x == "hello");
+      String y = args[0];
+      CHECK(y == "hello");
       *rv = x;
     })("hello");
 }