From: Tianqi Chen Date: Wed, 16 Oct 2019 22:24:23 +0000 (-0700) Subject: [RUNTIME] Refactor object python FFI to new protocol. (#4128) X-Git-Tag: upstream/0.7.0~1775 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=02c1e11716a5afdfc0159a1b21e45a64f7875473;p=platform%2Fupstream%2Ftvm.git [RUNTIME] Refactor object python FFI to new protocol. (#4128) * [RUNTIME] Refactor object python FFI to new protocol. This is a pre-req to bring the Node system under object protocol. Most of the code reflects the current code in the Node system. - Use new instead of init so subclass can define their own constructors - Allow register via name, besides type idnex - Introduce necessary runtime C API functions - Refactored Tensor and Datatype to directly use constructor. * address review comments --- diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 54e6f98..b058fd6 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -104,7 +104,7 @@ typedef enum { kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectCell = 14U, + kObjectHandle = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. @@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, TVMStreamHandle dst); /*! - * \brief Get the tag from an object. + * \brief Get the type_index from an object. * * \param obj The object handle. - * \param tag The tag of object. + * \param out_tindex the output type index. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag); +TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + +/*! + * \brief Free the object. + * + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectFree(TVMObjectHandle obj); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7b0653a..0693b1f 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -253,6 +253,7 @@ class Object { template friend class ObjectPtr; friend class TVMRetValue; + friend class TVMObjectCAPI; }; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5b71bbc..2bfa332 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -491,7 +491,7 @@ class TVMPODValue_ { } operator ObjectRef() const { if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { @@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(ObjectRef other) { this->Clear(); - type_code_ = kObjectCell; + type_code_ = kObjectHandle; // move the handle out value_.v_handle = other.data_.data_; other.data_.data_ = nullptr; @@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ { kNodeHandle, *other.template ptr >()); break; } - case kObjectCell: { + case kObjectHandle: { *this = other.operator ObjectRef(); break; } @@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ { static_cast(value_.v_handle)->DecRef(); break; } - case kObjectCell: { + case kObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; } @@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kObjectCell: return "ObjectCell"; + case kObjectHandle: return "ObjectCell"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } @@ -1164,7 +1164,7 @@ class TVMArgsSetter { } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectCell; + type_codes_[i] = kObjectHandle; } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 895c72d..22fb6c3 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -33,6 +33,7 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .node import NodeBase +from . import object as _object from . import node as _node FunctionHandle = ctypes.c_void_p @@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args): temp_args.append(arg) elif isinstance(arg, _CLASS_OBJECT): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_CELL + type_codes[i] = TypeCode.OBJECT_HANDLE else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.NODE_HANDLE + assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE) handle = ret_val.v_handle return handle @@ -247,6 +248,7 @@ def _handle_return_func(x): # setup return handle for function type _node.__init_by_constructor__ = __init_handle_by_constructor__ +_object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py new file mode 100644 index 0000000..5ddceb1 --- /dev/null +++ b/python/tvm/_ffi/_ctypes/object.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object api""" +from __future__ import absolute_import + +import ctypes +from ..base import _LIB, check_call +from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func + + +ObjectHandle = ctypes.c_void_p +__init_by_constructor__ = None + +"""Maps object type to its constructor""" +OBJECT_TYPE = {} + +def _register_object(index, cls): + """register object class""" + OBJECT_TYPE[index] = cls + + +def _return_object(x): + handle = x.v_handle + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + tindex = ctypes.c_uint() + check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) + cls = OBJECT_TYPE.get(tindex.value, ObjectBase) + # Avoid calling __init__ of cls, instead directly call __new__ + # This allows child class to implement their own __init__ + obj = cls.__new__(cls) + obj.handle = handle + return obj + +RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, TypeCode.OBJECT_HANDLE) + + +class ObjectBase(object): + """Base object for all object types""" + __slots__ = ["handle"] + + def __del__(self): + if _LIB is not None: + check_call(_LIB.TVMObjectFree(self.handle)) + + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # assign handle first to avoid error raising + self.handle = None + handle = __init_by_constructor__(fconstructor, args) + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + self.handle = handle diff --git a/python/tvm/_ffi/_ctypes/vmobj.py b/python/tvm/_ffi/_ctypes/vmobj.py deleted file mode 100644 index 59930e5..0000000 --- a/python/tvm/_ffi/_ctypes/vmobj.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH - -ObjectHandle = ctypes.c_void_p - -"""Maps object type to its constructor""" -OBJECT_TYPE = {} - -def _register_object(index, cls): - """register object class""" - OBJECT_TYPE[index] = cls - - -def _return_object(x): - handle = x.v_handle - if not isinstance(handle, ObjectHandle): - handle = ObjectHandle(handle) - tag = ctypes.c_int() - check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag))) - cls = OBJECT_TYPE.get(tag.value, ObjectBase) - obj = cls(handle) - return obj - -RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object - - -class ObjectBase(object): - __slots__ = ["handle"] - - def __init__(self, handle): - self.handle = handle diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 63130ef..76fa963 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -37,7 +37,7 @@ cdef enum TVMTypeCode: kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectCell = 14 + kObjectHandle = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMArrayToDLPack(DLTensorHandle arr_from, DLManagedTensor** out) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) - int TVMGetObjectTag(ObjectHandle obj, int* tag) + int TVMObjectFree(ObjectHandle obj) + int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) + cdef extern from "tvm/c_dsl_api.h": int TVMNodeFree(NodeHandle handle) diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index 4b8536c..a934933 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -16,7 +16,8 @@ # under the License. include "./base.pxi" +include "./object.pxi" include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" -include "./vmobj.pxi" + diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index cf1884c..ceacf74 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args, if (tcode == kNodeHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectCell or + tcode == kObjectHandle or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -155,12 +155,12 @@ cdef inline int make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = kNodeHandle temp_args.append(arg) + elif isinstance(arg, _CLASS_OBJECT): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle - elif isinstance(arg, _CLASS_OBJECT): - value[0].v_handle = c_handle(arg.handle) - tcode[0] = kObjectCell elif isinstance(arg, FunctionBase): value[0].v_handle = (arg).chandle tcode[0] = kFuncHandle @@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" if tcode == kNodeHandle: return make_ret_node(value.v_handle) + elif tcode == kObjectHandle: + return make_ret_object(value.v_handle) elif tcode == kNull: return None elif tcode == kInt: @@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode): fobj = _CLASS_FUNCTION(None, False) (fobj).chandle = value.v_handle return fobj - elif tcode == kObjectCell: - return make_ret_object(value.v_handle) elif tcode in _TVM_EXT_RET: return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) diff --git a/python/tvm/_ffi/_cython/vmobj.pxi b/python/tvm/_ffi/_cython/object.pxi similarity index 53% rename from python/tvm/_ffi/_cython/vmobj.pxi rename to python/tvm/_ffi/_cython/object.pxi index 9b48756..90be6a9 100644 --- a/python/tvm/_ffi/_cython/vmobj.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -19,7 +19,7 @@ OBJECT_TYPE = [] def _register_object(int index, object cls): - """register node class""" + """register object class""" while len(OBJECT_TYPE) <= index: OBJECT_TYPE.append(None) OBJECT_TYPE[index] = cls @@ -27,41 +27,70 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE - cdef int tag + cdef unsigned tindex cdef list object_type cdef object cls cdef object handle object_type = OBJECT_TYPE handle = ctypes_handle(chandle) - CALL(TVMGetObjectTag(chandle, &tag)) - if tag < len(object_type): - cls = object_type[tag] + CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(object_type): + cls = object_type[tindex] if cls is not None: - obj = cls(handle) + obj = cls.__new__(cls) else: - obj = ObjectBase(handle) + obj = ObjectBase.__new__(ObjectBase) else: - obj = ObjectBase(handle) + obj = ObjectBase.__new__(ObjectBase) + (obj).chandle = chandle return obj cdef class ObjectBase: - cdef ObjectHandle chandle + cdef void* chandle cdef inline _set_handle(self, handle): + cdef unsigned long long ptr if handle is None: self.chandle = NULL else: - self.chandle = c_handle(handle) + ptr = handle.value + self.chandle = (ptr) property handle: def __get__(self): if self.chandle == NULL: return None else: - return ctypes.cast(self.chandle, ctypes.c_void_p) + return ctypes_handle(self.chandle) + def __set__(self, value): self._set_handle(value) - def __init__(self, handle): - self._set_handle(handle) + def __dealloc__(self): + CALL(TVMObjectFree(self.chandle)) + + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # avoid error raised during construction. + self.chandle = NULL + cdef void* chandle + ConstructorCall( + (fconstructor).chandle, + kObjectHandle, args, &chandle) + self.chandle = chandle diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 4bb3182..60e7aeb 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,6 @@ from __future__ import absolute_import import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from . import vmobj as _vmobj IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py new file mode 100644 index 0000000..be8b086 --- /dev/null +++ b/python/tvm/_ffi/object.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object API""" +from __future__ import absolute_import + +import sys +import ctypes +from .base import _FFI_MODE, check_call, _LIB, c_str + +IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError + +try: + # pylint: disable=wrong-import-position + if _FFI_MODE == "ctypes": + raise ImportError() + if sys.version_info >= (3, 0): + from ._cy3.core import _set_class_object + from ._cy3.core import ObjectBase as _ObjectBase + from ._cy3.core import _register_object + else: + from ._cy2.core import _set_class_object + from ._cy2.core import ObjectBase as _ObjectBase + from ._cy2.core import _register_object +except IMPORT_EXCEPT: + # pylint: disable=wrong-import-position + from ._ctypes.function import _set_class_object + from ._ctypes.object import ObjectBase as _ObjectBase + from ._ctypes.object import _register_object + + +class Object(_ObjectBase): + """Base class for all tvm's runtime objects.""" + pass + + +def register_object(type_key=None): + """register object type. + + Parameters + ---------- + type_key : str or cls + The type key of the node + + Examples + -------- + The following code registers MyObject + using type key "test.MyObject" + + .. code-block:: python + + @tvm.register_object("test.MyObject") + class MyObject(Object): + pass + """ + object_name = type_key if isinstance(type_key, str) else type_key.__name__ + + def register(cls): + """internal register function""" + if hasattr(cls, "_type_index"): + tindex = cls._type_index + else: + tidx = ctypes.c_uint() + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + tindex = tidx.value + _register_object(tindex, cls) + return cls + + if isinstance(type_key, str): + return register + + return register(type_key) + + +def getitem_helper(obj, elem_getter, length, idx): + """Helper function to implement a pythonic getitem function. + + Parameters + ---------- + obj: object + The original object + + elem_getter : function + A simple function that takes index and return a single element. + + length : int + The size of the array + + idx : int or slice + The argument passed to getitem + + Returns + ------- + result : object + The result of getitem + """ + if isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else length + step = idx.step if idx.step is not None else 1 + if start < 0: + start += length + if stop < 0: + stop += length + return [elem_getter(obj, i) for i in range(start, stop, step)] + + if idx < -length or idx >= length: + raise IndexError("Index out of range. size: {}, got index {}" + .format(length, idx)) + if idx < 0: + idx += length + return elem_getter(obj, idx) + + +_set_class_object(Object) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 0d28abd..00e1945 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -42,7 +42,7 @@ class TypeCode(object): STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_CELL = 14 + OBJECT_HANDLE = 14 EXT_BEGIN = 15 diff --git a/python/tvm/_ffi/vmobj.py b/python/tvm/_ffi/vmobj.py deleted file mode 100644 index ea3431a..0000000 --- a/python/tvm/_ffi/vmobj.py +++ /dev/null @@ -1,61 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import sys -from .base import _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError - -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object - from ._cy3.core import ObjectBase as _ObjectBase - from ._cy3.core import _register_object - else: - from ._cy2.core import _set_class_object - from ._cy2.core import ObjectBase as _ObjectBase - from ._cy2.core import _register_object -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object - from ._ctypes.vmobj import ObjectBase as _ObjectBase - from ._ctypes.vmobj import _register_object - - -class ObjectTag(object): - """Type code used in API calls""" - TENSOR = 1 - CLOSURE = 2 - DATATYPE = 3 - - -class Object(_ObjectBase): - """The VM Object used in Relay virtual machine.""" - - -def register_object(cls): - _register_object(cls.tag, cls) - return cls - - -_set_class_object(Object) diff --git a/python/tvm/api.py b/python/tvm/api.py index e7523bd..f0261be 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs from numbers import Integral as _Integral from ._ffi.base import string_types +from ._ffi.object import register_object, Object from ._ffi.node import register_node, NodeBase from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node_generic import _scalar_type_inference diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index e54629d..c24b16c 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -30,9 +30,12 @@ from . import _vm from . import vmobj as _obj from .interpreter import Executor +Tensor = _obj.Tensor +Datatype = _obj.Datatype + def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): - cargs.append(_obj.tensor_object(arg)) + cargs.append(_obj.Tensor(arg)) elif isinstance(arg, (tuple, list)): field_args = [] for field in arg: diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index 4c92e9b..939b122 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -18,32 +18,37 @@ from __future__ import absolute_import as _abs import numpy as _np -from tvm._ffi.vmobj import Object, ObjectTag, register_object +from tvm._ffi.object import Object, register_object, getitem_helper from tvm import ndarray as _nd from . import _vmobj -# TODO(@icemelon9): Add ClosureObject -@register_object -class TensorObject(Object): - """Tensor object.""" - tag = ObjectTag.TENSOR +@register_object("vm.Tensor") +class Tensor(Object): + """Tensor object. - def __init__(self, handle): - """Constructs a Tensor object - - Parameters - ---------- - handle : object - Object handle + Parameters + ---------- + arr : numpy.ndarray or tvm.nd.NDArray + The source array. - Returns - ------- - obj : TensorObject - A tensor object. - """ - super(TensorObject, self).__init__(handle) - self.data = _vmobj.GetTensorData(self) + ctx : TVMContext, optional + The device context to create the array + """ + def __init__(self, arr, ctx=None): + if isinstance(arr, _np.ndarray): + ctx = ctx if ctx else _nd.cpu(0) + self.__init_handle_by_constructor__( + _vmobj.Tensor, _nd.array(arr, ctx=ctx)) + elif isinstance(arr, _nd.NDArray): + self.__init_handle_by_constructor__( + _vmobj.Tensor, arr) + else: + raise RuntimeError("Unsupported type for tensor object.") + + @property + def data(self): + return _vmobj.GetTensorData(self) def asnumpy(self): """Convert data to numpy array @@ -56,65 +61,34 @@ class TensorObject(Object): return self.data.asnumpy() -@register_object -class DatatypeObject(Object): - """Datatype object.""" - tag = ObjectTag.DATATYPE +@register_object("vm.Datatype") +class Datatype(Object): + """Datatype object. - def __init__(self, handle): - """Constructs a Datatype object + Parameters + ---------- + tag : int + The tag of datatype. - Parameters - ---------- - handle : object - Object handle + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, Object) + self.__init_handle_by_constructor__( + _vmobj.Datatype, tag, *fields) - Returns - ------- - obj : DatatypeObject - A Datatype object. - """ - super(DatatypeObject, self).__init__(handle) - self.tag = _vmobj.GetDatatypeTag(self) - num_fields = _vmobj.GetDatatypeNumberOfFields(self) - self.fields = [] - for i in range(num_fields): - self.fields.append(_vmobj.GetDatatypeFields(self, i)) + @property + def tag(self): + return _vmobj.GetDatatypeTag(self) def __getitem__(self, idx): - return self.fields[idx] + return getitem_helper( + self, _vmobj.GetDatatypeFields, len(self), idx) def __len__(self): - return len(self.fields) - - def __iter__(self): - return iter(self.fields) - -# TODO(icemelon9): Add closure object - -def tensor_object(arr, ctx=_nd.cpu(0)): - """Create a tensor object from source arr. - - Parameters - ---------- - arr : numpy.ndarray or tvm.nd.NDArray - The source array. - - ctx : TVMContext, optional - The device context to create the array - - Returns - ------- - ret : TensorObject - The created object. - """ - if isinstance(arr, _np.ndarray): - tensor = _vmobj.Tensor(_nd.array(arr, ctx)) - elif isinstance(arr, _nd.NDArray): - tensor = _vmobj.Tensor(arr) - else: - raise RuntimeError("Unsupported type for tensor object.") - return tensor + return _vmobj.GetDatatypeNumberOfFields(self) def tuple_object(fields): @@ -127,30 +101,9 @@ def tuple_object(fields): Returns ------- - ret : DatatypeObject + ret : Datatype The created object. """ for f in fields: assert isinstance(f, Object) return _vmobj.Tuple(*fields) - - -def datatype_object(tag, fields): - """Create a datatype object from tag and source fields. - - Parameters - ---------- - tag : int - The tag of datatype. - - fields : list[Object] or tuple[Object] - The source tuple. - - Returns - ------- - ret : DatatypeObject - The created object. - """ - for f in fields: - assert isinstance(f, Object) - return _vmobj.Datatype(tag, *fields) diff --git a/src/runtime/c_dsl_api.cc b/src/runtime/c_dsl_api.cc index e45c89a..bf90926 100644 --- a/src/runtime/c_dsl_api.cc +++ b/src/runtime/c_dsl_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 20793b4..74f0f3e 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 5248da0..a52a9b3 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -26,6 +26,7 @@ #include #include #include +#include "runtime_base.h" namespace tvm { namespace runtime { @@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) { uint32_t Object::TypeKey2Index(const char* key) { return TypeContext::Global()->TypeKey2Index(key); } + +class TVMObjectCAPI { + public: + static void Free(TVMObjectHandle obj) { + static_cast(obj)->DecRef(); + } + + static uint32_t TypeKey2Index(const char* type_key) { + return Object::TypeKey2Index(type_key); + } +}; } // namespace runtime } // namespace tvm + +int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { + API_BEGIN(); + out_tindex[0] = static_cast(obj)->type_index(); + API_END(); +} + +int TVMObjectFree(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::TVMObjectCAPI::Free(obj); + API_END(); +} + +int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { + API_BEGIN(); + out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index( + type_key); + API_END(); +} diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 6432bbd..c2cbbff 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -47,9 +47,9 @@ def convert_to_list(x): return x def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + if isinstance(o, tvm.relay.backend.vmobj.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + elif isinstance(o, tvm.relay.backend.vmobj.Datatype): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 5289fe9..cedbc4f 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): return ret def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + if isinstance(o, tvm.relay.backend.vm.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + elif isinstance(o, tvm.relay.backend.vm.Datatype): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py new file mode 100644 index 0000000..ad21fff --- /dev/null +++ b/tests/python/relay/test_vm_object.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm.relay import vm + +def test_tensor(): + arr = tvm.nd.array([1,2,3]) + x = vm.Tensor(arr) + assert isinstance(x, vm.Tensor) + assert x.asnumpy()[0] == 1 + assert x.asnumpy()[-1] == 3 + assert isinstance(x.data, tvm.nd.NDArray) + + +def test_datatype(): + arr = tvm.nd.array([1,2,3]) + x = vm.Tensor(arr) + y = vm.Datatype(0, [x, x]) + + assert len(y) == 2 + assert isinstance(y, vm.Datatype) + y[0:1][-1].data == x.data + assert y.tag == 0 + assert isinstance(x.data, tvm.nd.NDArray) + + + +if __name__ == "__main__": + test_tensor() + test_datatype()