[REFACTOR] Separate ArgTypeCode from DLDataTypeCode (#5730)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 4 Jun 2020 22:04:17 +0000 (15:04 -0700)
committerGitHub <noreply@github.com>
Thu, 4 Jun 2020 22:04:17 +0000 (15:04 -0700)
We use a single enum(TypeCode) to represent ArgTypeCode and DLDataTypeCode.
However, as we start to expand more data types, it is clear that argument
type code(in the FFI convention) and data type code needs to evolve separately.
So that we can add first class for data types without having changing the FFI ABI.

This PR makes the distinction clear and refactored the code to separate the two.

- [PY] Separate ArgTypeCode from DataTypeCode
- [WEB] Separate ArgTypeCode from DataTypeCode
- [JAVA] Separate ArgTypeCode from DataTypeCode

38 files changed:
3rdparty/dlpack
include/tvm/runtime/c_runtime_api.h
include/tvm/runtime/data_type.h
include/tvm/runtime/packed_func.h
include/tvm/tir/op.h
jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java [moved from jvm/core/src/main/java/org/apache/tvm/TypeCode.java with 95% similarity]
jvm/core/src/main/java/org/apache/tvm/Function.java
jvm/core/src/main/java/org/apache/tvm/Module.java
jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java
jvm/core/src/main/java/org/apache/tvm/TVMValue.java
jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java
jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java
jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java
jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java
jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java
jvm/core/src/main/java/org/apache/tvm/TVMValueString.java
python/tvm/__init__.py
python/tvm/_ffi/_ctypes/object.py
python/tvm/_ffi/_ctypes/packed_func.py
python/tvm/_ffi/_ctypes/types.py
python/tvm/_ffi/_cython/base.pxi
python/tvm/_ffi/registry.py
python/tvm/_ffi/runtime_ctypes.py
python/tvm/runtime/__init__.py
python/tvm/runtime/ndarray.py
python/tvm/tir/expr.py
rust/common/src/packed_func.rs
rust/frontend/src/function.rs
rust/tvm-sys/src/packed_func.rs
src/runtime/micro/standalone/utvm_graph_runtime.cc
src/runtime/rpc/rpc_module.cc
src/target/datatype/registry.cc
src/target/datatype/registry.h
tests/python/unittest/test_runtime_extension.py
tests/python/unittest/test_runtime_ndarray.py
web/src/ctypes.ts
web/src/rpc_server.ts
web/src/runtime.ts

index 0acb731..3ec0443 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913
+Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642
index bb38ad8..be86563 100644 (file)
@@ -87,12 +87,14 @@ typedef enum {
 } TVMDeviceExtType;
 
 /*!
- * \brief The type code in used in the TVM FFI.
+ * \brief The type code in used in the TVM FFI for argument passing.
  */
 typedef enum {
   // The type code of other types are compatible with DLPack.
   // The next few fields are extension types
   // that is used by TVM API calls.
+  kTVMArgInt = kDLInt,
+  kTVMArgFloat = kDLFloat,
   kTVMOpaqueHandle = 3U,
   kTVMNullptr = 4U,
   kTVMDataType = 5U,
@@ -115,9 +117,7 @@ typedef enum {
   // The following section of code is used for non-reserved types.
   kTVMExtReserveEnd = 64U,
   kTVMExtEnd = 128U,
-  // The rest of the space is used for custom, user-supplied datatypes
-  kTVMCustomBegin = 129U,
-} TVMTypeCode;
+} TVMArgTypeCode;
 
 /*!
  * \brief The Device information, abstract away common device types.
index a10b83f..1d53810 100644 (file)
@@ -45,7 +45,8 @@ class DataType {
     kInt = kDLInt,
     kUInt = kDLUInt,
     kFloat = kDLFloat,
-    kHandle = TVMTypeCode::kTVMOpaqueHandle,
+    kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
+    kCustomBegin = 129
   };
   /*! \brief default constructor */
   DataType() {}
@@ -248,7 +249,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
  * \param type_code The type code .
  * \return The name of type code.
  */
-inline const char* TypeCode2Str(int type_code);
+inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);
 
 /*!
  * \brief convert a string to TVM type.
@@ -265,38 +266,16 @@ inline DLDataType String2DLDataType(std::string s);
 inline std::string DLDataType2String(DLDataType t);
 
 // implementation details
-inline const char* TypeCode2Str(int type_code) {
-  switch (type_code) {
+inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
+  switch (static_cast<int>(type_code)) {
     case kDLInt:
       return "int";
     case kDLUInt:
       return "uint";
     case kDLFloat:
       return "float";
-    case kTVMStr:
-      return "str";
-    case kTVMBytes:
-      return "bytes";
-    case kTVMOpaqueHandle:
+    case DataType::kHandle:
       return "handle";
-    case kTVMNullptr:
-      return "NULL";
-    case kTVMDLTensorHandle:
-      return "ArrayHandle";
-    case kTVMDataType:
-      return "DLDataType";
-    case kTVMContext:
-      return "TVMContext";
-    case kTVMPackedFuncHandle:
-      return "FunctionHandle";
-    case kTVMModuleHandle:
-      return "ModuleHandle";
-    case kTVMNDArrayHandle:
-      return "NDArrayContainer";
-    case kTVMObjectHandle:
-      return "Object";
-    case kTVMObjectRValueRefArg:
-      return "ObjectRValueRefArg";
     default:
       LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
       return "";
@@ -311,8 +290,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) {  // NOLINT(*)
   if (DataType(t).is_void()) {
     return os << "void";
   }
-  if (t.code < kTVMCustomBegin) {
-    os << TypeCode2Str(t.code);
+  if (t.code < DataType::kCustomBegin) {
+    os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
   } else {
     os << "custom[" << GetCustomTypeName(t.code) << "]";
   }
index 01f8e99..e82b97a 100644 (file)
@@ -327,9 +327,16 @@ class TVMArgs {
   inline TVMArgValue operator[](int i) const;
 };
 
+/*!
+ * \brief Convert argument type code to string.
+ * \param type_code The input type code.
+ * \return The corresponding string repr.
+ */
+inline const char* ArgTypeCode2Str(int type_code);
+
 // macro to check type code.
 #define TVM_CHECK_TYPE_CODE(CODE, T) \
