From: Tianqi Chen Date: Thu, 4 Jun 2020 22:04:17 +0000 (-0700) Subject: [REFACTOR] Separate ArgTypeCode from DLDataTypeCode (#5730) X-Git-Tag: upstream/0.7.0~609 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8a98782cdfe6a5dbe6e61fc19d5526c16827aba6;p=platform%2Fupstream%2Ftvm.git [REFACTOR] Separate ArgTypeCode from DLDataTypeCode (#5730) We use a single enum(TypeCode) to represent ArgTypeCode and DLDataTypeCode. However, as we start to expand more data types, it is clear that argument type code(in the FFI convention) and data type code needs to evolve separately. So that we can add first class for data types without having changing the FFI ABI. This PR makes the distinction clear and refactored the code to separate the two. - [PY] Separate ArgTypeCode from DataTypeCode - [WEB] Separate ArgTypeCode from DataTypeCode - [JAVA] Separate ArgTypeCode from DataTypeCode --- diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 0acb731..3ec0443 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 +Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642 diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index bb38ad8..be86563 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -87,12 +87,14 @@ typedef enum { } TVMDeviceExtType; /*! - * \brief The type code in used in the TVM FFI. + * \brief The type code in used in the TVM FFI for argument passing. */ typedef enum { // The type code of other types are compatible with DLPack. // The next few fields are extension types // that is used by TVM API calls. + kTVMArgInt = kDLInt, + kTVMArgFloat = kDLFloat, kTVMOpaqueHandle = 3U, kTVMNullptr = 4U, kTVMDataType = 5U, @@ -115,9 +117,7 @@ typedef enum { // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, - // The rest of the space is used for custom, user-supplied datatypes - kTVMCustomBegin = 129U, -} TVMTypeCode; +} TVMArgTypeCode; /*! * \brief The Device information, abstract away common device types. diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a10b83f..1d53810 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -45,7 +45,8 @@ class DataType { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, - kHandle = TVMTypeCode::kTVMOpaqueHandle, + kHandle = TVMArgTypeCode::kTVMOpaqueHandle, + kCustomBegin = 129 }; /*! \brief default constructor */ DataType() {} @@ -248,7 +249,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); * \param type_code The type code . * \return The name of type code. */ -inline const char* TypeCode2Str(int type_code); +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code); /*! * \brief convert a string to TVM type. @@ -265,38 +266,16 @@ inline DLDataType String2DLDataType(std::string s); inline std::string DLDataType2String(DLDataType t); // implementation details -inline const char* TypeCode2Str(int type_code) { - switch (type_code) { +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { + switch (static_cast(type_code)) { case kDLInt: return "int"; case kDLUInt: return "uint"; case kDLFloat: return "float"; - case kTVMStr: - return "str"; - case kTVMBytes: - return "bytes"; - case kTVMOpaqueHandle: + case DataType::kHandle: return "handle"; - case kTVMNullptr: - return "NULL"; - case kTVMDLTensorHandle: - return "ArrayHandle"; - case kTVMDataType: - return "DLDataType"; - case kTVMContext: - return "TVMContext"; - case kTVMPackedFuncHandle: - return "FunctionHandle"; - case kTVMModuleHandle: - return "ModuleHandle"; - case kTVMNDArrayHandle: - return "NDArrayContainer"; - case kTVMObjectHandle: - return "Object"; - case kTVMObjectRValueRefArg: - return "ObjectRValueRefArg"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; @@ -311,8 +290,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (DataType(t).is_void()) { return os << "void"; } - if (t.code < kTVMCustomBegin) { - os << TypeCode2Str(t.code); + if (t.code < DataType::kCustomBegin) { + os << DLDataTypeCode2Str(static_cast(t.code)); } else { os << "custom[" << GetCustomTypeName(t.code) << "]"; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 01f8e99..e82b97a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -327,9 +327,16 @@ class TVMArgs { inline TVMArgValue operator[](int i) const; }; +/*! + * \brief Convert argument type code to string. + * \param type_code The input type code. + * \return The corresponding string repr. + */ +inline const char* ArgTypeCode2Str(int type_code); + // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ - CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) + CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -394,7 +401,7 @@ class TVMPODValue_ { } else { if (type_code_ == kTVMNullptr) return nullptr; LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); + << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_); return nullptr; } } @@ -982,6 +989,44 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_( inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace +inline const char* ArgTypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kTVMContext: + return "TVMContext"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; + } +} + namespace detail { template diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5884942..a4748d5 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -740,7 +740,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kTVMCustomBegin)) { + if (static_cast(t.code()) >= static_cast(DataType::kCustomBegin)) { return FloatImm(t, static_cast(value)); } LOG(FATAL) << "cannot make const for type " << t; diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java similarity index 95% rename from jvm/core/src/main/java/org/apache/tvm/TypeCode.java rename to jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java index 2d21e4a..b3b3da5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java +++ b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java @@ -18,14 +18,14 @@ package org.apache.tvm; // Type code used in API calls -public enum TypeCode { +public enum ArgTypeCode { INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); public final int id; - private TypeCode(int id) { + private ArgTypeCode(int id) { this.id = id; } diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index a9ac707..df535a8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -80,7 +80,7 @@ public class Function extends TVMValue { * @param isResident Whether this is a resident function in jvm */ Function(long handle, boolean isResident) { - super(TypeCode.FUNC_HANDLE); + super(ArgTypeCode.FUNC_HANDLE); this.handle = handle; this.isResident = isResident; } @@ -187,7 +187,7 @@ public class Function extends TVMValue { * @return this */ public Function pushArg(NDArrayBase arg) { - int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(arg.handle, id); return this; } @@ -198,7 +198,7 @@ public class Function extends TVMValue { * @return this */ public Function pushArg(Module arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id); return this; } @@ -208,7 +208,7 @@ public class Function extends TVMValue { * @return this */ public Function pushArg(Function arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id); return this; } @@ -249,12 +249,12 @@ public class Function extends TVMValue { Base._LIB.tvmFuncPushArgBytes((byte[]) arg); } else if (arg instanceof NDArrayBase) { NDArrayBase nd = (NDArrayBase) arg; - int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(nd.handle, id); } else if (arg instanceof Module) { - Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { - Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 1656f8d..874daa4 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -45,7 +45,7 @@ public class Module extends TVMValue { } Module(long handle) { - super(TypeCode.MODULE_HANDLE); + super(ArgTypeCode.MODULE_HANDLE); this.handle = handle; } @@ -138,7 +138,7 @@ public class Module extends TVMValue { */ public static Module load(String path, String fmt) { TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); - assert ret.typeCode == TypeCode.MODULE_HANDLE; + assert ret.typeCode == ArgTypeCode.MODULE_HANDLE; return ret.asModule(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java index 5ac630d..26bb735 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java @@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue { private boolean isReleased = false; NDArrayBase(long handle, boolean isView) { - super(TypeCode.ARRAY_HANDLE); + super(ArgTypeCode.ARRAY_HANDLE); this.handle = handle; this.isView = isView; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index 92c7623..d30cfcc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -18,9 +18,9 @@ package org.apache.tvm; public class TVMValue { - public final TypeCode typeCode; + public final ArgTypeCode typeCode; - public TVMValue(TypeCode tc) { + public TVMValue(ArgTypeCode tc) { typeCode = tc; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java index 6c7c1c8..132d88f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java @@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue { public final byte[] value; public TVMValueBytes(byte[] value) { - super(TypeCode.BYTES); + super(ArgTypeCode.BYTES); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java index d94b011..9db4c3b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java @@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue { public final double value; public TVMValueDouble(double value) { - super(TypeCode.FLOAT); + super(ArgTypeCode.FLOAT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java index 8ab7572..b91f55e 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java @@ -18,13 +18,13 @@ package org.apache.tvm; /** - * Java class related to TVM handles (TypeCode.HANDLE) + * Java class related to TVM handles (ArgTypeCode.HANDLE) */ public class TVMValueHandle extends TVMValue { public final long value; public TVMValueHandle(long value) { - super(TypeCode.HANDLE); + super(ArgTypeCode.HANDLE); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java index 5dba2fd..8a9b157 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java @@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue { public final long value; public TVMValueLong(long value) { - super(TypeCode.INT); + super(ArgTypeCode.INT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java index 03c0ea0..8c49ee5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java @@ -19,6 +19,6 @@ package org.apache.tvm; public class TVMValueNull extends TVMValue { public TVMValueNull() { - super(TypeCode.NULL); + super(ArgTypeCode.NULL); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java index 260803e..46926e7 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java @@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue { public final String value; public TVMValueString(String value) { - super(TypeCode.STR); + super(ArgTypeCode.STR); this.value = value; } diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 6db8655..6cbc6d2 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -23,7 +23,7 @@ import traceback # top-level alias # tvm._ffi from ._ffi.base import TVMError, __version__ -from ._ffi.runtime_ctypes import TypeCode, DataType +from ._ffi.runtime_ctypes import DataTypeCode, DataType from ._ffi import register_object, register_func, register_extension, get_global_func # top-level alias diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 3dbb607..359b018 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -18,7 +18,7 @@ """Runtime Object api""" import ctypes from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .ndarray import _register_ndarray, NDArrayBase @@ -60,12 +60,12 @@ def _return_object(x): obj.handle = handle return obj -RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_HANDLE) +RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG) class PyNativeObject: diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index b17174a..8a2f49a 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -26,7 +26,7 @@ from ..base import c_str, string_types from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array -from .types import TVMValue, TypeCode +from .types import TVMValue, ArgTypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .object import ObjectBase, PyNativeObject, _set_class_object @@ -115,32 +115,32 @@ def _make_tvm_args(args, temp_args): for i, arg in enumerate(args): if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None - type_codes[i] = TypeCode.NULL + type_codes[i] = ArgTypeCode.NULL elif isinstance(arg, NDArrayBase): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) - type_codes[i] = (TypeCode.NDARRAY_HANDLE - if not arg.is_view else TypeCode.DLTENSOR_HANDLE) + type_codes[i] = (ArgTypeCode.NDARRAY_HANDLE + if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE) elif isinstance(arg, PyNativeObject): values[i].v_handle = arg.__tvm_object__.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode elif isinstance(arg, Integral): values[i].v_int64 = arg - type_codes[i] = TypeCode.INT + type_codes[i] = ArgTypeCode.INT elif isinstance(arg, Number): values[i].v_float64 = arg - type_codes[i] = TypeCode.FLOAT + type_codes[i] = ArgTypeCode.FLOAT elif isinstance(arg, DataType): values[i].v_str = c_str(str(arg)) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) - type_codes[i] = TypeCode.TVM_CONTEXT + type_codes[i] = ArgTypeCode.TVM_CONTEXT elif isinstance(arg, (bytearray, bytes)): # from_buffer only taeks in bytearray. if isinstance(arg, bytes): @@ -155,31 +155,31 @@ def _make_tvm_args(args, temp_args): arr.size = len(arg) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) temp_args.append(arr) - type_codes[i] = TypeCode.BYTES + type_codes[i] = ArgTypeCode.BYTES elif isinstance(arg, string_types): values[i].v_str = c_str(arg) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): arg = _FUNC_CONVERT_TO_OBJECT(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.MODULE_HANDLE + type_codes[i] = ArgTypeCode.MODULE_HANDLE elif isinstance(arg, PackedFuncBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg - type_codes[i] = TypeCode.HANDLE + type_codes[i] = ArgTypeCode.HANDLE elif isinstance(arg, ObjectRValueRef): values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) - type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG + type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE temp_args.append(arg) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -240,7 +240,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.OBJECT_HANDLE + assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -275,15 +275,15 @@ def _get_global_func(name, allow_missing=False): # setup return handle for function type _object.__init_by_constructor__ = __init_handle_by_constructor__ -RETURN_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _handle_return_func -RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module -RETURN_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) -C_TO_PY_ARG_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( - _handle_return_func, TypeCode.PACKED_FUNC_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( - _return_module, TypeCode.MODULE_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) -C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func +RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module +RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( + _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func( + _return_module, ArgTypeCode.MODULE_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) +C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) _CLASS_MODULE = None _CLASS_PACKED_FUNC = None diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 20be30a..d4e7b36 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -19,7 +19,7 @@ import ctypes import struct from ..base import py_str, check_call, _LIB -from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext +from ..runtime_ctypes import TVMByteArray, ArgTypeCode, TVMContext class TVMValue(ctypes.Union): """TVMValue in C API""" @@ -86,21 +86,21 @@ def _ctx_to_int64(ctx): RETURN_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } C_TO_PY_ARG_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0da66ac..8c9e413 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -22,7 +22,7 @@ from cpython cimport pycapsule from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t import ctypes -cdef enum TVMTypeCode: +cdef enum TVMArgTypeCode: kInt = 0 kUInt = 1 kFloat = 2 diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py index e4b8b18..0942ccb 100644 --- a/python/tvm/_ffi/registry.py +++ b/python/tvm/_ffi/registry.py @@ -122,7 +122,7 @@ def register_extension(cls, fcreate=None): @tvm.register_extension class MyTensor(object): - _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE + _tvm_tcode = tvm.ArgTypeCode.ARRAY_HANDLE def __init__(self): self.handle = _LIB.NewDLTensor() @@ -132,8 +132,8 @@ def register_extension(cls, fcreate=None): return self.handle.value """ assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") + if fcreate: + raise ValueError("Extension with fcreate is no longer supported") _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index db89854..2e498e3 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -23,7 +23,7 @@ from .base import _LIB, check_call tvm_shape_index_t = ctypes.c_int64 -class TypeCode(object): +class ArgTypeCode(object): """Type code used in API calls""" INT = 0 UINT = 1 @@ -42,23 +42,30 @@ class TypeCode(object): OBJECT_RVALUE_REF_ARG = 14 EXT_BEGIN = 15 - class TVMByteArray(ctypes.Structure): """Temp data structure for byte array.""" _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] +class DataTypeCode(object): + """DataType code in DLTensor.""" + INT = 0 + UINT = 1 + FLOAT = 2 + HANDLE = 3 + + class DataType(ctypes.Structure): """TVM datatype structure""" _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] CODE2STR = { - 0 : 'int', - 1 : 'uint', - 2 : 'float', - 4 : 'handle' + DataTypeCode.INT : 'int', + DataTypeCode.UINT : 'uint', + DataTypeCode.FLOAT : 'float', + DataTypeCode.HANDLE : 'handle' } def __init__(self, type_str): super(DataType, self).__init__() @@ -67,7 +74,7 @@ class DataType(ctypes.Structure): if type_str == "bool": self.bits = 1 - self.type_code = 1 + self.type_code = DataTypeCode.UINT self.lanes = 1 return @@ -77,16 +84,16 @@ class DataType(ctypes.Structure): bits = 32 if head.startswith("int"): - self.type_code = 0 + self.type_code = DataTypeCode.INT head = head[3:] elif head.startswith("uint"): - self.type_code = 1 + self.type_code = DataTypeCode.UINT head = head[4:] elif head.startswith("float"): - self.type_code = 2 + self.type_code = DataTypeCode.FLOAT head = head[5:] elif head.startswith("handle"): - self.type_code = 4 + self.type_code = DataTypeCode.HANDLE bits = 64 head = "" elif head.startswith("custom"): diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 59574e6..21c06c5 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,7 +20,7 @@ from .packed_func import PackedFunc from .object import Object from .object_generic import ObjectGeneric, ObjectTypes -from .ndarray import NDArray, DataType, TypeCode, TVMContext +from .ndarray import NDArray, DataType, DataTypeCode, TVMContext from .module import Module # function exposures diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 6629cc6..060673d 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle -from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t +from tvm._ffi.runtime_ctypes import DataTypeCode, tvm_shape_index_t try: # pylint: disable=wrong-import-position diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 4cbece3..aca5e5a 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -29,7 +29,7 @@ For example, you can use addexp.a to get the left operand of an Add node. """ import tvm._ffi -from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const +from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const from tvm.ir import PrimExpr import tvm.ir._ffi_api from . import generic as _generic @@ -47,13 +47,13 @@ def _dtype_is_int(value): if isinstance(value, int): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.INT) + DataType(value.dtype).type_code == DataTypeCode.INT) def _dtype_is_float(value): if isinstance(value, float): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.FLOAT) + DataType(value.dtype).type_code == DataTypeCode.FLOAT) class ExprOp(object): """Operator overloading for Expr like expressions.""" diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index f3bac39..65434b9 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -94,52 +94,52 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMTypeCode_kTVMNullptr => Null, - TVMTypeCode_kTVMDataType => DataType($value.v_type), - TVMTypeCode_kTVMContext => Context($value.v_ctx), - TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } - pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), - Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, - TVMTypeCode_kTVMStr, + TVMArgTypeCode_kTVMStr, ) } - Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMTypeCode_kTVMNDArrayHandle, + TVMArgTypeCode_kTVMNDArrayHandle, ) }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, - TVMTypeCode_kTVMPackedFuncHandle + TVMArgTypeCode_kTVMPackedFuncHandle ), NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), $( $self_type($val) => { $from_self_type } ),+ } } @@ -155,14 +155,14 @@ TVMPODValue! { Str(&'a CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { - (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } } } @@ -188,14 +188,14 @@ TVMPODValue! { Str(&'static CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } } } diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 8411b03..88d6cc8 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -204,7 +204,7 @@ impl<'a, 'm> Builder<'a, 'm> { ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = + let (mut values, mut type_codes): (Vec, Vec) = self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() }; @@ -257,9 +257,9 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int { check_call!(ffi::TVMCbArgToReturn( &mut value as *mut _, diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index e4b2739..a326aa1 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -95,52 +95,52 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMTypeCode_kTVMNullptr => Null, - TVMTypeCode_kTVMDataType => DataType($value.v_type), - TVMTypeCode_kTVMContext => Context($value.v_ctx), - TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } - pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), - Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, - TVMTypeCode_kTVMStr, + TVMArgTypeCode_kTVMStr, ) } - Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMTypeCode_kTVMNDArrayHandle, + TVMArgTypeCode_kTVMNDArrayHandle, ) }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, - TVMTypeCode_kTVMPackedFuncHandle + TVMArgTypeCode_kTVMPackedFuncHandle ), NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), $( $self_type($val) => { $from_self_type } ),+ } } @@ -156,14 +156,14 @@ TVMPODValue! { Str(&'a CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { - (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } } } @@ -189,14 +189,14 @@ TVMPODValue! { Str(&'static CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } } } diff --git a/src/runtime/micro/standalone/utvm_graph_runtime.cc b/src/runtime/micro/standalone/utvm_graph_runtime.cc index db55634..e19ee34 100644 --- a/src/runtime/micro/standalone/utvm_graph_runtime.cc +++ b/src/runtime/micro/standalone/utvm_graph_runtime.cc @@ -327,7 +327,7 @@ std::function CreateTVMOp(const DSOModule& module, const TVMOpParam& par } TVMValue; /*typedef*/ enum { kTVMDLTensorHandle = 7U, - } /*TVMTypeCode*/; + } /*TVMArgTypeCode*/; struct OpArgs { DynArray args; DynArray arg_values; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 8c46269..b9cdc2c 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -251,7 +251,7 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code()) + LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code()) << " as an argument to the remote"; return nullptr; } diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 99d6bee..5ed3ce4 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -49,8 +49,8 @@ Registry* Registry::Global() { } void Registry::Register(const std::string& type_name, uint8_t type_code) { - CHECK(type_code >= kTVMCustomBegin) - << "Please choose a type code >= kTVMCustomBegin for custom types"; + CHECK(type_code >= DataType::kCustomBegin) + << "Please choose a type code >= DataType::kCustomBegin for custom types"; code_to_name_[type_code] = type_name; name_to_code_[type_name] = type_code; } @@ -78,7 +78,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { ss << datatype::Registry::Global()->GetTypeName(type_code); } else { - ss << runtime::TypeCode2Str(type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(type_code)); } ss << "."; @@ -86,7 +86,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { ss << datatype::Registry::Global()->GetTypeName(src_type_code); } else { - ss << runtime::TypeCode2Str(src_type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(src_type_code)); } return runtime::Registry::Get(ss.str()); } diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index c043592..5df8ef8 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -61,7 +61,7 @@ class Registry { * same code. Generally, this should be straightforward, as the user will be manually registering * all of their custom types. * \param type_name The name of the type, e.g. "bfloat" - * \param type_code The type code, which should be greater than TVMTypeCode::kTVMExtEnd + * \param type_code The type code, which should be greater than TVMArgTypeCode::kTVMExtEnd */ void Register(const std::string& type_name, uint8_t type_code); diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 48eaf7d..2207eb3 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -18,9 +18,10 @@ import tvm from tvm import te import numpy as np + @tvm.register_extension class MyTensorView(object): - _tvm_tcode = tvm.TypeCode.DLTENSOR_HANDLE + _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE def __init__(self, arr): self.arr = arr diff --git a/tests/python/unittest/test_runtime_ndarray.py b/tests/python/unittest/test_runtime_ndarray.py index e314379..3631295 100644 --- a/tests/python/unittest/test_runtime_ndarray.py +++ b/tests/python/unittest/test_runtime_ndarray.py @@ -72,6 +72,13 @@ def test_fp16_conversion(): tvm.testing.assert_allclose(expected, real) + +def test_dtype(): + dtype = tvm.DataType("handle") + assert dtype.type_code == tvm.DataTypeCode.HANDLE + + if __name__ == "__main__": test_nd_create() test_fp16_conversion() + test_dtype() diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index f533b4e..66c46fe 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -208,9 +208,9 @@ export const enum SizeOf { } /** - * Type code in TVM FFI. + * Argument Type code in TVM FFI. */ -export const enum TypeCode { +export const enum ArgTypeCode { Int = 0, UInt = 1, Float = 2, @@ -226,4 +226,4 @@ export const enum TypeCode { TVMBytes = 12, TVMNDArrayHandle = 13, TVMObjectRValueRefArg = 14 -} \ No newline at end of file +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 50227dc..542558a 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -17,7 +17,7 @@ * under the License. */ -import { SizeOf, TypeCode } from "./ctypes"; +import { SizeOf, ArgTypeCode } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; import { detectGPUDevice } from "./webgpu"; import * as compact from "./compact"; @@ -216,10 +216,10 @@ export class RPCServer { for (let i = 0; i < nargs; ++i) { const tcode = tcodes[i]; - if (tcode == TypeCode.TVMStr) { + if (tcode == ArgTypeCode.TVMStr) { const str = Uint8ArrayToString(reader.readByteArray()); args.push(str); - } else if (tcode == TypeCode.TVMBytes) { + } else if (tcode == ArgTypeCode.TVMBytes) { args.push(reader.readByteArray()); } else { throw new Error("cannot support type code " + tcode); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index bcf7be7..5c9b9d8 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -20,7 +20,7 @@ /** * TVM JS Wasm Runtime library. */ -import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; @@ -234,12 +234,21 @@ export class DLContext { ); } } +/** + * The data type code in DLDataType + */ +export const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} const DLDataTypeCodeToStr: Record = { 0: "int", 1: "uint", 2: "float", - 4: "handle", + 3: "handle", }; /** @@ -866,16 +875,16 @@ export class Instance implements Disposable { lanes = 1; if (pattern.substring(0, 5) == "float") { pattern = pattern.substring(5, pattern.length); - code = TypeCode.Float; + code = DLDataTypeCode.Float; } else if (pattern.substring(0, 3) == "int") { pattern = pattern.substring(3, pattern.length); - code = TypeCode.Int; + code = DLDataTypeCode.Int; } else if (pattern.substring(0, 4) == "uint") { pattern = pattern.substring(4, pattern.length); - code = TypeCode.UInt; + code = DLDataTypeCode.UInt; } else if (pattern.substring(0, 6) == "handle") { pattern = pattern.substring(5, pattern.length); - code = TypeCode.TVMOpaqueHandle; + code = DLDataTypeCode.OpaqueHandle; bits = 64; } else { throw new Error("Unknown dtype " + dtype); @@ -1140,47 +1149,47 @@ export class Instance implements Disposable { const codeOffset = argsCode + i * SizeOf.I32; if (val instanceof NDArray) { stack.storePtr(valueOffset, val.handle); - stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { stack.storeI64(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.Int); + stack.storeI32(codeOffset, ArgTypeCode.Int); } else if (val.dtype.startsWith("float")) { stack.storeF64(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.Float); + stack.storeI32(codeOffset, ArgTypeCode.Float); } else { assert(val.dtype == "handle", "Expect handle"); stack.storePtr(valueOffset, val.value); - stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); } } else if (val instanceof DLContext) { stack.storeI32(valueOffset, val.deviceType); stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); - stack.storeI32(codeOffset, TypeCode.TVMContext); + stack.storeI32(codeOffset, ArgTypeCode.TVMContext); } else if (tp == "number") { stack.storeF64(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.Float); + stack.storeI32(codeOffset, ArgTypeCode.Float); // eslint-disable-next-line no-prototype-builtins } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { stack.storePtr(valueOffset, val._tvmPackedCell.handle); - stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val === null || val == undefined) { stack.storePtr(valueOffset, 0); - stack.storeI32(codeOffset, TypeCode.Null); + stack.storeI32(codeOffset, ArgTypeCode.Null); } else if (tp == "string") { stack.allocThenSetArgString(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.TVMStr); + stack.storeI32(codeOffset, ArgTypeCode.TVMStr); } else if (val instanceof Uint8Array) { stack.allocThenSetArgBytes(valueOffset, val); - stack.storeI32(codeOffset, TypeCode.TVMBytes); + stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); } else if (val instanceof Function) { val = this.toPackedFunc(val); stack.tempArgs.push(val); stack.storePtr(valueOffset, val._tvmPackedCell.handle); - stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); } else if (val instanceof Module) { stack.storePtr(valueOffset, val.handle); - stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); } else { throw new Error("Unsupported argument type " + tp); } @@ -1204,10 +1213,10 @@ export class Instance implements Disposable { let tcode = lib.memory.loadI32(codePtr); if ( - tcode == TypeCode.TVMObjectHandle || - tcode == TypeCode.TVMObjectRValueRefArg || - tcode == TypeCode.TVMPackedFuncHandle || - tcode == TypeCode.TVMModuleHandle + tcode == ArgTypeCode.TVMObjectHandle || + tcode == ArgTypeCode.TVMObjectRValueRefArg || + tcode == ArgTypeCode.TVMPackedFuncHandle || + tcode == ArgTypeCode.TVMModuleHandle ) { lib.checkCall( (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( @@ -1290,25 +1299,25 @@ export class Instance implements Disposable { private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { switch (tcode) { - case TypeCode.Int: - case TypeCode.UInt: + case ArgTypeCode.Int: + case ArgTypeCode.UInt: return this.memory.loadI64(rvaluePtr); - case TypeCode.Float: + case ArgTypeCode.Float: return this.memory.loadF64(rvaluePtr); - case TypeCode.TVMOpaqueHandle: { + case ArgTypeCode.TVMOpaqueHandle: { return this.memory.loadPointer(rvaluePtr); } - case TypeCode.TVMNDArrayHandle: { + case ArgTypeCode.TVMNDArrayHandle: { return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); } - case TypeCode.TVMDLTensorHandle: { + case ArgTypeCode.TVMDLTensorHandle: { assert(callbackArg); return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); } - case TypeCode.TVMPackedFuncHandle: { + case ArgTypeCode.TVMPackedFuncHandle: { return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); } - case TypeCode.TVMModuleHandle: { + case ArgTypeCode.TVMModuleHandle: { return new Module( this.memory.loadPointer(rvaluePtr), this.lib, @@ -1317,17 +1326,17 @@ export class Instance implements Disposable { } ); } - case TypeCode.Null: return undefined; - case TypeCode.TVMContext: { + case ArgTypeCode.Null: return undefined; + case ArgTypeCode.TVMContext: { const deviceType = this.memory.loadI32(rvaluePtr); const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); return this.context(deviceType, deviceId); } - case TypeCode.TVMStr: { + case ArgTypeCode.TVMStr: { const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); return ret; } - case TypeCode.TVMBytes: { + case ArgTypeCode.TVMBytes: { return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); } default: