* [RUNTIME] Refactor object python FFI to new protocol.
This is a pre-req to bring the Node system under object protocol.
Most of the code reflects the current code in the Node system.
- Use new instead of init so subclass can define their own constructors
- Allow register via name, besides type idnex
- Introduce necessary runtime C API functions
- Refactored Tensor and Datatype to directly use constructor.
* address review comments
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
- kObjectCell = 14U,
+ kObjectHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
TVMStreamHandle dst);
/*!
- * \brief Get the tag from an object.
+ * \brief Get the type_index from an object.
*
* \param obj The object handle.
- * \param tag The tag of object.
+ * \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens
*/
-TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
+TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
+
+/*!
+ * \brief Convert type key to type index.
+ * \param type_key The key of the type.
+ * \param out_tindex the corresponding type index.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
+
+/*!
+ * \brief Free the object.
+ *
+ * \param obj The object handle.
+ * \note Internally we decrease the reference counter of the object.
+ * The object will be freed when every reference to the object are removed.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
#ifdef __cplusplus
} // TVM_EXTERN_C
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
+ friend class TVMObjectCAPI;
};
/*!
}
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
- TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
+ TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
}
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
- type_code_ = kObjectCell;
+ type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
- case kObjectCell: {
+ case kObjectHandle: {
*this = other.operator ObjectRef();
break;
}
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
- case kObjectCell: {
+ case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
- case kObjectCell: return "ObjectCell";
+ case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
- type_codes_[i] = kObjectCell;
+ type_codes_[i] = kObjectHandle;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase
+from . import object as _object
from . import node as _node
FunctionHandle = ctypes.c_void_p
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
- type_codes[i] = TypeCode.OBJECT_CELL
+ type_codes[i] = TypeCode.OBJECT_HANDLE
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
raise get_last_ffi_error()
_ = temp_args
_ = args
- assert ret_tcode.value == TypeCode.NODE_HANDLE
+ assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
handle = ret_val.v_handle
return handle
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
+_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Runtime Object api"""
+from __future__ import absolute_import
+
+import ctypes
+from ..base import _LIB, check_call
+from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
+
+
+ObjectHandle = ctypes.c_void_p
+__init_by_constructor__ = None
+
+"""Maps object type to its constructor"""
+OBJECT_TYPE = {}
+
+def _register_object(index, cls):
+ """register object class"""
+ OBJECT_TYPE[index] = cls
+
+
+def _return_object(x):
+ handle = x.v_handle
+ if not isinstance(handle, ObjectHandle):
+ handle = ObjectHandle(handle)
+ tindex = ctypes.c_uint()
+ check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
+ cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
+ # Avoid calling __init__ of cls, instead directly call __new__
+ # This allows child class to implement their own __init__
+ obj = cls.__new__(cls)
+ obj.handle = handle
+ return obj
+
+RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
+C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
+ _return_object, TypeCode.OBJECT_HANDLE)
+
+
+class ObjectBase(object):
+ """Base object for all object types"""
+ __slots__ = ["handle"]
+
+ def __del__(self):
+ if _LIB is not None:
+ check_call(_LIB.TVMObjectFree(self.handle))
+
+ def __init_handle_by_constructor__(self, fconstructor, *args):
+ """Initialize the handle by calling constructor function.
+
+ Parameters
+ ----------
+ fconstructor : Function
+ Constructor function.
+
+ args: list of objects
+ The arguments to the constructor
+
+ Note
+ ----
+ We have a special calling convention to call constructor functions.
+ So the return handle is directly set into the Node object
+ instead of creating a new Node.
+ """
+ # assign handle first to avoid error raising
+ self.handle = None
+ handle = __init_by_constructor__(fconstructor, args)
+ if not isinstance(handle, ObjectHandle):
+ handle = ObjectHandle(handle)
+ self.handle = handle
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name
-"""Runtime Object api"""
-from __future__ import absolute_import
-
-import ctypes
-from ..base import _LIB, check_call
-from .types import TypeCode, RETURN_SWITCH
-
-ObjectHandle = ctypes.c_void_p
-
-"""Maps object type to its constructor"""
-OBJECT_TYPE = {}
-
-def _register_object(index, cls):
- """register object class"""
- OBJECT_TYPE[index] = cls
-
-
-def _return_object(x):
- handle = x.v_handle
- if not isinstance(handle, ObjectHandle):
- handle = ObjectHandle(handle)
- tag = ctypes.c_int()
- check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
- cls = OBJECT_TYPE.get(tag.value, ObjectBase)
- obj = cls(handle)
- return obj
-
-RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
-
-
-class ObjectBase(object):
- __slots__ = ["handle"]
-
- def __init__(self, handle):
- self.handle = handle
kStr = 11
kBytes = 12
kNDArrayContainer = 13
- kObjectCell = 14
+ kObjectHandle = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
- int TVMGetObjectTag(ObjectHandle obj, int* tag)
+ int TVMObjectFree(ObjectHandle obj)
+ int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)
+
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
# under the License.
include "./base.pxi"
+include "./object.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
-include "./vmobj.pxi"
+
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
- tcode == kObjectCell or
+ tcode == kObjectHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
+ elif isinstance(arg, _CLASS_OBJECT):
+ value[0].v_handle = (<ObjectBase>arg).chandle
+ tcode[0] = kObjectHandle
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
- elif isinstance(arg, _CLASS_OBJECT):
- value[0].v_handle = c_handle(arg.handle)
- tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
+ elif tcode == kObjectHandle:
+ return make_ret_object(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
- elif tcode == kObjectCell:
- return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
OBJECT_TYPE = []
def _register_object(int index, object cls):
- """register node class"""
+ """register object class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
- cdef int tag
+ cdef unsigned tindex
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
- CALL(TVMGetObjectTag(chandle, &tag))
- if tag < len(object_type):
- cls = object_type[tag]
+ CALL(TVMObjectGetTypeIndex(chandle, &tindex))
+ if tindex < len(object_type):
+ cls = object_type[tindex]
if cls is not None:
- obj = cls(handle)
+ obj = cls.__new__(cls)
else:
- obj = ObjectBase(handle)
+ obj = ObjectBase.__new__(ObjectBase)
else:
- obj = ObjectBase(handle)
+ obj = ObjectBase.__new__(ObjectBase)
+ (<ObjectBase>obj).chandle = chandle
return obj
cdef class ObjectBase:
- cdef ObjectHandle chandle
+ cdef void* chandle
cdef inline _set_handle(self, handle):
+ cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
- self.chandle = c_handle(handle)
+ ptr = handle.value
+ self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
- return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
+ return ctypes_handle(self.chandle)
+
def __set__(self, value):
self._set_handle(value)
- def __init__(self, handle):
- self._set_handle(handle)
+ def __dealloc__(self):
+ CALL(TVMObjectFree(self.chandle))
+
+ def __init_handle_by_constructor__(self, fconstructor, *args):
+ """Initialize the handle by calling constructor function.
+
+ Parameters
+ ----------
+ fconstructor : Function
+ Constructor function.
+
+ args: list of objects
+ The arguments to the constructor
+
+ Note
+ ----
+ We have a special calling convention to call constructor functions.
+ So the return handle is directly set into the Node object
+ instead of creating a new Node.
+ """
+ # avoid error raised during construction.
+ self.chandle = NULL
+ cdef void* chandle
+ ConstructorCall(
+ (<FunctionBase>fconstructor).chandle,
+ kObjectHandle, args, &chandle)
+ self.chandle = chandle
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
-from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Runtime Object API"""
+from __future__ import absolute_import
+
+import sys
+import ctypes
+from .base import _FFI_MODE, check_call, _LIB, c_str
+
+IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
+
+try:
+ # pylint: disable=wrong-import-position
+ if _FFI_MODE == "ctypes":
+ raise ImportError()
+ if sys.version_info >= (3, 0):
+ from ._cy3.core import _set_class_object
+ from ._cy3.core import ObjectBase as _ObjectBase
+ from ._cy3.core import _register_object
+ else:
+ from ._cy2.core import _set_class_object
+ from ._cy2.core import ObjectBase as _ObjectBase
+ from ._cy2.core import _register_object
+except IMPORT_EXCEPT:
+ # pylint: disable=wrong-import-position
+ from ._ctypes.function import _set_class_object
+ from ._ctypes.object import ObjectBase as _ObjectBase
+ from ._ctypes.object import _register_object
+
+
+class Object(_ObjectBase):
+ """Base class for all tvm's runtime objects."""
+ pass
+
+
+def register_object(type_key=None):
+ """register object type.
+
+ Parameters
+ ----------
+ type_key : str or cls
+ The type key of the node
+
+ Examples
+ --------
+ The following code registers MyObject
+ using type key "test.MyObject"
+
+ .. code-block:: python
+
+ @tvm.register_object("test.MyObject")
+ class MyObject(Object):
+ pass
+ """
+ object_name = type_key if isinstance(type_key, str) else type_key.__name__
+
+ def register(cls):
+ """internal register function"""
+ if hasattr(cls, "_type_index"):
+ tindex = cls._type_index
+ else:
+ tidx = ctypes.c_uint()
+ check_call(_LIB.TVMObjectTypeKey2Index(
+ c_str(object_name), ctypes.byref(tidx)))
+ tindex = tidx.value
+ _register_object(tindex, cls)
+ return cls
+
+ if isinstance(type_key, str):
+ return register
+
+ return register(type_key)
+
+
+def getitem_helper(obj, elem_getter, length, idx):
+ """Helper function to implement a pythonic getitem function.
+
+ Parameters
+ ----------
+ obj: object
+ The original object
+
+ elem_getter : function
+ A simple function that takes index and return a single element.
+
+ length : int
+ The size of the array
+
+ idx : int or slice
+ The argument passed to getitem
+
+ Returns
+ -------
+ result : object
+ The result of getitem
+ """
+ if isinstance(idx, slice):
+ start = idx.start if idx.start is not None else 0
+ stop = idx.stop if idx.stop is not None else length
+ step = idx.step if idx.step is not None else 1
+ if start < 0:
+ start += length
+ if stop < 0:
+ stop += length
+ return [elem_getter(obj, i) for i in range(start, stop, step)]
+
+ if idx < -length or idx >= length:
+ raise IndexError("Index out of range. size: {}, got index {}"
+ .format(length, idx))
+ if idx < 0:
+ idx += length
+ return elem_getter(obj, idx)
+
+
+_set_class_object(Object)
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
- OBJECT_CELL = 14
+ OBJECT_HANDLE = 14
EXT_BEGIN = 15
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name
-"""Runtime Object api"""
-from __future__ import absolute_import
-
-import sys
-from .base import _FFI_MODE
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
-
-try:
- # pylint: disable=wrong-import-position
- if _FFI_MODE == "ctypes":
- raise ImportError()
- if sys.version_info >= (3, 0):
- from ._cy3.core import _set_class_object
- from ._cy3.core import ObjectBase as _ObjectBase
- from ._cy3.core import _register_object
- else:
- from ._cy2.core import _set_class_object
- from ._cy2.core import ObjectBase as _ObjectBase
- from ._cy2.core import _register_object
-except IMPORT_EXCEPT:
- # pylint: disable=wrong-import-position
- from ._ctypes.function import _set_class_object
- from ._ctypes.vmobj import ObjectBase as _ObjectBase
- from ._ctypes.vmobj import _register_object
-
-
-class ObjectTag(object):
- """Type code used in API calls"""
- TENSOR = 1
- CLOSURE = 2
- DATATYPE = 3
-
-
-class Object(_ObjectBase):
- """The VM Object used in Relay virtual machine."""
-
-
-def register_object(cls):
- _register_object(cls.tag, cls)
- return cls
-
-
-_set_class_object(Object)
from numbers import Integral as _Integral
from ._ffi.base import string_types
+from ._ffi.object import register_object, Object
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from . import vmobj as _obj
from .interpreter import Executor
+Tensor = _obj.Tensor
+Datatype = _obj.Datatype
+
def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
- cargs.append(_obj.tensor_object(arg))
+ cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
from __future__ import absolute_import as _abs
import numpy as _np
-from tvm._ffi.vmobj import Object, ObjectTag, register_object
+from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj
-# TODO(@icemelon9): Add ClosureObject
-@register_object
-class TensorObject(Object):
- """Tensor object."""
- tag = ObjectTag.TENSOR
+@register_object("vm.Tensor")
+class Tensor(Object):
+ """Tensor object.
- def __init__(self, handle):
- """Constructs a Tensor object
-
- Parameters
- ----------
- handle : object
- Object handle
+ Parameters
+ ----------
+ arr : numpy.ndarray or tvm.nd.NDArray
+ The source array.
- Returns
- -------
- obj : TensorObject
- A tensor object.
- """
- super(TensorObject, self).__init__(handle)
- self.data = _vmobj.GetTensorData(self)
+ ctx : TVMContext, optional
+ The device context to create the array
+ """
+ def __init__(self, arr, ctx=None):
+ if isinstance(arr, _np.ndarray):
+ ctx = ctx if ctx else _nd.cpu(0)
+ self.__init_handle_by_constructor__(
+ _vmobj.Tensor, _nd.array(arr, ctx=ctx))
+ elif isinstance(arr, _nd.NDArray):
+ self.__init_handle_by_constructor__(
+ _vmobj.Tensor, arr)
+ else:
+ raise RuntimeError("Unsupported type for tensor object.")
+
+ @property
+ def data(self):
+ return _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
return self.data.asnumpy()
-@register_object
-class DatatypeObject(Object):
- """Datatype object."""
- tag = ObjectTag.DATATYPE
+@register_object("vm.Datatype")
+class Datatype(Object):
+ """Datatype object.
- def __init__(self, handle):
- """Constructs a Datatype object
+ Parameters
+ ----------
+ tag : int
+ The tag of datatype.
- Parameters
- ----------
- handle : object
- Object handle
+ fields : list[Object] or tuple[Object]
+ The source tuple.
+ """
+ def __init__(self, tag, fields):
+ for f in fields:
+ assert isinstance(f, Object)
+ self.__init_handle_by_constructor__(
+ _vmobj.Datatype, tag, *fields)
- Returns
- -------
- obj : DatatypeObject
- A Datatype object.
- """
- super(DatatypeObject, self).__init__(handle)
- self.tag = _vmobj.GetDatatypeTag(self)
- num_fields = _vmobj.GetDatatypeNumberOfFields(self)
- self.fields = []
- for i in range(num_fields):
- self.fields.append(_vmobj.GetDatatypeFields(self, i))
+ @property
+ def tag(self):
+ return _vmobj.GetDatatypeTag(self)
def __getitem__(self, idx):
- return self.fields[idx]
+ return getitem_helper(
+ self, _vmobj.GetDatatypeFields, len(self), idx)
def __len__(self):
- return len(self.fields)
-
- def __iter__(self):
- return iter(self.fields)
-
-# TODO(icemelon9): Add closure object
-
-def tensor_object(arr, ctx=_nd.cpu(0)):
- """Create a tensor object from source arr.
-
- Parameters
- ----------
- arr : numpy.ndarray or tvm.nd.NDArray
- The source array.
-
- ctx : TVMContext, optional
- The device context to create the array
-
- Returns
- -------
- ret : TensorObject
- The created object.
- """
- if isinstance(arr, _np.ndarray):
- tensor = _vmobj.Tensor(_nd.array(arr, ctx))
- elif isinstance(arr, _nd.NDArray):
- tensor = _vmobj.Tensor(arr)
- else:
- raise RuntimeError("Unsupported type for tensor object.")
- return tensor
+ return _vmobj.GetDatatypeNumberOfFields(self)
def tuple_object(fields):
Returns
-------
- ret : DatatypeObject
+ ret : Datatype
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
-
-
-def datatype_object(tag, fields):
- """Create a datatype object from tag and source fields.
-
- Parameters
- ----------
- tag : int
- The tag of datatype.
-
- fields : list[Object] or tuple[Object]
- The source tuple.
-
- Returns
- -------
- ret : DatatypeObject
- The created object.
- """
- for f in fields:
- assert isinstance(f, Object)
- return _vmobj.Datatype(tag, *fields)
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <string>
#include <vector>
#include <unordered_map>
+#include "runtime_base.h"
namespace tvm {
namespace runtime {
uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key);
}
+
+class TVMObjectCAPI {
+ public:
+ static void Free(TVMObjectHandle obj) {
+ static_cast<Object*>(obj)->DecRef();
+ }
+
+ static uint32_t TypeKey2Index(const char* type_key) {
+ return Object::TypeKey2Index(type_key);
+ }
+};
} // namespace runtime
} // namespace tvm
+
+int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
+ API_BEGIN();
+ out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
+ API_END();
+}
+
+int TVMObjectFree(TVMObjectHandle obj) {
+ API_BEGIN();
+ tvm::runtime::TVMObjectCAPI::Free(obj);
+ API_END();
+}
+
+int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
+ API_BEGIN();
+ out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
+ type_key);
+ API_END();
+}
return x
def vmobj_to_list(o):
- if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
+ if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
- elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
+ elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return ret
def vmobj_to_list(o):
- if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
+ if isinstance(o, tvm.relay.backend.vm.Tensor):
return [o.asnumpy().tolist()]
- elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
+ elif isinstance(o, tvm.relay.backend.vm.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+from tvm.relay import vm
+
+def test_tensor():
+ arr = tvm.nd.array([1,2,3])
+ x = vm.Tensor(arr)
+ assert isinstance(x, vm.Tensor)
+ assert x.asnumpy()[0] == 1
+ assert x.asnumpy()[-1] == 3
+ assert isinstance(x.data, tvm.nd.NDArray)
+
+
+def test_datatype():
+ arr = tvm.nd.array([1,2,3])
+ x = vm.Tensor(arr)
+ y = vm.Datatype(0, [x, x])
+
+ assert len(y) == 2
+ assert isinstance(y, vm.Datatype)
+ y[0:1][-1].data == x.data
+ assert y.tag == 0
+ assert isinstance(x.data, tvm.nd.NDArray)
+
+
+
+if __name__ == "__main__":
+ test_tensor()
+ test_datatype()