-  CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
+  CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE)
 
 /*!
  * \brief Type traits for runtime type check during FFI conversion.
@@ -394,7 +401,7 @@ class TVMPODValue_ {
     } else {
       if (type_code_ == kTVMNullptr) return nullptr;
       LOG(FATAL) << "Expect "
-                 << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_);
+                 << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_);
       return nullptr;
     }
   }
@@ -982,6 +989,44 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(
 inline PackedFunc::FType PackedFunc::body() const { return body_; }
 
 // internal namespace
+inline const char* ArgTypeCode2Str(int type_code) {
+  switch (type_code) {
+    case kDLInt:
+      return "int";
+    case kDLUInt:
+      return "uint";
+    case kDLFloat:
+      return "float";
+    case kTVMStr:
+      return "str";
+    case kTVMBytes:
+      return "bytes";
+    case kTVMOpaqueHandle:
+      return "handle";
+    case kTVMNullptr:
+      return "NULL";
+    case kTVMDLTensorHandle:
+      return "ArrayHandle";
+    case kTVMDataType:
+      return "DLDataType";
+    case kTVMContext:
+      return "TVMContext";
+    case kTVMPackedFuncHandle:
+      return "FunctionHandle";
+    case kTVMModuleHandle:
+      return "ModuleHandle";
+    case kTVMNDArrayHandle:
+      return "NDArrayContainer";
+    case kTVMObjectHandle:
+      return "Object";
+    case kTVMObjectRValueRefArg:
+      return "ObjectRValueRefArg";
+    default:
+      LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
+      return "";
+  }
+}
+
 namespace detail {
 
 template <bool stop, std::size_t I, typename F>
index 5884942..a4748d5 100644 (file)
@@ -740,7 +740,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
   // datatypes lowering pass, we will lower the value to its true representation in the format
   // specified by the datatype.
   // TODO(gus) when do we need to start worrying about doubles not being precise enough?
-  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
+  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
     return FloatImm(t, static_cast<double>(value));
   }
   LOG(FATAL) << "cannot make const for type " << t;
 package org.apache.tvm;
 
 // Type code used in API calls
-public enum TypeCode {
+public enum ArgTypeCode {
   INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
   TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
   FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);
 
   public final int id;
 
-  private TypeCode(int id) {
+  private ArgTypeCode(int id) {
     this.id = id;
   }
 
index a9ac707..df535a8 100644 (file)
@@ -80,7 +80,7 @@ public class Function extends TVMValue {
    * @param isResident Whether this is a resident function in jvm
    */
   Function(long handle, boolean isResident) {
-    super(TypeCode.FUNC_HANDLE);
+    super(ArgTypeCode.FUNC_HANDLE);
     this.handle = handle;
     this.isResident = isResident;
   }
@@ -187,7 +187,7 @@ public class Function extends TVMValue {
    * @return this
    */
   public Function pushArg(NDArrayBase arg) {
-    int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
+    int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
     Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
     return this;
   }
@@ -198,7 +198,7 @@ public class Function extends TVMValue {
    * @return this
    */
   public Function pushArg(Module arg) {
-    Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id);
+    Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id);
     return this;
   }
 
@@ -208,7 +208,7 @@ public class Function extends TVMValue {
    * @return this
    */
   public Function pushArg(Function arg) {
-    Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id);
+    Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id);
     return this;
   }
 
