[RUNTIME] Refactor object python FFI to new protocol. (#4128)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 16 Oct 2019 22:24:23 +0000 (15:24 -0700)
committerGitHub <noreply@github.com>
Wed, 16 Oct 2019 22:24:23 +0000 (15:24 -0700)
* [RUNTIME] Refactor object python FFI to new protocol.

This is a pre-req to bring the Node system under object protocol.
Most of the code reflects the current code in the Node system.

- Use new instead of init so subclass can define their own constructors
- Allow register via name, besides type idnex
- Introduce necessary runtime C API functions
- Refactored Tensor and Datatype to directly use constructor.

* address review comments

23 files changed:
include/tvm/runtime/c_runtime_api.h
include/tvm/runtime/object.h
include/tvm/runtime/packed_func.h
python/tvm/_ffi/_ctypes/function.py
python/tvm/_ffi/_ctypes/object.py [new file with mode: 0644]
python/tvm/_ffi/_ctypes/vmobj.py [deleted file]
python/tvm/_ffi/_cython/base.pxi
python/tvm/_ffi/_cython/core.pyx
python/tvm/_ffi/_cython/function.pxi
python/tvm/_ffi/_cython/object.pxi [moved from python/tvm/_ffi/_cython/vmobj.pxi with 53% similarity]
python/tvm/_ffi/function.py
python/tvm/_ffi/object.py [new file with mode: 0644]
python/tvm/_ffi/runtime_ctypes.py
python/tvm/_ffi/vmobj.py [deleted file]
python/tvm/api.py
python/tvm/relay/backend/vm.py
python/tvm/relay/backend/vmobj.py
src/runtime/c_dsl_api.cc
src/runtime/c_runtime_api.cc
src/runtime/object.cc
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_vm.py
tests/python/relay/test_vm_object.py [new file with mode: 0644]

index 54e6f98..b058fd6 100644 (file)
@@ -104,7 +104,7 @@ typedef enum {
   kStr = 11U,
   kBytes = 12U,
   kNDArrayContainer = 13U,
-  kObjectCell = 14U,
+  kObjectHandle = 14U,
   // Extension codes for other frameworks to integrate TVM PackedFunc.
   // To make sure each framework's id do not conflict, use first and
   // last sections to mark ranges.
@@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
                                        TVMStreamHandle dst);
 
 /*!
- * \brief Get the tag from an object.
+ * \brief Get the type_index from an object.
  *
  * \param obj The object handle.
- * \param tag The tag of object.
+ * \param out_tindex the output type index.
  * \return 0 when success, -1 when failure happens
  */
-TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
+TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
+
+/*!
+ * \brief Convert type key to type index.
+ * \param type_key The key of the type.
+ * \param out_tindex the corresponding type index.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
+
+/*!
+ * \brief Free the object.
+ *
+ * \param obj The object handle.
+ * \note Internally we decrease the reference counter of the object.
+ *       The object will be freed when every reference to the object are removed.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
 
 #ifdef __cplusplus
 }  // TVM_EXTERN_C
index 7b0653a..0693b1f 100644 (file)
@@ -253,6 +253,7 @@ class Object {
   template<typename>
   friend class ObjectPtr;
   friend class TVMRetValue;
+  friend class TVMObjectCAPI;
 };
 
 /*!
index 5b71bbc..2bfa332 100644 (file)
@@ -491,7 +491,7 @@ class TVMPODValue_ {
   }
   operator ObjectRef() const {
     if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
-    TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
+    TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
     return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
   }
   operator TVMContext() const {
@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
   }
   TVMRetValue& operator=(ObjectRef other) {
     this->Clear();
-    type_code_ = kObjectCell;
+    type_code_ = kObjectHandle;
     // move the handle out
     value_.v_handle = other.data_.data_;
     other.data_.data_ = nullptr;
@@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
             kNodeHandle, *other.template ptr<NodePtr<Node> >());
         break;
       }
-      case kObjectCell: {
+      case kObjectHandle: {
         *this = other.operator ObjectRef();
         break;
       }
@@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
         static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
         break;
       }
-      case kObjectCell: {
+      case kObjectHandle: {
         static_cast<Object*>(value_.v_handle)->DecRef();
         break;
       }
@@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
     case kFuncHandle: return "FunctionHandle";
     case kModuleHandle: return "ModuleHandle";
     case kNDArrayContainer: return "NDArrayContainer";
-    case kObjectCell: return "ObjectCell";
+    case kObjectHandle: return "ObjectCell";
     default: LOG(FATAL) << "unknown type_code="
                         << static_cast<int>(type_code); return "";
   }
@@ -1164,7 +1164,7 @@ class TVMArgsSetter {
   }
   void operator()(size_t i, const ObjectRef& value) const {  // NOLINT(*)
     values_[i].v_handle = value.data_.data_;
-    type_codes_[i] = kObjectCell;
+    type_codes_[i] = kObjectHandle;
   }
   void operator()(size_t i, const TVMRetValue& value) const {  // NOLINT(*)
     if (value.type_code() == kStr) {
index 895c72d..22fb6c3 100644 (file)
@@ -33,6 +33,7 @@ from .types import TVMValue, TypeCode
 from .types import TVMPackedCFunc, TVMCFuncFinalizer
 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
 from .node import NodeBase
+from . import object as _object
 from . import node as _node
 
 FunctionHandle = ctypes.c_void_p
@@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
             temp_args.append(arg)
         elif isinstance(arg, _CLASS_OBJECT):
             values[i].v_handle = arg.handle
-            type_codes[i] = TypeCode.OBJECT_CELL
+            type_codes[i] = TypeCode.OBJECT_HANDLE
         else:
             raise TypeError("Don't know how to handle type %s" % type(arg))
     return values, type_codes, num_args
@@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
         raise get_last_ffi_error()
     _ = temp_args
     _ = args
-    assert ret_tcode.value == TypeCode.NODE_HANDLE
+    assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
     handle = ret_val.v_handle
     return handle
 
@@ -247,6 +248,7 @@ def _handle_return_func(x):
 
 # setup return handle for function type
 _node.__init_by_constructor__ = __init_handle_by_constructor__
+_object.__init_by_constructor__ = __init_handle_by_constructor__
 RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
 RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
 RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py
new file mode 100644 (file)
index 0000000..5ddceb1
--- /dev/null
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Runtime Object api"""
+from __future__ import absolute_import
+
+import ctypes
+from ..base import _LIB, check_call
+from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
+
+
+ObjectHandle = ctypes.c_void_p
+__init_by_constructor__ = None
+
+"""Maps object type to its constructor"""
+OBJECT_TYPE = {}
+
+def _register_object(index, cls):
+    """register object class"""
+    OBJECT_TYPE[index] = cls
+
+
+def _return_object(x):
+    handle = x.v_handle
+    if not isinstance(handle, ObjectHandle):
+        handle = ObjectHandle(handle)
+    tindex = ctypes.c_uint()
+    check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
+    cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
+    # Avoid calling __init__ of cls, instead directly call __new__
+    # This allows child class to implement their own __init__
+    obj = cls.__new__(cls)
+    obj.handle = handle
+    return obj
+
+RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
+C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
+    _return_object, TypeCode.OBJECT_HANDLE)
+
+
+class ObjectBase(object):
+    """Base object for all object types"""
+    __slots__ = ["handle"]
+
+    def __del__(self):
+        if _LIB is not None:
+            check_call(_LIB.TVMObjectFree(self.handle))
+
+    def __init_handle_by_constructor__(self, fconstructor, *args):
+        """Initialize the handle by calling constructor function.
+
+        Parameters
+        ----------
+        fconstructor : Function
+            Constructor function.
+
+        args: list of objects
+            The arguments to the constructor
+
+        Note
+        ----
+        We have a special calling convention to call constructor functions.
+        So the return handle is directly set into the Node object
+        instead of creating a new Node.
+        """
+        # assign handle first to avoid error raising
+        self.handle = None
+        handle = __init_by_constructor__(fconstructor, args)
+        if not isinstance(handle, ObjectHandle):
+            handle = ObjectHandle(handle)
+        self.handle = handle
diff --git a/python/tvm/_ffi/_ctypes/vmobj.py b/python/tvm/_ffi/_ctypes/vmobj.py
deleted file mode 100644 (file)
index 59930e5..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name
-"""Runtime Object api"""
-from __future__ import absolute_import
-
-import ctypes
-from ..base import _LIB, check_call
-from .types import TypeCode, RETURN_SWITCH
-
-ObjectHandle = ctypes.c_void_p
-
-"""Maps object type to its constructor"""
-OBJECT_TYPE = {}
-
-def _register_object(index, cls):
-    """register object class"""
-    OBJECT_TYPE[index] = cls
-
-
-def _return_object(x):
-    handle = x.v_handle
-    if not isinstance(handle, ObjectHandle):
-        handle = ObjectHandle(handle)
-    tag = ctypes.c_int()
-    check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
-    cls = OBJECT_TYPE.get(tag.value, ObjectBase)
-    obj = cls(handle)
-    return obj
-
-RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
-
-
-class ObjectBase(object):
-    __slots__ = ["handle"]
-
-    def __init__(self, handle):
-        self.handle = handle
index 63130ef..76fa963 100644 (file)
@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
     kStr = 11
     kBytes = 12
     kNDArrayContainer = 13
-    kObjectCell = 14
+    kObjectHandle = 14
     kExtBegin = 15
 
 cdef extern from "tvm/runtime/c_runtime_api.h":
@@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
     int TVMArrayToDLPack(DLTensorHandle arr_from,
                          DLManagedTensor** out)
     void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
-    int TVMGetObjectTag(ObjectHandle obj, int* tag)
+    int TVMObjectFree(ObjectHandle obj)
+    int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)
+
 
 cdef extern from "tvm/c_dsl_api.h":
     int TVMNodeFree(NodeHandle handle)
index 4b8536c..a934933 100644 (file)
@@ -16,7 +16,8 @@
 # under the License.
 
 include "./base.pxi"
+include "./object.pxi"
 include "./node.pxi"
 include "./function.pxi"
 include "./ndarray.pxi"
-include "./vmobj.pxi"
+
index cf1884c..ceacf74 100644 (file)
@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
         if (tcode == kNodeHandle or
             tcode == kFuncHandle or
             tcode == kModuleHandle or
-            tcode == kObjectCell or
+            tcode == kObjectHandle or
             tcode > kExtBegin):
             CALL(TVMCbArgToReturn(&value, tcode))
 
@@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
         value[0].v_handle = (<NodeBase>arg).chandle
         tcode[0] = kNodeHandle
         temp_args.append(arg)
+    elif isinstance(arg, _CLASS_OBJECT):
+        value[0].v_handle = (<ObjectBase>arg).chandle
+        tcode[0] = kObjectHandle
     elif isinstance(arg, _CLASS_MODULE):
         value[0].v_handle = c_handle(arg.handle)
         tcode[0] = kModuleHandle
-    elif isinstance(arg, _CLASS_OBJECT):
-        value[0].v_handle = c_handle(arg.handle)
-        tcode[0] = kObjectCell
     elif isinstance(arg, FunctionBase):
         value[0].v_handle = (<FunctionBase>arg).chandle
         tcode[0] = kFuncHandle
@@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
     """convert result to return value."""
     if tcode == kNodeHandle:
         return make_ret_node(value.v_handle)
+    elif tcode == kObjectHandle:
+        return make_ret_object(value.v_handle)
     elif tcode == kNull:
         return None
     elif tcode == kInt:
@@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
         fobj = _CLASS_FUNCTION(None, False)
         (<FunctionBase>fobj).chandle = value.v_handle
         return fobj
-    elif tcode == kObjectCell:
-        return make_ret_object(value.v_handle)
     elif tcode in _TVM_EXT_RET:
         return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
 
similarity index 53%
rename from python/tvm/_ffi/_cython/vmobj.pxi
rename to python/tvm/_ffi/_cython/object.pxi
index 9b48756..90be6a9 100644 (file)
@@ -19,7 +19,7 @@
 OBJECT_TYPE = []
 
 def _register_object(int index, object cls):
-    """register node class"""
+    """register object class"""
     while len(OBJECT_TYPE) <= index:
         OBJECT_TYPE.append(None)
     OBJECT_TYPE[index] = cls
@@ -27,41 +27,70 @@ def _register_object(int index, object cls):
 
 cdef inline object make_ret_object(void* chandle):
     global OBJECT_TYPE
-    cdef int tag
+    cdef unsigned tindex
     cdef list object_type
     cdef object cls
     cdef object handle
     object_type = OBJECT_TYPE
     handle = ctypes_handle(chandle)
-    CALL(TVMGetObjectTag(chandle, &tag))
-    if tag < len(object_type):
-        cls = object_type[tag]
+    CALL(TVMObjectGetTypeIndex(chandle, &tindex))
+    if tindex < len(object_type):
+        cls = object_type[tindex]
         if cls is not None:
-            obj = cls(handle)
+            obj = cls.__new__(cls)
         else:
-            obj = ObjectBase(handle)
+            obj = ObjectBase.__new__(ObjectBase)
     else:
-        obj = ObjectBase(handle)
+        obj = ObjectBase.__new__(ObjectBase)
+    (<ObjectBase>obj).chandle = chandle
     return obj
 
 
 cdef class ObjectBase:
-    cdef ObjectHandle chandle
+    cdef void* chandle
 
     cdef inline _set_handle(self, handle):
+        cdef unsigned long long ptr
         if handle is None:
             self.chandle = NULL
         else:
-            self.chandle = c_handle(handle)
+            ptr = handle.value
+            self.chandle = <void*>(ptr)
 
     property handle:
         def __get__(self):
             if self.chandle == NULL:
                 return None
             else:
-                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
+                return ctypes_handle(self.chandle)
+
         def __set__(self, value):
             self._set_handle(value)
 
-    def __init__(self, handle):
-        self._set_handle(handle)
+    def __dealloc__(self):
+        CALL(TVMObjectFree(self.chandle))
+
+    def __init_handle_by_constructor__(self, fconstructor, *args):
+        """Initialize the handle by calling constructor function.
+
+        Parameters
+        ----------
+        fconstructor : Function
+            Constructor function.
+
+        args: list of objects
+            The arguments to the constructor
+
+        Note
+        ----
+        We have a special calling convention to call constructor functions.
+        So the return handle is directly set into the Node object
+        instead of creating a new Node.
+        """
+        # avoid error raised during construction.
+        self.chandle = NULL
+        cdef void* chandle
+        ConstructorCall(
+            (<FunctionBase>fconstructor).chandle,
+            kObjectHandle, args, &chandle)
+        self.chandle = chandle
index 4bb3182..60e7aeb 100644 (file)
@@ -22,7 +22,6 @@ from __future__ import absolute_import
 import sys
 import ctypes
 from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
-from . import vmobj as _vmobj
 
 IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
 
diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py
new file mode 100644 (file)
index 0000000..be8b086
--- /dev/null
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Runtime Object API"""
+from __future__ import absolute_import
+
+import sys
+import ctypes
+from .base import _FFI_MODE, check_call, _LIB, c_str
+
+IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
+
+try:
+    # pylint: disable=wrong-import-position
+    if _FFI_MODE == "ctypes":
+        raise ImportError()
+    if sys.version_info >= (3, 0):
+        from ._cy3.core import _set_class_object
+        from ._cy3.core import ObjectBase as _ObjectBase
+        from ._cy3.core import _register_object
+    else:
+        from ._cy2.core import _set_class_object
+        from ._cy2.core import ObjectBase as _ObjectBase
+        from ._cy2.core import _register_object
+except IMPORT_EXCEPT:
+    # pylint: disable=wrong-import-position
+    from ._ctypes.function import _set_class_object
+    from ._ctypes.object import ObjectBase as _ObjectBase
+    from ._ctypes.object import _register_object
+
+
+class Object(_ObjectBase):
+    """Base class for all tvm's runtime objects."""
+    pass
+
+
+def register_object(type_key=None):
+    """register object type.
+
+    Parameters
+    ----------
+    type_key : str or cls
+        The type key of the node
+
+    Examples
+    --------
+    The following code registers MyObject
+    using type key "test.MyObject"
+
+    .. code-block:: python
+
+      @tvm.register_object("test.MyObject")
+      class MyObject(Object):
+          pass
+    """
+    object_name = type_key if isinstance(type_key, str) else type_key.__name__
+
+    def register(cls):
+        """internal register function"""
+        if hasattr(cls, "_type_index"):
+            tindex = cls._type_index
+        else:
+            tidx = ctypes.c_uint()
+            check_call(_LIB.TVMObjectTypeKey2Index(
+                c_str(object_name), ctypes.byref(tidx)))
+            tindex = tidx.value
+        _register_object(tindex, cls)
+        return cls
+
+    if isinstance(type_key, str):
+        return register
+
+    return register(type_key)
+
+
+def getitem_helper(obj, elem_getter, length, idx):
+    """Helper function to implement a pythonic getitem function.
+
+    Parameters
+    ----------
+    obj: object
+        The original object
+
+    elem_getter : function
+        A simple function that takes index and return a single element.
+
+    length : int
+        The size of the array
+
+    idx : int or slice
+        The argument passed to getitem
+
+    Returns
+    -------
+    result : object
+        The result of getitem
+    """
+    if isinstance(idx, slice):
+        start = idx.start if idx.start is not None else 0
+        stop = idx.stop if idx.stop is not None else length
+        step = idx.step if idx.step is not None else 1
+        if start < 0:
+            start += length
+        if stop < 0:
+            stop += length
+        return [elem_getter(obj, i) for i in range(start, stop, step)]
+
+    if idx < -length or idx >= length:
+        raise IndexError("Index out of range. size: {}, got index {}"
+                         .format(length, idx))
+    if idx < 0:
+        idx += length
+    return elem_getter(obj, idx)
+
+
+_set_class_object(Object)
index 0d28abd..00e1945 100644 (file)
@@ -42,7 +42,7 @@ class TypeCode(object):
     STR = 11
     BYTES = 12
     NDARRAY_CONTAINER = 13
-    OBJECT_CELL = 14
+    OBJECT_HANDLE = 14
     EXT_BEGIN = 15
 
 
diff --git a/python/tvm/_ffi/vmobj.py b/python/tvm/_ffi/vmobj.py
deleted file mode 100644 (file)
index ea3431a..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name
-"""Runtime Object api"""
-from __future__ import absolute_import
-
-import sys
-from .base import _FFI_MODE
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
-
-try:
-    # pylint: disable=wrong-import-position
-    if _FFI_MODE == "ctypes":
-        raise ImportError()
-    if sys.version_info >= (3, 0):
-        from ._cy3.core import _set_class_object
-        from ._cy3.core import ObjectBase as _ObjectBase
-        from ._cy3.core import _register_object
-    else:
-        from ._cy2.core import _set_class_object
-        from ._cy2.core import ObjectBase as _ObjectBase
-        from ._cy2.core import _register_object
-except IMPORT_EXCEPT:
-    # pylint: disable=wrong-import-position
-    from ._ctypes.function import _set_class_object
-    from ._ctypes.vmobj import ObjectBase as _ObjectBase
-    from ._ctypes.vmobj import _register_object
-
-
-class ObjectTag(object):
-    """Type code used in API calls"""
-    TENSOR = 1
-    CLOSURE = 2
-    DATATYPE = 3
-
-
-class Object(_ObjectBase):
-    """The VM Object used in Relay virtual machine."""
-
-
-def register_object(cls):
-    _register_object(cls.tag, cls)
-    return cls
-
-
-_set_class_object(Object)
index e7523bd..f0261be 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
 from numbers import Integral as _Integral
 
 from ._ffi.base import string_types
+from ._ffi.object import register_object, Object
 from ._ffi.node import register_node, NodeBase
 from ._ffi.node import convert_to_node as _convert_to_node
 from ._ffi.node_generic import _scalar_type_inference
index e54629d..c24b16c 100644 (file)
@@ -30,9 +30,12 @@ from . import _vm
 from . import vmobj as _obj
 from .interpreter import Executor
 
+Tensor = _obj.Tensor
+Datatype = _obj.Datatype
+
 def _convert(arg, cargs):
     if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
-        cargs.append(_obj.tensor_object(arg))
+        cargs.append(_obj.Tensor(arg))
     elif isinstance(arg, (tuple, list)):
         field_args = []
         for field in arg:
index 4c92e9b..939b122 100644 (file)
 from __future__ import absolute_import as _abs
 import numpy as _np
 
-from tvm._ffi.vmobj import Object, ObjectTag, register_object
+from tvm._ffi.object import Object, register_object, getitem_helper
 from tvm import ndarray as _nd
 from . import _vmobj
 
-# TODO(@icemelon9): Add ClosureObject
 
-@register_object
-class TensorObject(Object):
-    """Tensor object."""
-    tag = ObjectTag.TENSOR
+@register_object("vm.Tensor")
+class Tensor(Object):
+    """Tensor object.
 
-    def __init__(self, handle):
-        """Constructs a Tensor object
-
-        Parameters
-        ----------
-        handle : object
-            Object handle
+    Parameters
+    ----------
+    arr : numpy.ndarray or tvm.nd.NDArray
+        The source array.
 
-        Returns
-        -------
-        obj : TensorObject
-            A tensor object.
-        """
-        super(TensorObject, self).__init__(handle)
-        self.data = _vmobj.GetTensorData(self)
+    ctx :  TVMContext, optional
+        The device context to create the array
+    """
+    def __init__(self, arr, ctx=None):
+        if isinstance(arr, _np.ndarray):
+            ctx = ctx if ctx else _nd.cpu(0)
+            self.__init_handle_by_constructor__(
+                _vmobj.Tensor, _nd.array(arr, ctx=ctx))
+        elif isinstance(arr, _nd.NDArray):
+            self.__init_handle_by_constructor__(
+                _vmobj.Tensor, arr)
+        else:
+            raise RuntimeError("Unsupported type for tensor object.")
+
+    @property
+    def data(self):
+        return _vmobj.GetTensorData(self)
 
     def asnumpy(self):
         """Convert data to numpy array
@@ -56,65 +61,34 @@ class TensorObject(Object):
         return self.data.asnumpy()
 
 
-@register_object
-class DatatypeObject(Object):
-    """Datatype object."""
-    tag = ObjectTag.DATATYPE
+@register_object("vm.Datatype")
+class Datatype(Object):
+    """Datatype object.
 
-    def __init__(self, handle):
-        """Constructs a Datatype object
+    Parameters
+    ----------
+    tag : int
+        The tag of datatype.
 
-        Parameters
-        ----------
-        handle : object
-            Object handle
+    fields : list[Object] or tuple[Object]
+        The source tuple.
+    """
+    def __init__(self, tag, fields):
+        for f in fields:
+            assert isinstance(f, Object)
+        self.__init_handle_by_constructor__(
+            _vmobj.Datatype, tag, *fields)
 
-        Returns
-        -------
-        obj : DatatypeObject
-            A Datatype object.
-        """
-        super(DatatypeObject, self).__init__(handle)
-        self.tag = _vmobj.GetDatatypeTag(self)
-        num_fields = _vmobj.GetDatatypeNumberOfFields(self)
-        self.fields = []
-        for i in range(num_fields):
-            self.fields.append(_vmobj.GetDatatypeFields(self, i))
+    @property
+    def tag(self):
+        return _vmobj.GetDatatypeTag(self)
 
     def __getitem__(self, idx):
-        return self.fields[idx]
+        return getitem_helper(
+            self, _vmobj.GetDatatypeFields, len(self), idx)
 
     def __len__(self):
-        return len(self.fields)
-
-    def __iter__(self):
-        return iter(self.fields)
-
-# TODO(icemelon9): Add closure object
-
-def tensor_object(arr, ctx=_nd.cpu(0)):
-    """Create a tensor object from source arr.
-
-    Parameters
-    ----------
-    arr : numpy.ndarray or tvm.nd.NDArray
-        The source array.
-
-    ctx :  TVMContext, optional
-        The device context to create the array
-
-    Returns
-    -------
-    ret : TensorObject
-        The created object.
-    """
-    if isinstance(arr, _np.ndarray):
-        tensor = _vmobj.Tensor(_nd.array(arr, ctx))
-    elif isinstance(arr, _nd.NDArray):
-        tensor = _vmobj.Tensor(arr)
-    else:
-        raise RuntimeError("Unsupported type for tensor object.")
-    return tensor
+        return _vmobj.GetDatatypeNumberOfFields(self)
 
 
 def tuple_object(fields):
@@ -127,30 +101,9 @@ def tuple_object(fields):
 
     Returns
     -------
-    ret : DatatypeObject
+    ret : Datatype
         The created object.
     """
     for f in fields:
         assert isinstance(f, Object)
     return _vmobj.Tuple(*fields)
-
-
-def datatype_object(tag, fields):
-    """Create a datatype object from tag and source fields.
-
-    Parameters
-    ----------
-    tag : int
-        The tag of datatype.
-
-    fields : list[Object] or tuple[Object]
-        The source tuple.
-
-    Returns
-    -------
-    ret : DatatypeObject
-        The created object.
-    """
-    for f in fields:
-        assert isinstance(f, Object)
-    return _vmobj.Datatype(tag, *fields)
index e45c89a..bf90926 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 20793b4..74f0f3e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 5248da0..a52a9b3 100644 (file)
@@ -26,6 +26,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include "runtime_base.h"
 
 namespace tvm {
 namespace runtime {
@@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) {
 uint32_t Object::TypeKey2Index(const char* key) {
   return TypeContext::Global()->TypeKey2Index(key);
 }
+
+class TVMObjectCAPI {
+ public:
+  static void Free(TVMObjectHandle obj) {
+    static_cast<Object*>(obj)->DecRef();
+  }
+
+  static uint32_t TypeKey2Index(const char* type_key) {
+    return Object::TypeKey2Index(type_key);
+  }
+};
 }  // namespace runtime
 }  // namespace tvm
+
+int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
+  API_BEGIN();
+  out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
+  API_END();
+}
+
+int TVMObjectFree(TVMObjectHandle obj) {
+  API_BEGIN();
+  tvm::runtime::TVMObjectCAPI::Free(obj);
+  API_END();
+}
+
+int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
+  API_BEGIN();
+  out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
+      type_key);
+  API_END();
+}
index 6432bbd..c2cbbff 100644 (file)
@@ -47,9 +47,9 @@ def convert_to_list(x):
     return x
 
 def vmobj_to_list(o):
-    if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
+    if isinstance(o, tvm.relay.backend.vmobj.Tensor):
         return [o.asnumpy().tolist()]
-    elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
+    elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
         result = []
         for f in o:
             result.extend(vmobj_to_list(f))
index 5289fe9..cedbc4f 100644 (file)
@@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
         return ret
 
 def vmobj_to_list(o):
-    if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
+    if isinstance(o, tvm.relay.backend.vm.Tensor):
         return [o.asnumpy().tolist()]
-    elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
+    elif isinstance(o, tvm.relay.backend.vm.Datatype):
         result = []
         for f in o:
             result.extend(vmobj_to_list(f))
diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py
new file mode 100644 (file)
index 0000000..ad21fff
--- /dev/null
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+from tvm.relay import vm
+
+def test_tensor():
+    arr = tvm.nd.array([1,2,3])
+    x = vm.Tensor(arr)
+    assert isinstance(x, vm.Tensor)
+    assert x.asnumpy()[0] == 1
+    assert x.asnumpy()[-1] == 3
+    assert isinstance(x.data, tvm.nd.NDArray)
+
+
+def test_datatype():
+    arr = tvm.nd.array([1,2,3])
+    x = vm.Tensor(arr)
+    y = vm.Datatype(0, [x, x])
+
+    assert len(y) == 2
+    assert isinstance(y, vm.Datatype)
+    y[0:1][-1].data == x.data
+    assert y.tag == 0
+    assert isinstance(x.data, tvm.nd.NDArray)
+
+
+
+if __name__ == "__main__":
+    test_tensor()
+    test_datatype()