@@ -249,12 +249,12 @@ public class Function extends TVMValue {
       Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
     } else if (arg instanceof NDArrayBase) {
       NDArrayBase nd = (NDArrayBase) arg;
-      int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
+      int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
       Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
     } else if (arg instanceof Module) {
-      Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
+      Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
     } else if (arg instanceof Function) {
-      Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id);
+      Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
     } else if (arg instanceof TVMValue) {
       TVMValue tvmArg = (TVMValue) arg;
       switch (tvmArg.typeCode) {
index 1656f8d..874daa4 100644 (file)
@@ -45,7 +45,7 @@ public class Module extends TVMValue {
   }
 
   Module(long handle) {
-    super(TypeCode.MODULE_HANDLE);
+    super(ArgTypeCode.MODULE_HANDLE);
     this.handle = handle;
   }
 
@@ -138,7 +138,7 @@ public class Module extends TVMValue {
    */
   public static Module load(String path, String fmt) {
     TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke();
-    assert ret.typeCode == TypeCode.MODULE_HANDLE;
+    assert ret.typeCode == ArgTypeCode.MODULE_HANDLE;
     return ret.asModule();
   }
 
index 5ac630d..26bb735 100644 (file)
@@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue {
   private boolean isReleased = false;
 
   NDArrayBase(long handle, boolean isView) {
-    super(TypeCode.ARRAY_HANDLE);
+    super(ArgTypeCode.ARRAY_HANDLE);
     this.handle = handle;
     this.isView = isView;
   }
index 92c7623..d30cfcc 100644 (file)
@@ -18,9 +18,9 @@
 package org.apache.tvm;
 
 public class TVMValue {
-  public final TypeCode typeCode;
+  public final ArgTypeCode typeCode;
 
-  public TVMValue(TypeCode tc) {
+  public TVMValue(ArgTypeCode tc) {
     typeCode = tc;
   }
 
index 6c7c1c8..132d88f 100644 (file)
@@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue {
   public final byte[] value;
 
   public TVMValueBytes(byte[] value) {
-    super(TypeCode.BYTES);
+    super(ArgTypeCode.BYTES);
     this.value = value;
   }
 
index d94b011..9db4c3b 100644 (file)
@@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue {
   public final double value;
 
   public TVMValueDouble(double value) {
-    super(TypeCode.FLOAT);
+    super(ArgTypeCode.FLOAT);
     this.value = value;
   }
 
index 8ab7572..b91f55e 100644 (file)
 package org.apache.tvm;
 
 /**
- * Java class related to TVM handles (TypeCode.HANDLE)
+ * Java class related to TVM handles (ArgTypeCode.HANDLE)
  */
 public class TVMValueHandle extends TVMValue {
   public final long value;
 
   public TVMValueHandle(long value) {
-    super(TypeCode.HANDLE);
+    super(ArgTypeCode.HANDLE);
     this.value = value;
   }
 
index 5dba2fd..8a9b157 100644 (file)
@@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue {
   public final long value;
 
   public TVMValueLong(long value) {
-    super(TypeCode.INT);
+    super(ArgTypeCode.INT);
     this.value = value;
   }
 
index 03c0ea0..8c49ee5 100644 (file)
@@ -19,6 +19,6 @@ package org.apache.tvm;
 
 public class TVMValueNull extends TVMValue {
   public TVMValueNull() {
-    super(TypeCode.NULL);
+    super(ArgTypeCode.NULL);
   }
 }
index 260803e..46926e7 100644 (file)
@@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue {
   public final String value;
 
   public TVMValueString(String value) {
-    super(TypeCode.STR);
+    super(ArgTypeCode.STR);
     this.value = value;
   }
 
index 6db8655..6cbc6d2 100644 (file)
@@ -23,7 +23,7 @@ import traceback
 # top-level alias
 # tvm._ffi
 from ._ffi.base import TVMError, __version__
-from ._ffi.runtime_ctypes import TypeCode, DataType
+from ._ffi.runtime_ctypes import DataTypeCode, DataType
 from ._ffi import register_object, register_func, register_extension, get_global_func
 
 # top-level alias
index 3dbb607..359b018 100644 (file)
@@ -18,7 +18,7 @@
 """Runtime Object api"""
 import ctypes
 from ..base import _LIB, check_call
-from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
+from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
 from .ndarray import _register_ndarray, NDArrayBase
 
 
@@ -60,12 +60,12 @@ def _return_object(x):
     obj.handle = handle
     return obj
 
-RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
-C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
-    _return_object, TypeCode.OBJECT_HANDLE)
+RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object
+C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func(
+    _return_object, ArgTypeCode.OBJECT_HANDLE)
 
-C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
-    _return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
+C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
+    _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG)
 
 
 class PyNativeObject:
index b17174a..8a2f49a 100644 (file)
@@ -26,7 +26,7 @@ from ..base import c_str, string_types
 from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef
 from . import ndarray as _nd
 from .ndarray import NDArrayBase, _make_array
-from .types import TVMValue, TypeCode
+from .types import TVMValue, ArgTypeCode
 from .types import TVMPackedCFunc, TVMCFuncFinalizer
 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
 from .object import ObjectBase, PyNativeObject, _set_class_object
@@ -115,32 +115,32 @@ def _make_tvm_args(args, temp_args):
     for i, arg in enumerate(args):
         if isinstance(arg, ObjectBase):
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.OBJECT_HANDLE
+            type_codes[i] = ArgTypeCode.OBJECT_HANDLE
         elif arg is None:
             values[i].v_handle = None
-            type_codes[i] = TypeCode.NULL
+            type_codes[i] = ArgTypeCode.NULL
         elif isinstance(arg, NDArrayBase):
             values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
-            type_codes[i] = (TypeCode.NDARRAY_HANDLE
-                             if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
+            type_codes[i] = (ArgTypeCode.NDARRAY_HANDLE
+                             if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE)
         elif isinstance(arg, PyNativeObject):
             values[i].v_handle = arg.__tvm_object__.handle
-            type_codes[i] = TypeCode.OBJECT_HANDLE
+            type_codes[i] = ArgTypeCode.OBJECT_HANDLE
         elif isinstance(arg, _nd._TVM_COMPATS):
             values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
             type_codes[i] = arg.__class__._tvm_tcode
         elif isinstance(arg, Integral):
             values[i].v_int64 = arg
-            type_codes[i] = TypeCode.INT
+            type_codes[i] = ArgTypeCode.INT
         elif isinstance(arg, Number):
             values[i].v_float64 = arg
-            type_codes[i] = TypeCode.FLOAT
+            type_codes[i] = ArgTypeCode.FLOAT
         elif isinstance(arg, DataType):
             values[i].v_str = c_str(str(arg))
-            type_codes[i] = TypeCode.STR
+            type_codes[i] = ArgTypeCode.STR
         elif isinstance(arg, TVMContext):
             values[i].v_int64 = _ctx_to_int64(arg)
-            type_codes[i] = TypeCode.TVM_CONTEXT
+            type_codes[i] = ArgTypeCode.TVM_CONTEXT
         elif isinstance(arg, (bytearray, bytes)):
             # from_buffer only taeks in bytearray.
             if isinstance(arg, bytes):
@@ -155,31 +155,31 @@ def _make_tvm_args(args, temp_args):
             arr.size = len(arg)
             values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
             temp_args.append(arr)
-            type_codes[i] = TypeCode.BYTES
+            type_codes[i] = ArgTypeCode.BYTES
         elif isinstance(arg, string_types):
             values[i].v_str = c_str(arg)
-            type_codes[i] = TypeCode.STR
+            type_codes[i] = ArgTypeCode.STR
         elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
             arg = _FUNC_CONVERT_TO_OBJECT(arg)
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.OBJECT_HANDLE
+            type_codes[i] = ArgTypeCode.OBJECT_HANDLE
             temp_args.append(arg)
         elif isinstance(arg, _CLASS_MODULE):
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.MODULE_HANDLE
+            type_codes[i] = ArgTypeCode.MODULE_HANDLE
         elif isinstance(arg, PackedFuncBase):
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
+            type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
         elif isinstance(arg, ctypes.c_void_p):
             values[i].v_handle = arg
-            type_codes[i] = TypeCode.HANDLE
+            type_codes[i] = ArgTypeCode.HANDLE
         elif isinstance(arg, ObjectRValueRef):
             values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p)
-            type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG
+            type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG
         elif callable(arg):
             arg = convert_to_tvm_func(arg)
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
+            type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
             temp_args.append(arg)
         else:
             raise TypeError("Don't know how to handle type %s" % type(arg))
@@ -240,7 +240,7 @@ def __init_handle_by_constructor__(fconstructor, args):
         raise get_last_ffi_error()
     _ = temp_args
     _ = args
-    assert ret_tcode.value == TypeCode.OBJECT_HANDLE
+    assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
     handle = ret_val.v_handle
     return handle
 
@@ -275,15 +275,15 @@ def _get_global_func(name, allow_missing=False):
 
 # setup return handle for function type
 _object.__init_by_constructor__ = __init_handle_by_constructor__
-RETURN_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
-RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
-RETURN_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
-C_TO_PY_ARG_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func(
-    _handle_return_func, TypeCode.PACKED_FUNC_HANDLE)
-C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
-    _return_module, TypeCode.MODULE_HANDLE)
-C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
-C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
+RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
+RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module
+RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
+C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func(
+    _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE)
+C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func(
+    _return_module, ArgTypeCode.MODULE_HANDLE)
+C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
+C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
 
 _CLASS_MODULE = None
 _CLASS_PACKED_FUNC = None
index 20be30a..d4e7b36 100644 (file)
@@ -19,7 +19,7 @@
 import ctypes
 import struct
 from ..base import py_str, check_call, _LIB
-from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext
+from ..runtime_ctypes import TVMByteArray, ArgTypeCode, TVMContext
 
 class TVMValue(ctypes.Union):
     """TVMValue in C API"""
@@ -86,21 +86,21 @@ def _ctx_to_int64(ctx):
 
 
 RETURN_SWITCH = {
-    TypeCode.INT: lambda x: x.v_int64,
-    TypeCode.FLOAT: lambda x: x.v_float64,
-    TypeCode.HANDLE: _return_handle,
-    TypeCode.NULL: lambda x: None,
-    TypeCode.STR: lambda x: py_str(x.v_str),
-    TypeCode.BYTES: _return_bytes,
-    TypeCode.TVM_CONTEXT: _return_context
+    ArgTypeCode.INT: lambda x: x.v_int64,
+    ArgTypeCode.FLOAT: lambda x: x.v_float64,
+    ArgTypeCode.HANDLE: _return_handle,
+    ArgTypeCode.NULL: lambda x: None,
+    ArgTypeCode.STR: lambda x: py_str(x.v_str),
+    ArgTypeCode.BYTES: _return_bytes,
+    ArgTypeCode.TVM_CONTEXT: _return_context
 }
 
 C_TO_PY_ARG_SWITCH = {
-    TypeCode.INT: lambda x: x.v_int64,
-    TypeCode.FLOAT: lambda x: x.v_float64,
-    TypeCode.HANDLE: _return_handle,
-    TypeCode.NULL: lambda x: None,
-    TypeCode.STR: lambda x: py_str(x.v_str),
-    TypeCode.BYTES: _return_bytes,
-    TypeCode.TVM_CONTEXT: _return_context
+    ArgTypeCode.INT: lambda x: x.v_int64,
+    ArgTypeCode.FLOAT: lambda x: x.v_float64,
+    ArgTypeCode.HANDLE: _return_handle,
+    ArgTypeCode.NULL: lambda x: None,
+    ArgTypeCode.STR: lambda x: py_str(x.v_str),
+    ArgTypeCode.BYTES: _return_bytes,
+    ArgTypeCode.TVM_CONTEXT: _return_context
 }
index 0da66ac..8c9e413 100644 (file)
@@ -22,7 +22,7 @@ from cpython cimport pycapsule
 from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t
 import ctypes
 
-cdef enum TVMTypeCode:
+cdef enum TVMArgTypeCode:
     kInt = 0
     kUInt = 1
     kFloat = 2
index e4b8b18..0942ccb 100644 (file)
@@ -122,7 +122,7 @@ def register_extension(cls, fcreate=None):
 
        @tvm.register_extension
        class MyTensor(object):
-           _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
+           _tvm_tcode = tvm.ArgTypeCode.ARRAY_HANDLE
 
            def __init__(self):
                self.handle = _LIB.NewDLTensor()
@@ -132,8 +132,8 @@ def register_extension(cls, fcreate=None):
                return self.handle.value
     """
     assert hasattr(cls, "_tvm_tcode")
-    if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
-        raise ValueError("Cannot register create when extension tcode is same as buildin")
+    if fcreate:
+        raise ValueError("Extension with fcreate is no longer supported")
     _reg_extension(cls, fcreate)
     return cls
 
index db89854..2e498e3 100644 (file)
@@ -23,7 +23,7 @@ from .base import _LIB, check_call
 
 tvm_shape_index_t = ctypes.c_int64
 
-class TypeCode(object):
+class ArgTypeCode(object):
     """Type code used in API calls"""
     INT = 0
     UINT = 1
@@ -42,23 +42,30 @@ class TypeCode(object):
     OBJECT_RVALUE_REF_ARG = 14
     EXT_BEGIN = 15
 
-
 class TVMByteArray(ctypes.Structure):
     """Temp data structure for byte array."""
     _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
                 ("size", ctypes.c_size_t)]
 
 
+class DataTypeCode(object):
+    """DataType code in DLTensor."""
+    INT = 0
+    UINT = 1
+    FLOAT = 2
+    HANDLE = 3
+
+
 class DataType(ctypes.Structure):
     """TVM datatype structure"""
     _fields_ = [("type_code", ctypes.c_uint8),
                 ("bits", ctypes.c_uint8),
                 ("lanes", ctypes.c_uint16)]
     CODE2STR = {
-        0 : 'int',
-        1 : 'uint',
-        2 : 'float',
-        4 : 'handle'
+        DataTypeCode.INT : 'int',
+        DataTypeCode.UINT : 'uint',
+        DataTypeCode.FLOAT : 'float',
+        DataTypeCode.HANDLE : 'handle'
     }
     def __init__(self, type_str):
         super(DataType, self).__init__()
@@ -67,7 +74,7 @@ class DataType(ctypes.Structure):
 
         if type_str == "bool":
             self.bits = 1
-            self.type_code = 1
+            self.type_code = DataTypeCode.UINT
             self.lanes = 1
             return
 
@@ -77,16 +84,16 @@ class DataType(ctypes.Structure):
         bits = 32
 
         if head.startswith("int"):
-            self.type_code = 0
+            self.type_code = DataTypeCode.INT
             head = head[3:]
         elif head.startswith("uint"):
-            self.type_code = 1
+            self.type_code = DataTypeCode.UINT
             head = head[4:]
         elif head.startswith("float"):
-            self.type_code = 2
+            self.type_code = DataTypeCode.FLOAT
             head = head[5:]
         elif head.startswith("handle"):
-            self.type_code = 4
+            self.type_code = DataTypeCode.HANDLE
             bits = 64
             head = ""
         elif head.startswith("custom"):
index 59574e6..21c06c5 100644 (file)
@@ -20,7 +20,7 @@
 from .packed_func import PackedFunc
 from .object import Object
 from .object_generic import ObjectGeneric, ObjectTypes
-from .ndarray import NDArray, DataType, TypeCode, TVMContext
+from .ndarray import NDArray, DataType, DataTypeCode, TVMContext
 from .module import Module
 
 # function exposures
index 6629cc6..060673d 100644 (file)
@@ -22,7 +22,7 @@ import tvm._ffi
 
 from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
 from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle
-from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
+from tvm._ffi.runtime_ctypes import DataTypeCode, tvm_shape_index_t
 
 try:
     # pylint: disable=wrong-import-position
index 4cbece3..aca5e5a 100644 (file)
@@ -29,7 +29,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
 """
 import tvm._ffi
 
-from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const
+from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
 from tvm.ir import PrimExpr
 import tvm.ir._ffi_api
 from . import generic as _generic
@@ -47,13 +47,13 @@ def _dtype_is_int(value):
     if isinstance(value, int):
         return True
     return (isinstance(value, ExprOp) and
-            DataType(value.dtype).type_code == TypeCode.INT)
+            DataType(value.dtype).type_code == DataTypeCode.INT)
 
 def _dtype_is_float(value):
     if isinstance(value, float):
         return True
     return (isinstance(value, ExprOp) and
-            DataType(value.dtype).type_code == TypeCode.FLOAT)
+            DataType(value.dtype).type_code == DataTypeCode.FLOAT)
 
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
index f3bac39..65434b9 100644 (file)
@@ -94,52 +94,52 @@ macro_rules! TVMPODValue {
                         DLDataTypeCode_kDLInt => Int($value.v_int64),
                         DLDataTypeCode_kDLUInt => UInt($value.v_int64),
                         DLDataTypeCode_kDLFloat => Float($value.v_float64),
-                        TVMTypeCode_kTVMNullptr => Null,
-                        TVMTypeCode_kTVMDataType => DataType($value.v_type),
-                        TVMTypeCode_kTVMContext => Context($value.v_ctx),
-                        TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
-                        TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
-                        TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
-                        TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
-                        TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
-                        TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMNullptr => Null,
+                        TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
+                        TVMArgTypeCode_kTVMContext => Context($value.v_ctx),
+                        TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
+                        TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
+                        TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
                         $( $tvm_type => { $from_tvm_type } ),+
                         _ => unimplemented!("{}", type_code),
                     }
                 }
             }
 
-            pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
+            pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) {
                 use $name::*;
                 match self {
                     Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
                     UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
                     Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
-                    Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
-                    DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
-                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
+                    Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
+                    DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
+                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext),
                     String(val) => {
                         (
                             TVMValue { v_handle: val.as_ptr() as *mut c_void },
-                            TVMTypeCode_kTVMStr,
+                            TVMArgTypeCode_kTVMStr,
                         )
                     }
-                    Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
+                    Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle),
                     ArrayHandle(val) => {
                         (
                             TVMValue { v_handle: *val as *const _ as *mut c_void },
-                            TVMTypeCode_kTVMNDArrayHandle,
+                            TVMArgTypeCode_kTVMNDArrayHandle,
                         )
                     },
-                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
+                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle),
                     ModuleHandle(val) =>
-                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
+                        (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle),
                     FuncHandle(val) => (
                         TVMValue { v_handle: *val },
-                        TVMTypeCode_kTVMPackedFuncHandle
+                        TVMArgTypeCode_kTVMPackedFuncHandle
                     ),
                     NDArrayHandle(val) =>
-                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
+                        (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle),
                     $( $self_type($val) => { $from_self_type } ),+
                 }
             }
@@ -155,14 +155,14 @@ TVMPODValue! {
         Str(&'a CStr),
     },
     match value {
-        TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
-        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
+        TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
+        TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
     },
     match &self {
         Bytes(val) => {
-            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
+            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes)
         }
-        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
+        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) }
     }
 }
 
@@ -188,14 +188,14 @@ TVMPODValue! {
         Str(&'static CStr),
     },
     match value {
-        TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
-        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
+        TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
+        TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
     },
     match &self {
         Bytes(val) =>
-            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
+            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) }
         Str(val) =>
-            { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
+            { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) }
     }
 }
 
index 8411b03..88d6cc8 100644 (file)
@@ -204,7 +204,7 @@ impl<'a, 'm> Builder<'a, 'm> {
         ensure!(self.func.is_some(), errors::FunctionNotFoundError);
 
         let num_args = self.arg_buf.len();
-        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
+        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMArgTypeCode>) =
             self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
 
         let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() };
@@ -257,9 +257,9 @@ unsafe extern "C" fn tvm_callback(
     for i in 0..len {
         value = args_list[i];
         tcode = type_codes_list[i];
-        if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int
-            || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
-            || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
+        if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int
+            || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int
+            || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int
         {
             check_call!(ffi::TVMCbArgToReturn(
                 &mut value as *mut _,
index e4b2739..a326aa1 100644 (file)
@@ -95,52 +95,52 @@ macro_rules! TVMPODValue {
                         DLDataTypeCode_kDLInt => Int($value.v_int64),
                         DLDataTypeCode_kDLUInt => UInt($value.v_int64),
                         DLDataTypeCode_kDLFloat => Float($value.v_float64),
-                        TVMTypeCode_kTVMNullptr => Null,
-                        TVMTypeCode_kTVMDataType => DataType($value.v_type),
-                        TVMTypeCode_kTVMContext => Context($value.v_ctx),
-                        TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
-                        TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
-                        TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
-                        TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
-                        TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
-                        TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMNullptr => Null,
+                        TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
+                        TVMArgTypeCode_kTVMContext => Context($value.v_ctx),
+                        TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
+                        TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
+                        TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
                         $( $tvm_type => { $from_tvm_type } ),+
                         _ => unimplemented!("{}", type_code),
                     }
                 }
             }
 
-            pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
+            pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) {
                 use $name::*;
                 match self {
                     Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
                     UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
                     Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
-                    Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
-                    DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
-                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
+                    Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
+                    DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
+                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext),
                     String(val) => {
                         (
                             TVMValue { v_handle: val.as_ptr() as *mut c_void },
-                            TVMTypeCode_kTVMStr,
+                            TVMArgTypeCode_kTVMStr,
                         )
                     }
-                    Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
+                    Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle),
                     ArrayHandle(val) => {
                         (
                             TVMValue { v_handle: *val as *const _ as *mut c_void },
-                            TVMTypeCode_kTVMNDArrayHandle,
+                            TVMArgTypeCode_kTVMNDArrayHandle,
                         )
                     },
-                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
+                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle),
                     ModuleHandle(val) =>
-                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
+                        (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle),
                     FuncHandle(val) => (
                         TVMValue { v_handle: *val },
-                        TVMTypeCode_kTVMPackedFuncHandle
+                        TVMArgTypeCode_kTVMPackedFuncHandle
                     ),
                     NDArrayHandle(val) =>
-                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
+                        (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle),
                     $( $self_type($val) => { $from_self_type } ),+
                 }
             }
@@ -156,14 +156,14 @@ TVMPODValue! {
         Str(&'a CStr),
     },
     match value {
-        TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
-        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
+        TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
+        TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
     },
     match &self {
         Bytes(val) => {
-            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
+            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes)
         }
-        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
+        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) }
     }
 }
 
@@ -189,14 +189,14 @@ TVMPODValue! {
         Str(&'static CStr),
     },
     match value {
-        TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
-        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
+        TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
+        TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
     },
     match &self {
         Bytes(val) =>
-            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
+            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) }
         Str(val) =>
-            { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
+            { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) }
     }
 }
 
index db55634..e19ee34 100644 (file)
@@ -327,7 +327,7 @@ std::function<void()> CreateTVMOp(const DSOModule& module, const TVMOpParam& par
   } TVMValue;
   /*typedef*/ enum {
     kTVMDLTensorHandle = 7U,
-  } /*TVMTypeCode*/;
+  } /*TVMArgTypeCode*/;
   struct OpArgs {
     DynArray<DLTensor> args;
     DynArray<TVMValue> arg_values;
index 8c46269..b9cdc2c 100644 (file)
@@ -251,7 +251,7 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const {
         << "ValueError: Cannot pass in module into a different remote session";
     return rmod->module_handle();
   } else {
-    LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code())
+    LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code())
                << " as an argument to the remote";
     return nullptr;
   }
index 99d6bee..5ed3ce4 100644 (file)
@@ -49,8 +49,8 @@ Registry* Registry::Global() {
 }
 
 void Registry::Register(const std::string& type_name, uint8_t type_code) {
-  CHECK(type_code >= kTVMCustomBegin)
-      << "Please choose a type code >= kTVMCustomBegin for custom types";
+  CHECK(type_code >= DataType::kCustomBegin)
+      << "Please choose a type code >= DataType::kCustomBegin for custom types";
   code_to_name_[type_code] = type_name;
   name_to_code_[type_name] = type_code;
 }
@@ -78,7 +78,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t
   if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
     ss << datatype::Registry::Global()->GetTypeName(type_code);
   } else {
-    ss << runtime::TypeCode2Str(type_code);
+    ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(type_code));
   }
 
   ss << ".";
@@ -86,7 +86,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t
   if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
     ss << datatype::Registry::Global()->GetTypeName(src_type_code);
   } else {
-    ss << runtime::TypeCode2Str(src_type_code);
+    ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(src_type_code));
   }
   return runtime::Registry::Get(ss.str());
 }
index c043592..5df8ef8 100644 (file)
@@ -61,7 +61,7 @@ class Registry {
    * same code. Generally, this should be straightforward, as the user will be manually registering
    * all of their custom types.
    * \param type_name The name of the type, e.g. "bfloat"
-   * \param type_code The type code, which should be greater than TVMTypeCode::kTVMExtEnd
+   * \param type_code The type code, which should be greater than TVMArgTypeCode::kTVMExtEnd
    */
   void Register(const std::string& type_name, uint8_t type_code);
 
index 48eaf7d..2207eb3 100644 (file)
@@ -18,9 +18,10 @@ import tvm
 from tvm import te
 import numpy as np
 
+
 @tvm.register_extension
 class MyTensorView(object):
-    _tvm_tcode = tvm.TypeCode.DLTENSOR_HANDLE
+    _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE
     def __init__(self, arr):
         self.arr = arr
 
index e314379..3631295 100644 (file)
@@ -72,6 +72,13 @@ def test_fp16_conversion():
 
         tvm.testing.assert_allclose(expected, real)
 
+
+def test_dtype():
+    dtype = tvm.DataType("handle")
+    assert dtype.type_code == tvm.DataTypeCode.HANDLE
+
+
 if __name__ == "__main__":
     test_nd_create()
     test_fp16_conversion()
+    test_dtype()
index f533b4e..66c46fe 100644 (file)
@@ -208,9 +208,9 @@ export const enum SizeOf {
 }
 
 /**
- * Type code in TVM FFI.
+ * Argument Type code in TVM FFI.
  */
-export const enum TypeCode {
+export const enum ArgTypeCode {
   Int = 0,
   UInt = 1,
   Float = 2,
@@ -226,4 +226,4 @@ export const enum TypeCode {
   TVMBytes = 12,
   TVMNDArrayHandle = 13,
   TVMObjectRValueRefArg = 14
-}
\ No newline at end of file
+}
index 50227dc..542558a 100644 (file)
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-import { SizeOf, TypeCode } from "./ctypes";
+import { SizeOf, ArgTypeCode } from "./ctypes";
 import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
 import { detectGPUDevice } from "./webgpu";
 import * as compact from "./compact";
@@ -216,10 +216,10 @@ export class RPCServer {
 
       for (let i = 0; i < nargs; ++i) {
         const tcode = tcodes[i];
-        if (tcode == TypeCode.TVMStr) {
+        if (tcode == ArgTypeCode.TVMStr) {
           const str = Uint8ArrayToString(reader.readByteArray());
           args.push(str);
-        } else if (tcode == TypeCode.TVMBytes) {
+        } else if (tcode == ArgTypeCode.TVMBytes) {
           args.push(reader.readByteArray());
         } else {
           throw new Error("cannot support type code " + tcode);
index bcf7be7..5c9b9d8 100644 (file)
@@ -20,7 +20,7 @@
 /**
  * TVM JS Wasm Runtime library.
  */
-import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes";
+import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes";
 import { Disposable } from "./types";
 import { Memory, CachedCallStack } from "./memory";
 import { assert, StringToUint8Array } from "./support";
@@ -234,12 +234,21 @@ export class DLContext {
     );
   }
 }
+/**
+ * The data type code in DLDataType
+ */
+export const enum DLDataTypeCode {
+  Int = 0,
+  UInt = 1,
+  Float = 2,
+  OpaqueHandle = 3
+}
 
 const DLDataTypeCodeToStr: Record<number, string> = {
   0: "int",
   1: "uint",
   2: "float",
-  4: "handle",
+  3: "handle",
 };
 
 /**
@@ -866,16 +875,16 @@ export class Instance implements Disposable {
         lanes = 1;
       if (pattern.substring(0, 5) == "float") {
         pattern = pattern.substring(5, pattern.length);
-        code = TypeCode.Float;
+        code = DLDataTypeCode.Float;
       } else if (pattern.substring(0, 3) == "int") {
         pattern = pattern.substring(3, pattern.length);
-        code = TypeCode.Int;
+        code = DLDataTypeCode.Int;
       } else if (pattern.substring(0, 4) == "uint") {
         pattern = pattern.substring(4, pattern.length);
-        code = TypeCode.UInt;
+        code = DLDataTypeCode.UInt;
       } else if (pattern.substring(0, 6) == "handle") {
         pattern = pattern.substring(5, pattern.length);
-        code = TypeCode.TVMOpaqueHandle;
+        code = DLDataTypeCode.OpaqueHandle;
         bits = 64;
       } else {
         throw new Error("Unknown dtype " + dtype);
@@ -1140,47 +1149,47 @@ export class Instance implements Disposable {
       const codeOffset = argsCode + i * SizeOf.I32;
       if (val instanceof NDArray) {
         stack.storePtr(valueOffset, val.handle);
-        stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle);
       } else if (val instanceof Scalar) {
         if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) {
           stack.storeI64(valueOffset, val.value);
-          stack.storeI32(codeOffset, TypeCode.Int);
+          stack.storeI32(codeOffset, ArgTypeCode.Int);
         } else if (val.dtype.startsWith("float")) {
           stack.storeF64(valueOffset, val.value);
-          stack.storeI32(codeOffset, TypeCode.Float);
+          stack.storeI32(codeOffset, ArgTypeCode.Float);
         } else {
           assert(val.dtype == "handle", "Expect handle");
           stack.storePtr(valueOffset, val.value);
-          stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle);
+          stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle);
         }
       } else if (val instanceof DLContext) {
         stack.storeI32(valueOffset, val.deviceType);
         stack.storeI32(valueOffset + SizeOf.I32, val.deviceType);
-        stack.storeI32(codeOffset, TypeCode.TVMContext);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMContext);
       } else if (tp == "number") {
         stack.storeF64(valueOffset, val);
-        stack.storeI32(codeOffset, TypeCode.Float);
+        stack.storeI32(codeOffset, ArgTypeCode.Float);
         // eslint-disable-next-line no-prototype-builtins
       } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) {
         stack.storePtr(valueOffset, val._tvmPackedCell.handle);
-        stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle);
       } else if (val === null || val == undefined) {
         stack.storePtr(valueOffset, 0);
-        stack.storeI32(codeOffset, TypeCode.Null);
+        stack.storeI32(codeOffset, ArgTypeCode.Null);
       } else if (tp == "string") {
         stack.allocThenSetArgString(valueOffset, val);
-        stack.storeI32(codeOffset, TypeCode.TVMStr);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMStr);
       } else if (val instanceof Uint8Array) {
         stack.allocThenSetArgBytes(valueOffset, val);
-        stack.storeI32(codeOffset, TypeCode.TVMBytes);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMBytes);
       } else if (val instanceof Function) {
         val = this.toPackedFunc(val);
         stack.tempArgs.push(val);
         stack.storePtr(valueOffset, val._tvmPackedCell.handle);
-        stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle);
       } else if (val instanceof Module) {
         stack.storePtr(valueOffset, val.handle);
-        stack.storeI32(codeOffset, TypeCode.TVMModuleHandle);
+        stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle);
       } else {
         throw new Error("Unsupported argument type " + tp);
       }
@@ -1204,10 +1213,10 @@ export class Instance implements Disposable {
         let tcode = lib.memory.loadI32(codePtr);
 
         if (
-          tcode == TypeCode.TVMObjectHandle ||
-          tcode == TypeCode.TVMObjectRValueRefArg ||
-          tcode == TypeCode.TVMPackedFuncHandle ||
-          tcode == TypeCode.TVMModuleHandle
+          tcode == ArgTypeCode.TVMObjectHandle ||
+          tcode == ArgTypeCode.TVMObjectRValueRefArg ||
+          tcode == ArgTypeCode.TVMPackedFuncHandle ||
+          tcode == ArgTypeCode.TVMModuleHandle
         ) {
           lib.checkCall(
             (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)(
@@ -1290,25 +1299,25 @@ export class Instance implements Disposable {
 
   private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any {
     switch (tcode) {
-      case TypeCode.Int:
-      case TypeCode.UInt:
+      case ArgTypeCode.Int:
+      case ArgTypeCode.UInt:
         return this.memory.loadI64(rvaluePtr);
-      case TypeCode.Float:
+      case ArgTypeCode.Float:
         return this.memory.loadF64(rvaluePtr);
-        case TypeCode.TVMOpaqueHandle: {
+        case ArgTypeCode.TVMOpaqueHandle: {
           return this.memory.loadPointer(rvaluePtr);
         }
-      case TypeCode.TVMNDArrayHandle: {
+      case ArgTypeCode.TVMNDArrayHandle: {
         return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib);
       }
-      case TypeCode.TVMDLTensorHandle: {
+      case ArgTypeCode.TVMDLTensorHandle: {
         assert(callbackArg);
         return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib);
       }
-      case TypeCode.TVMPackedFuncHandle: {
+      case ArgTypeCode.TVMPackedFuncHandle: {
         return this.makePackedFunc(this.memory.loadPointer(rvaluePtr));
       }
-      case TypeCode.TVMModuleHandle: {
+      case ArgTypeCode.TVMModuleHandle: {
         return new Module(
           this.memory.loadPointer(rvaluePtr),
           this.lib,
@@ -1317,17 +1326,17 @@ export class Instance implements Disposable {
           }
         );
       }
-      case TypeCode.Null: return undefined;
-      case TypeCode.TVMContext: {
+      case ArgTypeCode.Null: return undefined;
+      case ArgTypeCode.TVMContext: {
         const deviceType = this.memory.loadI32(rvaluePtr);
         const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32);
         return this.context(deviceType, deviceId);
       }
-      case TypeCode.TVMStr: {
+      case ArgTypeCode.TVMStr: {
         const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr));
         return ret;
       }
-      case TypeCode.TVMBytes: {
+      case ArgTypeCode.TVMBytes: {
         return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr));
       }
       default: