[REFACTOR] Replace TensorObj and TensorValue with NDArray (#4643)
authorZhi <5145158+zhiics@users.noreply.github.com>
Sat, 11 Jan 2020 00:44:16 +0000 (16:44 -0800)
committerJared Roesch <jroesch@octoml.ai>
Sat, 11 Jan 2020 00:44:16 +0000 (16:44 -0800)
* replace TensorObj and TensorValue with NDArray

* NodeBase to Object in Python

* rebase

55 files changed:
docs/api/python/dev.rst
docs/dev/codebase_walkthrough.rst
include/tvm/relay/interpreter.h
include/tvm/runtime/vm.h
python/tvm/__init__.py
python/tvm/_ffi/_ctypes/function.py
python/tvm/_ffi/_ctypes/object.py
python/tvm/_ffi/_cython/function.pxi
python/tvm/_ffi/_cython/object.pxi
python/tvm/_ffi/function.py
python/tvm/_ffi/node.py [deleted file]
python/tvm/_ffi/object.py
python/tvm/_ffi/object_generic.py [moved from python/tvm/_ffi/node_generic.py with 82% similarity]
python/tvm/api.py
python/tvm/arith.py
python/tvm/attrs.py
python/tvm/build_module.py
python/tvm/container.py
python/tvm/expr.py
python/tvm/ir_builder.py
python/tvm/object.py [moved from python/tvm/node.py with 93% similarity]
python/tvm/relay/_module.pyi
python/tvm/relay/adt.py
python/tvm/relay/backend/compile_engine.py
python/tvm/relay/backend/interpreter.py
python/tvm/relay/backend/vm.py
python/tvm/relay/backend/vmobj.py
python/tvm/relay/base.py
python/tvm/relay/expr.pyi
python/tvm/relay/quantize/quantize.py
python/tvm/relay/testing/py_converter.py
python/tvm/relay/transform.pyi
python/tvm/relay/ty.pyi
python/tvm/schedule.py
python/tvm/stmt.py
python/tvm/target.py
python/tvm/tensor.py
python/tvm/tensor_intrin.py
src/relay/backend/interpreter.cc
src/relay/backend/vm/compiler.cc
src/relay/pass/fold_constant.cc
src/relay/pass/partial_eval.cc
src/runtime/vm/executable.cc
src/runtime/vm/object.cc
src/runtime/vm/vm.cc
tests/python/frontend/tensorflow/test_control_flow.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_adt.py
tests/python/relay/test_backend_interpreter.py
tests/python/relay/test_py_converter.py
tests/python/relay/test_vm.py
tests/python/relay/test_vm_object.py
tests/python/unittest/test_pass_inject_double_buffer.py
tests/python/unittest/test_pass_inject_vthread.py
tests/python/unittest/test_pass_storage_flatten.py

index 7bb938c..8a0a705 100644 (file)
@@ -20,17 +20,14 @@ Developer API
 This page contains modules that are used by developers of TVM.
 Many of these APIs are PackedFunc registered in C++ backend.
 
-tvm.node
-~~~~~~~~
-.. automodule:: tvm.node
-
-.. autoclass:: tvm.node.NodeBase
-    :members:
+tvm.object
+~~~~~~~~~~
+.. automodule:: tvm.object
 
-.. autoclass:: tvm.node.Node
+.. autoclass:: tvm.object.Object
     :members:
 
-.. autofunction:: tvm.register_node
+.. autofunction:: tvm.register_object
 
 tvm.expr
 ~~~~~~~~
index 19f185e..0732c26 100644 (file)
@@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
 
-Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``.
+Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
 
 ::
 
-   @register_node
-   class Tensor(NodeBase, _expr.ExprOp):
+   @register_object
+   class Tensor(Object, _expr.ExprOp):
        """Tensor object, to construct, see function.Tensor"""
 
        def __call__(self, *indices):
           ...
 
-The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
+The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
 
 ``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
 
index 8ef7f6e..dc35fc2 100644 (file)
 #include <tvm/build_module.h>
 #include <tvm/relay/module.h>
 #include <tvm/relay/expr.h>
+#include <tvm/runtime/object.h>
 
 namespace tvm {
 namespace relay {
 
 /*!
- * \brief A Relay value.
- */
-class Value;
-
-/*!
  *\brief Create a Interpreter function that can
  *  evaluate an expression and produce a value.
  *
@@ -65,39 +61,21 @@ class Value;
  * \param target Compiler target flag to compile the functions on the context.
  * \return A function that takes in an expression and returns a value.
  */
-runtime::TypedPackedFunc<Value(Expr)>
+runtime::TypedPackedFunc<ObjectRef(Expr)>
 CreateInterpreter(Module mod, DLContext context, Target target);
 
-/*! \brief The base container type of Relay values. */
-class ValueNode : public RelayNode {
- public:
-  static constexpr const char* _type_key = "relay.Value";
-  TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
-};
-
-class Value : public ObjectRef {
- public:
-  Value() {}
-  explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
-  const ValueNode* operator->() const {
-    return static_cast<const ValueNode*>(get());
-  }
-
-  using ContainerType = ValueNode;
-};
-
 /*! \brief A Relay closure, i.e a scope and a function. */
 class Closure;
 
 /*! \brief The container type of Closures. */
-class ClosureNode : public ValueNode {
+class ClosureNode : public Object {
  public:
   /*! \brief The set of free variables in the closure.
    *
    * These are the captured variables which are required for
    * evaluation when we call the closure.
    */
-  tvm::Map<Var, Value> env;
+  tvm::Map<Var, ObjectRef> env;
   /*! \brief The function which implements the closure.
    *
    * \note May reference the variables contained in the env.
@@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
     v->Visit("func", &func);
   }
 
-  TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
+  TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);
 
   static constexpr const char* _type_key = "relay.Closure";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
 };
 
-class Closure : public Value {
+class Closure : public ObjectRef {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
 };
 
 /*! \brief A Relay Recursive Closure. A closure that has a name. */
 class RecClosure;
 
 /*! \brief The container type of RecClosure. */
-class RecClosureNode : public ValueNode {
+class RecClosureNode : public Object {
  public:
   /*! \brief The closure. */
   Closure clos;
@@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode {
   TVM_DLL static RecClosure make(Closure clos, Var bind);
 
   static constexpr const char* _type_key = "relay.RecClosure";
-  TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
 };
 
-class RecClosure : public Value {
+class RecClosure : public ObjectRef {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
 };
 
 /*! \brief A tuple value. */
 class TupleValue;
 
 /*! \brief Tuple (x, ... y). */
-struct TupleValueNode : ValueNode {
-  tvm::Array<Value> fields;
+struct TupleValueNode : Object {
+  tvm::Array<ObjectRef> fields;
 
   TupleValueNode() {}
 
   void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
 
-  TVM_DLL static TupleValue make(tvm::Array<Value> value);
+  TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
 
   static constexpr const char* _type_key = "relay.TupleValue";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
-};
-
-class TupleValue : public Value {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
-};
-
-/*! \brief A tensor value. */
-class TensorValue;
-
-/*! \brief The tensor value container, wrapping an NDArray. */
-struct TensorValueNode : ValueNode {
-  runtime::NDArray data;
-
-  TensorValueNode() {}
-
-  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
-
-  /*! \brief Build a value from an NDArray. */
-  TVM_DLL static TensorValue make(runtime::NDArray data);
-
-  static constexpr const char* _type_key = "relay.TensorValue";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
 };
 
-class TensorValue : public Value {
+class TupleValue : public ObjectRef {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
 };
 
 /*! \brief A reference value. */
 class RefValue;
 
-struct RefValueNode : ValueNode {
-  mutable Value value;
+struct RefValueNode : Object {
+  mutable ObjectRef value;
 
   RefValueNode() {}
 
@@ -208,24 +163,24 @@ struct RefValueNode : ValueNode {
     v->Visit("value", &value);
   }
 
-  TVM_DLL static RefValue make(Value val);
+  TVM_DLL static RefValue make(ObjectRef val);
 
   static constexpr const char* _type_key = "relay.RefValue";
-  TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
 };
 
-class RefValue : public Value {
+class RefValue : public ObjectRef {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
 };
 
 /*! \brief An ADT constructor value. */
 class ConstructorValue;
 
-struct ConstructorValueNode : ValueNode {
+struct ConstructorValueNode : Object {
   int32_t tag;
 
-  tvm::Array<Value> fields;
+  tvm::Array<ObjectRef> fields;
 
   /*! \brief Optional field tracking ADT constructor. */
   Constructor constructor;
@@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
   }
 
   TVM_DLL static ConstructorValue make(int32_t tag,
-                                       tvm::Array<Value> fields,
+                                       tvm::Array<ObjectRef> fields,
                                        Constructor construtor = {});
 
   static constexpr const char* _type_key = "relay.ConstructorValue";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
 };
 
-class ConstructorValue : public Value {
+class ConstructorValue : public ObjectRef {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
 };
 
 }  // namespace relay
index 59e9ae8..990ecf5 100644 (file)
@@ -36,25 +36,6 @@ namespace tvm {
 namespace runtime {
 namespace vm {
 
-/*! \brief An object containing an NDArray. */
-class TensorObj : public Object {
- public:
-  /*! \brief The NDArray. */
-  NDArray data;
-
-  static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
-  static constexpr const char* _type_key = "vm.Tensor";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
-};
-
-/*! \brief reference to tensor. */
-class Tensor : public ObjectRef {
- public:
-  explicit Tensor(NDArray data);
-
-  TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
-};
-
 /*! \brief An object representing a closure. */
 class ClosureObj : public Object {
  public:
index 9e3eb0f..b2a4ca3 100644 (file)
@@ -34,7 +34,7 @@ from . import codegen
 from . import container
 from . import schedule
 from . import module
-from . import node
+from . import object
 from . import attrs
 from . import ir_builder
 from . import target
@@ -55,7 +55,7 @@ from ._ffi.base import TVMError, __version__
 from .api import *
 from .intrin import *
 from .tensor_intrin import decl_tensor_intrin
-from .node import register_node
+from .object import register_object
 from .ndarray import register_extension
 from .schedule import create_schedule
 from .build_module import build, lower, build_config
index 2f0b5ba..45048c5 100644 (file)
@@ -25,14 +25,14 @@ from numbers import Number, Integral
 
 from ..base import _LIB, get_last_ffi_error, py2cerror
 from ..base import c_str, string_types
-from ..node_generic import convert_to_node, NodeGeneric
+from ..object_generic import convert_to_object, ObjectGeneric
 from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
 from . import ndarray as _nd
 from .ndarray import NDArrayBase, _make_array
 from .types import TVMValue, TypeCode
 from .types import TVMPackedCFunc, TVMCFuncFinalizer
 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
-from .object import ObjectBase, _set_class_node
+from .object import ObjectBase, _set_class_object
 from . import object as _object
 
 FunctionHandle = ctypes.c_void_p
@@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
         elif isinstance(arg, string_types):
             values[i].v_str = c_str(arg)
             type_codes[i] = TypeCode.STR
-        elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
-            arg = convert_to_node(arg)
+        elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
+            arg = convert_to_object(arg)
             values[i].v_handle = arg.handle
             type_codes[i] = TypeCode.OBJECT_HANDLE
             temp_args.append(arg)
@@ -256,7 +256,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
 
 _CLASS_MODULE = None
 _CLASS_FUNCTION = None
-_CLASS_OBJECT = None
 
 def _set_class_module(module_class):
     """Initialize the module."""
@@ -266,7 +265,3 @@ def _set_class_module(module_class):
 def _set_class_function(func_class):
     global _CLASS_FUNCTION
     _CLASS_FUNCTION = func_class
-
-def _set_class_object(obj_class):
-    global _CLASS_OBJECT
-    _CLASS_OBJECT = obj_class
index b8b8aef..8a2fb1b 100644 (file)
@@ -30,11 +30,11 @@ __init_by_constructor__ = None
 """Maps object type to its constructor"""
 OBJECT_TYPE = {}
 
-_CLASS_NODE = None
+_CLASS_OBJECT = None
 
-def _set_class_node(node_class):
-    global _CLASS_NODE
-    _CLASS_NODE = node_class
+def _set_class_object(object_class):
+    global _CLASS_OBJECT
+    _CLASS_OBJECT = object_class
 
 
 def _register_object(index, cls):
@@ -51,7 +51,7 @@ def _return_object(x):
         handle = ObjectHandle(handle)
     tindex = ctypes.c_uint()
     check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
-    cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE)
+    cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
     # Avoid calling __init__ of cls, instead directly call __new__
     # This allows child class to implement their own __init__
     obj = cls.__new__(cls)
index a236042..7789769 100644 (file)
@@ -20,7 +20,7 @@ import traceback
 from cpython cimport Py_INCREF, Py_DECREF
 from numbers import Number, Integral
 from ..base import string_types, py2cerror
-from ..node_generic import convert_to_node, NodeGeneric
+from ..object_generic import convert_to_object, ObjectGeneric
 from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
 
 
@@ -149,8 +149,8 @@ cdef inline int make_arg(object arg,
         value[0].v_str = tstr
         tcode[0] = kStr
         temp_args.append(tstr)
-    elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
-        arg = convert_to_node(arg)
+    elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
+        arg = convert_to_object(arg)
         value[0].v_handle = (<ObjectBase>arg).chandle
         tcode[0] = kObjectHandle
         temp_args.append(arg)
@@ -308,7 +308,6 @@ cdef class FunctionBase:
 _CLASS_FUNCTION = None
 _CLASS_MODULE = None
 _CLASS_OBJECT = None
-_CLASS_NODE = None
 
 def _set_class_module(module_class):
     """Initialize the module."""
@@ -322,7 +321,3 @@ def _set_class_function(func_class):
 def _set_class_object(obj_class):
     global _CLASS_OBJECT
     _CLASS_OBJECT = obj_class
-
-def _set_class_node(node_class):
-    global _CLASS_NODE
-    _CLASS_NODE = node_class
index 6d20723..1392f99 100644 (file)
@@ -32,7 +32,7 @@ def _register_object(int index, object cls):
 
 cdef inline object make_ret_object(void* chandle):
     global OBJECT_TYPE
-    global _CLASS_NODE
+    global _CLASS_OBJECT
     cdef unsigned tindex
     cdef object cls
     cdef object handle
@@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle):
         if cls is not None:
             obj = cls.__new__(cls)
         else:
-            # default use node base class
-            # TODO(tqchen) change to object after Node unifies with Object
-            obj = _CLASS_NODE.__new__(_CLASS_NODE)
+            obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
     else:
-        obj = _CLASS_NODE.__new__(_CLASS_NODE)
+        obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
     (<ObjectBase>obj).chandle = chandle
     return obj
 
index 23d95eb..22e0356 100644 (file)
@@ -22,7 +22,7 @@ from __future__ import absolute_import
 import sys
 import ctypes
 from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
-from .node_generic import _set_class_objects
+from .object_generic import _set_class_objects
 
 IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
 
diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py
deleted file mode 100644 (file)
index c6c151a..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Node namespace"""
-# pylint: disable=unused-import
-from __future__ import absolute_import
-
-import ctypes
-import sys
-from .. import _api_internal
-from .object import Object, register_object, _set_class_node
-from .node_generic import NodeGeneric, convert_to_node, const
-
-
-def _new_object(cls):
-    """Helper function for pickle"""
-    return cls.__new__(cls)
-
-
-class NodeBase(Object):
-    """NodeBase is the base class of all TVM language AST object."""
-    def __repr__(self):
-        return _api_internal._format_str(self)
-
-    def __dir__(self):
-        fnames = _api_internal._NodeListAttrNames(self)
-        size = fnames(-1)
-        return [fnames(i) for i in range(size)]
-
-    def __getattr__(self, name):
-        try:
-            return _api_internal._NodeGetAttr(self, name)
-        except AttributeError:
-            raise AttributeError(
-                "%s has no attribute %s" % (str(type(self)), name))
-
-    def __hash__(self):
-        return _api_internal._raw_ptr(self)
-
-    def __eq__(self, other):
-        return self.same_as(other)
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
-
-    def __reduce__(self):
-        cls = type(self)
-        return (_new_object, (cls, ), self.__getstate__())
-
-    def __getstate__(self):
-        handle = self.handle
-        if handle is not None:
-            return {'handle': _api_internal._save_json(self)}
-        return {'handle': None}
-
-    def __setstate__(self, state):
-        # pylint: disable=assigning-non-slot
-        handle = state['handle']
-        if handle is not None:
-            json_str = handle
-            other = _api_internal._load_json(json_str)
-            self.handle = other.handle
-            other.handle = None
-        else:
-            self.handle = None
-
-    def same_as(self, other):
-        """check object identity equality"""
-        if not isinstance(other, NodeBase):
-            return False
-        return self.__hash__() == other.__hash__()
-
-
-# pylint: disable=invalid-name
-register_node = register_object
-_set_class_node(NodeBase)
index 002fd27..83d4129 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, unused-import
 """Runtime Object API"""
 from __future__ import absolute_import
 
 import sys
 import ctypes
+from .. import _api_internal
 from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
+from .object_generic import ObjectGeneric, convert_to_object, const
 
 IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
 
@@ -29,23 +31,77 @@ try:
     if _FFI_MODE == "ctypes":
         raise ImportError()
     if sys.version_info >= (3, 0):
-        from ._cy3.core import _set_class_object, _set_class_node
+        from ._cy3.core import _set_class_object
         from ._cy3.core import ObjectBase as _ObjectBase
         from ._cy3.core import _register_object
     else:
-        from ._cy2.core import _set_class_object, _set_class_node
+        from ._cy2.core import _set_class_object
         from ._cy2.core import ObjectBase as _ObjectBase
         from ._cy2.core import _register_object
 except IMPORT_EXCEPT:
     # pylint: disable=wrong-import-position,unused-import
-    from ._ctypes.function import _set_class_object, _set_class_node
+    from ._ctypes.function import _set_class_object
     from ._ctypes.object import ObjectBase as _ObjectBase
     from ._ctypes.object import _register_object
 
 
+def _new_object(cls):
+    """Helper function for pickle"""
+    return cls.__new__(cls)
+
+
 class Object(_ObjectBase):
     """Base class for all tvm's runtime objects."""
-    pass
+    def __repr__(self):
+        return _api_internal._format_str(self)
+
+    def __dir__(self):
+        fnames = _api_internal._NodeListAttrNames(self)
+        size = fnames(-1)
+        return [fnames(i) for i in range(size)]
+
+    def __getattr__(self, name):
+        try:
+            return _api_internal._NodeGetAttr(self, name)
+        except AttributeError:
+            raise AttributeError(
+                "%s has no attribute %s" % (str(type(self)), name))
+
+    def __hash__(self):
+        return _api_internal._raw_ptr(self)
+
+    def __eq__(self, other):
+        return self.same_as(other)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __reduce__(self):
+        cls = type(self)
+        return (_new_object, (cls, ), self.__getstate__())
+
+    def __getstate__(self):
+        handle = self.handle
+        if handle is not None:
+            return {'handle': _api_internal._save_json(self)}
+        return {'handle': None}
+
+    def __setstate__(self, state):
+        # pylint: disable=assigning-non-slot
+        handle = state['handle']
+        if handle is not None:
+            json_str = handle
+            other = _api_internal._load_json(json_str)
+            self.handle = other.handle
+            other.handle = None
+        else:
+            self.handle = None
+
+    def same_as(self, other):
+        """check object identity equality"""
+        if not isinstance(other, Object):
+            return False
+        return self.__hash__() == other.__hash__()
 
 
 def register_object(type_key=None):
similarity index 82%
rename from python/tvm/_ffi/node_generic.py
rename to python/tvm/_ffi/object_generic.py
index 8ee7fc5..92e73ad 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Common implementation of Node generic related logic"""
+"""Common implementation of object generic related logic"""
 # pylint: disable=unused-import
 from __future__ import absolute_import
 
@@ -22,7 +22,7 @@ from numbers import Number, Integral
 from .. import _api_internal
 from .base import string_types
 
-# Node base class
+# Object base class
 _CLASS_OBJECTS = None
 
 def _set_class_objects(cls):
@@ -47,15 +47,15 @@ def _scalar_type_inference(value):
     return dtype
 
 
-class NodeGeneric(object):
-    """Base class for all classes that can be converted to node."""
-    def asnode(self):
-        """Convert value to node"""
+class ObjectGeneric(object):
+    """Base class for all classes that can be converted to object."""
+    def asobject(self):
+        """Convert value to object"""
         raise NotImplementedError()
 
 
-def convert_to_node(value):
-    """Convert a python value to corresponding node type.
+def convert_to_object(value):
+    """Convert a python value to corresponding object type.
 
     Parameters
     ----------
@@ -64,8 +64,8 @@ def convert_to_node(value):
 
     Returns
     -------
-    node : Node
-        The corresponding node value.
+    obj : Object
+        The corresponding object value.
     """
     if isinstance(value, _CLASS_OBJECTS):
         return value
@@ -76,7 +76,7 @@ def convert_to_node(value):
     if isinstance(value, string_types):
         return _api_internal._str(value)
     if isinstance(value, (list, tuple)):
-        value = [convert_to_node(x) for x in value]
+        value = [convert_to_object(x) for x in value]
         return _api_internal._Array(*value)
     if isinstance(value, dict):
         vlist = []
@@ -85,14 +85,14 @@ def convert_to_node(value):
                     not isinstance(item[0], string_types)):
                 raise ValueError("key of map must already been a container type")
             vlist.append(item[0])
-            vlist.append(convert_to_node(item[1]))
+            vlist.append(convert_to_object(item[1]))
         return _api_internal._Map(*vlist)
-    if isinstance(value, NodeGeneric):
-        return value.asnode()
+    if isinstance(value, ObjectGeneric):
+        return value.asobject()
     if value is None:
         return None
 
-    raise ValueError("don't know how to convert type %s to node" % type(value))
+    raise ValueError("don't know how to convert type %s to object" % type(value))
 
 
 def const(value, dtype=None):
index 4d0e347..7395d35 100644 (file)
@@ -22,9 +22,8 @@ from numbers import Integral as _Integral
 
 from ._ffi.base import string_types
 from ._ffi.object import register_object, Object
-from ._ffi.node import register_node, NodeBase
-from ._ffi.node import convert_to_node as _convert_to_node
-from ._ffi.node_generic import _scalar_type_inference
+from ._ffi.object import convert_to_object as _convert_to_object
+from ._ffi.object_generic import _scalar_type_inference
 from ._ffi.function import Function
 from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
 from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
@@ -111,7 +110,7 @@ def get_env_func(name):
 
     Note
     ----
-    EnvFunc is a Node wrapper around
+    EnvFunc is a Object wrapper around
     global function that can be serialized via its name.
     This can be used to serialize function field in the language.
     """
@@ -127,16 +126,16 @@ def convert(value):
 
     Returns
     -------
-    tvm_val : Node or Function
+    tvm_val : Object or Function
         Converted value in TVM
     """
-    if isinstance(value, (Function, NodeBase)):
+    if isinstance(value, (Function, Object)):
         return value
 
     if callable(value):
         return _convert_tvm_func(value)
 
-    return _convert_to_node(value)
+    return _convert_to_object(value)
 
 
 def load_json(json_str):
@@ -149,7 +148,7 @@ def load_json(json_str):
 
     Returns
     -------
-    node : Node
+    node : Object
         The loaded tvm node.
     """
     return _api_internal._load_json(json_str)
@@ -160,8 +159,8 @@ def save_json(node):
 
     Parameters
     ----------
-    node : Node
-        A TVM Node object to be saved.
+    node : Object
+        A TVM object to be saved.
 
     Returns
     -------
index 4c3c05f..81f478c 100644 (file)
 """Arithmetic data structure and utility"""
 from __future__ import absolute_import as _abs
 
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 from ._ffi.function import _init_api
 from . import _api_internal
 
-class IntSet(NodeBase):
+class IntSet(Object):
     """Represent a set of integer in one dimension."""
     def is_nothing(self):
         """Whether the set represent nothing"""
@@ -32,7 +32,7 @@ class IntSet(NodeBase):
         return _api_internal._IntSetIsEverything(self)
 
 
-@register_node("arith.IntervalSet")
+@register_object("arith.IntervalSet")
 class IntervalSet(IntSet):
     """Represent set of continuous interval [min_value, max_value]
 
@@ -49,16 +49,16 @@ class IntervalSet(IntSet):
             _make_IntervalSet, min_value, max_value)
 
 
-@register_node("arith.ModularSet")
-class ModularSet(NodeBase):
+@register_object("arith.ModularSet")
+class ModularSet(Object):
     """Represent range of (coeff * x + base) for x in Z """
     def __init__(self, coeff, base):
         self.__init_handle_by_constructor__(
             _make_ModularSet, coeff, base)
 
 
-@register_node("arith.ConstIntBound")
-class ConstIntBound(NodeBase):
+@register_object("arith.ConstIntBound")
+class ConstIntBound(Object):
     """Represent constant integer bound
 
     Parameters
@@ -245,7 +245,7 @@ class Analyzer:
         var : tvm.Var
             The variable.
 
-        info : tvm.NodeBase
+        info : tvm.Object
             Related information.
 
         override : bool
index e2a2732..2963a0e 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """ TVM Attribute module, which is mainly used for defining attributes of operators"""
-from ._ffi.node import NodeBase, register_node as _register_tvm_node
+from ._ffi.object import Object, register_object
 from ._ffi.function import _init_api
 from . import _api_internal
 
 
-@_register_tvm_node
-class Attrs(NodeBase):
+@register_object
+class Attrs(Object):
     """Attribute node, which is mainly use for defining attributes of relay operators.
 
     Used by function registered in python side, such as compute, schedule and alter_layout.
index f96e283..85d2b85 100644 (file)
@@ -23,7 +23,7 @@ from __future__ import absolute_import as _abs
 import warnings
 
 from ._ffi.function import Function
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 from . import api
 from . import _api_internal
 from . import tensor
@@ -115,22 +115,22 @@ class DumpIR(object):
         DumpIR.scope_level -= 1
 
 
-@register_node
-class BuildConfig(NodeBase):
+@register_object
+class BuildConfig(Object):
     """Configuration scope to set a build config option.
 
     Note
     ----
-    This object is backed by node system in C++, with arguments that can be
+    This object is backed by object protocol in C++, with arguments that can be
     exchanged between python and C++.
 
     Do not construct directly, use build_config instead.
 
-    The fields that are backed by the C++ node are immutable once an instance
-    is constructed. See _node_defaults for the fields.
+    The fields that are backed by the C++ object are immutable once an instance
+    is constructed. See _object_defaults for the fields.
     """
 
-    _node_defaults = {
+    _object_defaults = {
         "auto_unroll_max_step": 0,
         "auto_unroll_max_depth": 8,
         "auto_unroll_max_extent": 0,
@@ -191,7 +191,7 @@ class BuildConfig(NodeBase):
         _api_internal._ExitBuildConfigScope(self)
 
     def __setattr__(self, name, value):
-        if name in BuildConfig._node_defaults:
+        if name in BuildConfig._object_defaults:
             raise AttributeError(
                 "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
         return super(BuildConfig, self).__setattr__(name, value)
@@ -257,7 +257,7 @@ def build_config(**kwargs):
         The build configuration
     """
     node_args = {k: v if k not in kwargs else kwargs[k]
-                 for k, v in BuildConfig._node_defaults.items()}
+                 for k, v in BuildConfig._object_defaults.items()}
     config = make.node("BuildConfig", **node_args)
 
     if "add_lower_pass" in kwargs:
index aedbe95..274fc1f 100644 (file)
 # under the License.
 """Container data structures used in TVM DSL."""
 from __future__ import absolute_import as _abs
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 from . import _api_internal
 
-@register_node
-class Array(NodeBase):
+@register_object
+class Array(Object):
     """Array container of TVM.
 
     You do not need to create Array explicitly.
@@ -50,8 +50,8 @@ class Array(NodeBase):
         return _api_internal._ArraySize(self)
 
 
-@register_node
-class EnvFunc(NodeBase):
+@register_object
+class EnvFunc(Object):
     """Environment function.
 
     This is a global function object that can be serialized by its name.
@@ -64,13 +64,13 @@ class EnvFunc(NodeBase):
         return _api_internal._EnvFuncGetPackedFunc(self)
 
 
-@register_node
-class Map(NodeBase):
+@register_object
+class Map(Object):
     """Map container of TVM.
 
     You do not need to create Map explicitly.
     Normally python dict will be converted automaticall to Map during tvm function call.
-    You can use convert to create a dict[NodeBase-> NodeBase] into a Map
+    You can use convert to create a dict[Object-> Object] into a Map
     """
     def __getitem__(self, k):
         return _api_internal._MapGetItem(self, k)
@@ -87,11 +87,11 @@ class Map(NodeBase):
         return _api_internal._MapSize(self)
 
 
-@register_node
+@register_object
 class StrMap(Map):
     """A special map container that has str as key.
 
-    You can use convert to create a dict[str->NodeBase] into a Map.
+    You can use convert to create a dict[str->Object] into a Map.
     """
     def items(self):
         """Get the items from the map"""
@@ -99,8 +99,8 @@ class StrMap(Map):
         return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
 
 
-@register_node
-class Range(NodeBase):
+@register_object
+class Range(Object):
     """Represent a range in TVM.
 
     You do not need to create a Range explicitly.
@@ -108,8 +108,8 @@ class Range(NodeBase):
     """
 
 
-@register_node
-class LoweredFunc(NodeBase):
+@register_object
+class LoweredFunc(Object):
     """Represent a LoweredFunc in TVM."""
     MixedFunc = 0
     HostFunc = 1
index c6b3d9b..71c0aec 100644 (file)
@@ -32,7 +32,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
 """
 # pylint: disable=missing-docstring
 from __future__ import absolute_import as _abs
-from ._ffi.node import NodeBase, NodeGeneric, register_node
+from ._ffi.object import Object, register_object, ObjectGeneric
 from ._ffi.runtime_ctypes import TVMType, TypeCode
 from . import make as _make
 from . import generic as _generic
@@ -178,11 +178,11 @@ class ExprOp(object):
         return _generic.cast(self, dtype)
 
 
-class EqualOp(NodeGeneric, ExprOp):
+class EqualOp(ObjectGeneric, ExprOp):
     """Deferred equal operator.
 
     This is used to support sugar that a == b can either
-    mean NodeBase.same_as or NodeBase.equal.
+    mean Object.same_as or Object.equal.
 
     Parameters
     ----------
@@ -205,16 +205,16 @@ class EqualOp(NodeGeneric, ExprOp):
     def __bool__(self):
         return self.__nonzero__()
 
-    def asnode(self):
-        """Convert node."""
+    def asobject(self):
+        """Convert object."""
         return _make._OpEQ(self.a, self.b)
 
 
-class NotEqualOp(NodeGeneric, ExprOp):
+class NotEqualOp(ObjectGeneric, ExprOp):
     """Deferred NE operator.
 
     This is used to support sugar that a != b can either
-    mean not NodeBase.same_as or make.NE.
+    mean not Object.same_as or make.NE.
 
     Parameters
     ----------
@@ -237,16 +237,16 @@ class NotEqualOp(NodeGeneric, ExprOp):
     def __bool__(self):
         return self.__nonzero__()
 
-    def asnode(self):
-        """Convert node."""
+    def asobject(self):
+        """Convert object."""
         return _make._OpNE(self.a, self.b)
 
 
-class PrimExpr(ExprOp, NodeBase):
+class PrimExpr(ExprOp, Object):
     """Base class of all tvm Expressions"""
     # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
     # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
-    __hash__ = NodeBase.__hash__
+    __hash__ = Object.__hash__
 
 
 class ConstExpr(PrimExpr):
@@ -261,7 +261,7 @@ class CmpExpr(PrimExpr):
 class LogicalExpr(PrimExpr):
     pass
 
-@register_node("Variable")
+@register_object("Variable")
 class Var(PrimExpr):
     """Symbolic variable.
 
@@ -278,7 +278,7 @@ class Var(PrimExpr):
             _api_internal._Var, name, dtype)
 
 
-@register_node
+@register_object
 class Reduce(PrimExpr):
     """Reduce node.
 
@@ -305,7 +305,7 @@ class Reduce(PrimExpr):
             condition, value_index)
 
 
-@register_node
+@register_object
 class FloatImm(ConstExpr):
     """Float constant.
 
@@ -321,7 +321,7 @@ class FloatImm(ConstExpr):
         self.__init_handle_by_constructor__(
             _make.FloatImm, dtype, value)
 
-@register_node
+@register_object
 class IntImm(ConstExpr):
     """Int constant.
 
@@ -341,7 +341,7 @@ class IntImm(ConstExpr):
         return self.value
 
 
-@register_node
+@register_object
 class UIntImm(ConstExpr):
     """UInt constant.
 
@@ -358,7 +358,7 @@ class UIntImm(ConstExpr):
             _make.UIntImm, dtype, value)
 
 
-@register_node
+@register_object
 class StringImm(ConstExpr):
     """String constant.
 
@@ -382,7 +382,7 @@ class StringImm(ConstExpr):
         return self.value != other
 
 
-@register_node
+@register_object
 class Cast(PrimExpr):
     """Cast expression.
 
@@ -399,7 +399,7 @@ class Cast(PrimExpr):
             _make.Cast, dtype, value)
 
 
-@register_node
+@register_object
 class Add(BinaryOpExpr):
     """Add node.
 
@@ -416,7 +416,7 @@ class Add(BinaryOpExpr):
             _make.Add, a, b)
 
 
-@register_node
+@register_object
 class Sub(BinaryOpExpr):
     """Sub node.
 
@@ -433,7 +433,7 @@ class Sub(BinaryOpExpr):
             _make.Sub, a, b)
 
 
-@register_node
+@register_object
 class Mul(BinaryOpExpr):
     """Mul node.
 
@@ -450,7 +450,7 @@ class Mul(BinaryOpExpr):
             _make.Mul, a, b)
 
 
-@register_node
+@register_object
 class Div(BinaryOpExpr):
     """Div node.
 
@@ -467,7 +467,7 @@ class Div(BinaryOpExpr):
             _make.Div, a, b)
 
 
-@register_node
+@register_object
 class Mod(BinaryOpExpr):
     """Mod node.
 
@@ -484,7 +484,7 @@ class Mod(BinaryOpExpr):
             _make.Mod, a, b)
 
 
-@register_node
+@register_object
 class FloorDiv(BinaryOpExpr):
     """FloorDiv node.
 
@@ -501,7 +501,7 @@ class FloorDiv(BinaryOpExpr):
             _make.FloorDiv, a, b)
 
 
-@register_node
+@register_object
 class FloorMod(BinaryOpExpr):
     """FloorMod node.
 
@@ -518,7 +518,7 @@ class FloorMod(BinaryOpExpr):
             _make.FloorMod, a, b)
 
 
-@register_node
+@register_object
 class Min(BinaryOpExpr):
     """Min node.
 
@@ -535,7 +535,7 @@ class Min(BinaryOpExpr):
             _make.Min, a, b)
 
 
-@register_node
+@register_object
 class Max(BinaryOpExpr):
     """Max node.
 
@@ -552,7 +552,7 @@ class Max(BinaryOpExpr):
             _make.Max, a, b)
 
 
-@register_node
+@register_object
 class EQ(CmpExpr):
     """EQ node.
 
@@ -569,7 +569,7 @@ class EQ(CmpExpr):
             _make.EQ, a, b)
 
 
-@register_node
+@register_object
 class NE(CmpExpr):
     """NE node.
 
@@ -586,7 +586,7 @@ class NE(CmpExpr):
             _make.NE, a, b)
 
 
-@register_node
+@register_object
 class LT(CmpExpr):
     """LT node.
 
@@ -603,7 +603,7 @@ class LT(CmpExpr):
             _make.LT, a, b)
 
 
-@register_node
+@register_object
 class LE(CmpExpr):
     """LE node.
 
@@ -620,7 +620,7 @@ class LE(CmpExpr):
             _make.LE, a, b)
 
 
-@register_node
+@register_object
 class GT(CmpExpr):
     """GT node.
 
@@ -637,7 +637,7 @@ class GT(CmpExpr):
             _make.GT, a, b)
 
 
-@register_node
+@register_object
 class GE(CmpExpr):
     """GE node.
 
@@ -654,7 +654,7 @@ class GE(CmpExpr):
             _make.GE, a, b)
 
 
-@register_node
+@register_object
 class And(LogicalExpr):
     """And node.
 
@@ -671,7 +671,7 @@ class And(LogicalExpr):
             _make.And, a, b)
 
 
-@register_node
+@register_object
 class Or(LogicalExpr):
     """Or node.
 
@@ -688,7 +688,7 @@ class Or(LogicalExpr):
             _make.Or, a, b)
 
 
-@register_node
+@register_object
 class Not(LogicalExpr):
     """Not node.
 
@@ -702,7 +702,7 @@ class Not(LogicalExpr):
             _make.Not, a)
 
 
-@register_node
+@register_object
 class Select(PrimExpr):
     """Select node.
 
@@ -730,7 +730,7 @@ class Select(PrimExpr):
             _make.Select, condition, true_value, false_value)
 
 
-@register_node
+@register_object
 class Load(PrimExpr):
     """Load node.
 
@@ -753,7 +753,7 @@ class Load(PrimExpr):
             _make.Load, dtype, buffer_var, index, predicate)
 
 
-@register_node
+@register_object
 class Ramp(PrimExpr):
     """Ramp node.
 
@@ -773,7 +773,7 @@ class Ramp(PrimExpr):
             _make.Ramp, base, stride, lanes)
 
 
-@register_node
+@register_object
 class Broadcast(PrimExpr):
     """Broadcast node.
 
@@ -790,7 +790,7 @@ class Broadcast(PrimExpr):
             _make.Broadcast, value, lanes)
 
 
-@register_node
+@register_object
 class Shuffle(PrimExpr):
     """Shuffle node.
 
@@ -807,7 +807,7 @@ class Shuffle(PrimExpr):
             _make.Shuffle, vectors, indices)
 
 
-@register_node
+@register_object
 class Call(PrimExpr):
     """Call node.
 
@@ -842,7 +842,7 @@ class Call(PrimExpr):
             _make.Call, dtype, name, args, call_type, func, value_index)
 
 
-@register_node
+@register_object
 class Let(PrimExpr):
     """Let node.
 
index bf41c98..ede17a1 100644 (file)
@@ -24,7 +24,7 @@ from . import make as _make
 from . import ir_pass as _pass
 from . import container as _container
 from ._ffi.base import string_types
-from ._ffi.node import NodeGeneric
+from ._ffi.object import ObjectGeneric
 from ._ffi.runtime_ctypes import TVMType
 from .expr import Call as _Call
 
@@ -41,7 +41,7 @@ class WithScope(object):
         self._exit_cb()
 
 
-class BufferVar(NodeGeneric):
+class BufferVar(ObjectGeneric):
     """Buffer variable with content type, makes load store easily.
 
     Do not create it directly, create use IRBuilder.
@@ -70,7 +70,7 @@ class BufferVar(NodeGeneric):
         self._buffer_var = buffer_var
         self._content_type = content_type
 
-    def asnode(self):
+    def asobject(self):
         return self._buffer_var
 
     @property
similarity index 93%
rename from python/tvm/node.py
rename to python/tvm/object.py
index 1d5b506..9659d3c 100644 (file)
@@ -20,6 +20,4 @@ Normally user do not need to touch this api.
 """
 # pylint: disable=unused-import
 from __future__ import absolute_import as _abs
-from ._ffi.node import NodeBase, register_node
-
-Node = NodeBase
+from ._ffi.object import Object, register_object
index ae2d199..66c994e 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 
 from typing import Union, Tuple, Dict, List
-from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
+from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId
 from relay.ir import ShapeExtension, Operator, Defn
 
-class Module(NodeBase): ...
+class Module(Object): ...
index 30db22c..7f7496b 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """Algebraic data types in Relay."""
-from .base import RelayNode, register_relay_node, NodeBase
+from .base import RelayNode, register_relay_node, Object
 from . import _make
 from .ty import Type
 from .expr import Expr, Call
@@ -184,7 +184,7 @@ class TypeData(Type):
 
 
 @register_relay_node
-class Clause(NodeBase):
+class Clause(Object):
     """Clause for pattern matching in Relay."""
 
     def __init__(self, lhs, rhs):
index 6c690a9..956ad55 100644 (file)
 """Backend code generation engine."""
 from __future__ import absolute_import
 
-from ..base import register_relay_node, NodeBase
+from ..base import register_relay_node, Object
 from ... import target as _target
 from .. import expr as _expr
 from . import _backend
 
 @register_relay_node
-class CachedFunc(NodeBase):
+class CachedFunc(Object):
     """Low-level tensor function to back a relay primitive function.
     """
 
 
 @register_relay_node
-class CCacheKey(NodeBase):
+class CCacheKey(Object):
     """Key in the CompileEngine.
 
     Parameters
@@ -46,7 +46,7 @@ class CCacheKey(NodeBase):
 
 
 @register_relay_node
-class CCacheValue(NodeBase):
+class CCacheValue(Object):
     """Value in the CompileEngine, including usage statistics.
     """
 
@@ -64,7 +64,7 @@ def _get_cache_key(source_func, target):
 
 
 @register_relay_node
-class CompileEngine(NodeBase):
+class CompileEngine(Object):
     """CompileEngine to get lowered code.
     """
     def __init__(self):
index 1d53f6a..59d9a8f 100644 (file)
@@ -23,27 +23,13 @@ import numpy as np
 from . import _backend
 from .. import _make, analysis, transform
 from .. import module
-from ... import register_func, nd
-from ..base import NodeBase, register_relay_node
+from ... import nd
+from ..base import Object, register_relay_node
 from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
 from ..scope_builder import ScopeBuilder
-from . import _vm
-
-class Value(NodeBase):
-    """Base class of all values.
-    """
-    @staticmethod
-    @register_func("relay.from_scalar")
-    def from_scalar(value, dtype=None):
-        """Convert a Python scalar to a Relay scalar."""
-        return TensorValue(const(value, dtype).data)
-
-    def to_vm(self):
-        return _vm._ValueToVM(self)
-
 
 @register_relay_node
-class TupleValue(Value):
+class TupleValue(Object):
     """A tuple value produced by the interpreter."""
     def __init__(self, *fields):
         self.__init_handle_by_constructor__(
@@ -68,60 +54,32 @@ class TupleValue(Value):
 
 
 @register_relay_node
-class Closure(Value):
+class Closure(Object):
     """A closure produced by the interpreter."""
 
 
 @register_relay_node
-class RecClosure(Value):
+class RecClosure(Object):
     """A recursive closure produced by the interpreter."""
 
 
 @register_relay_node
-class ConstructorValue(Value):
+class ConstructorValue(Object):
     def __init__(self, tag, fields, constructor):
         self.__init_handle_by_constructor__(
             _make.ConstructorValue, tag, fields, constructor)
 
 
 @register_relay_node
-class TensorValue(Value):
-    """A Tensor value produced by the interpreter."""
-
-    def __init__(self, data):
-        """Allocate a new TensorValue and copy the data from `array` into
-           the new array.
-        """
-        if isinstance(data, np.ndarray):
-            data = nd.array(data)
-
-        self.__init_handle_by_constructor__(
-            _make.TensorValue, data)
-
-    def asnumpy(self):
-        """Convert a Relay TensorValue into a numpy.ndarray."""
-        return self.data.asnumpy()
-
-    def __eq__(self, other):
-        return self.data == other.data
-
-    def __repr__(self):
-        return repr(self.data)
-
-    def __str__(self):
-        return str(self.data)
-
-
-@register_relay_node
-class RefValue(Value):
+class RefValue(Object):
     def __init__(self, value):
         self.__init_handle_by_constructor__(
             _make.RefValue, value)
 
 
 def _arg_to_ast(mod, arg):
-    if isinstance(arg, TensorValue):
-        return Constant(arg.data.copyto(nd.cpu(0)))
+    if isinstance(arg, nd.NDArray):
+        return Constant(arg.copyto(nd.cpu(0)))
     elif isinstance(arg, TupleValue):
         return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
     elif isinstance(arg, tuple):
@@ -231,7 +189,7 @@ class Executor(object):
 
         Returns
         -------
-        val : Union[function, Value]
+        val : Union[function, Object]
             The evaluation result.
         """
         if binds:
index bad4ac2..aba55ef 100644 (file)
@@ -31,16 +31,18 @@ from . import _vm
 from . import vmobj as _obj
 from .interpreter import Executor
 
-Tensor = _obj.Tensor
 ADT = _obj.ADT
 
 def _convert(arg, cargs):
     if isinstance(arg, _expr.Constant):
-        cargs.append(_obj.Tensor(arg.data))
+        cargs.append(arg.data)
     elif isinstance(arg, _obj.Object):
         cargs.append(arg)
-    elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
-        cargs.append(_obj.Tensor(arg))
+    elif isinstance(arg, np.ndarray):
+        nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
+        cargs.append(nd_arr)
+    elif isinstance(arg, tvm.nd.NDArray):
+        cargs.append(arg)
     elif isinstance(arg, (tuple, list)):
         field_args = []
         for field in arg:
@@ -48,7 +50,7 @@ def _convert(arg, cargs):
         cargs.append(_obj.tuple_object(field_args))
     elif isinstance(arg, (_base.numeric_types, bool)):
         dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
-        value = _obj.Tensor(np.array(arg, dtype=dtype))
+        value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
         cargs.append(value)
     else:
         raise TypeError("Unsupported type: %s" % (type(arg)))
index f3fdb76..330257f 100644 (file)
 # under the License.
 """TVM Runtime Object API."""
 from __future__ import absolute_import as _abs
-import numpy as _np
 
 from tvm._ffi.object import Object, register_object, getitem_helper
 from tvm import ndarray as _nd
 from . import _vmobj
 
 
-@register_object("vm.Tensor")
-class Tensor(Object):
-    """Tensor object.
-
-    Parameters
-    ----------
-    arr : numpy.ndarray or tvm.nd.NDArray
-        The source array.
-
-    ctx :  TVMContext, optional
-        The device context to create the array
-    """
-    def __init__(self, arr, ctx=None):
-        if isinstance(arr, _np.ndarray):
-            ctx = ctx if ctx else _nd.cpu(0)
-            self.__init_handle_by_constructor__(
-                _vmobj.Tensor, _nd.array(arr, ctx=ctx))
-        elif isinstance(arr, _nd.NDArray):
-            self.__init_handle_by_constructor__(
-                _vmobj.Tensor, arr)
-        else:
-            raise RuntimeError("Unsupported type for tensor object.")
-
-    @property
-    def data(self):
-        return _vmobj.GetTensorData(self)
-
-    def asnumpy(self):
-        """Convert data to numpy array
-
-        Returns
-        -------
-        np_arr : numpy.ndarray
-            The corresponding numpy array.
-        """
-        return self.data.asnumpy()
-
-
 @register_object("vm.ADT")
 class ADT(Object):
     """Algebatic data type(ADT) object.
@@ -75,7 +36,8 @@ class ADT(Object):
     """
     def __init__(self, tag, fields):
         for f in fields:
-            assert isinstance(f, Object)
+            assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
+            "tvm NDArray type, but received : {0}".format(type(f))
         self.__init_handle_by_constructor__(
             _vmobj.ADT, tag, *fields)
 
@@ -105,5 +67,6 @@ def tuple_object(fields):
         The created object.
     """
     for f in fields:
-        assert isinstance(f, Object)
+        assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
+        "NDArray type, but received : {0}".format(type(f))
     return _vmobj.Tuple(*fields)
index eb604a4..d389803 100644 (file)
 # pylint: disable=no-else-return, unidiomatic-typecheck
 """The base node types for the Relay language."""
 from __future__ import absolute_import as _abs
-from .._ffi.node import NodeBase, register_node as _register_tvm_node
+from .._ffi.object import register_object as _register_tvm_node
+from .._ffi.object import Object
 from . import _make
 from . import _expr
 from . import _base
 
-NodeBase = NodeBase
+Object = Object
 
 def register_relay_node(type_key=None):
     """Register a Relay node type.
@@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None):
     return _register_tvm_node(type_key)
 
 
-class RelayNode(NodeBase):
+class RelayNode(Object):
     """Base class of all Relay nodes."""
     def astext(self, show_meta_data=True, annotate=None):
         """Get the text format of the expression.
@@ -102,7 +103,7 @@ class SourceName(RelayNode):
         self.__init_handle_by_constructor__(_make.SourceName, name)
 
 @register_relay_node
-class Id(NodeBase):
+class Id(Object):
     """Unique identifier(name) used in Var.
        Guaranteed to be stable across all passes.
     """
index d264e99..d2d0172 100644 (file)
 
 from typing import List
 import tvm
-from .base import Span, NodeBase
+from .base import Span, Object
 from .ty import Type, TypeParam
 from ._analysis import _get_checked_type
 
 
-class Expr(NodeBase):
+class Expr(Object):
     def checked_type(self):
         ...
 
index ac5387c..a9d877c 100644 (file)
@@ -22,7 +22,7 @@ from ._calibrate import calibrate
 from .. import expr as _expr
 from .. import transform as _transform
 from ... import make as _make
-from ..base import NodeBase, register_relay_node
+from ..base import Object, register_relay_node
 
 
 class QAnnotateKind(object):
@@ -53,7 +53,7 @@ def _forward_op(ref_call, args):
 
 
 @register_relay_node("relay.quantize.QConfig")
-class QConfig(NodeBase):
+class QConfig(Object):
     """Configure the quantization behavior by setting config variables.
 
     Note
index d7b5992..1edb27a 100644 (file)
@@ -32,15 +32,16 @@ OUTPUT_VAR_NAME = '_py_out'
 #     import numpy
 #     import tvm
 #     from tvm import relay
-#     from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue
+#     from tvm import nd
+#     from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue
 PROLOGUE = [
     ast.Import([alias('numpy', None)]),
     ast.Import([alias('tvm', None)]),
     ast.ImportFrom('tvm', [alias('relay', None)], 0),
+    ast.ImportFrom('tvm', [alias('nd', None)], 0),
     ast.ImportFrom('tvm.relay.backend.interpreter',
                    [alias('RefValue', None),
                     alias('TupleValue', None),
-                    alias('TensorValue', None),
                     alias('ConstructorValue', None)],
                    0)
 ]
@@ -245,7 +246,7 @@ class PythonConverter(ExprFunctor):
                a tensor or tuple (returns list of inputs to the lowered op call)"""
             # equivalent: input.data
             if isinstance(arg_type, relay.TensorType):
-                return [ast.Attribute(py_input, 'data', Load())]
+                return [py_input]
             assert isinstance(arg_type, relay.TupleType)
             # convert each input.fields[i]
             ret = []
@@ -265,15 +266,13 @@ class PythonConverter(ExprFunctor):
                 output_var_name = self.generate_var_name('_out')
                 output_var = Name(output_var_name, Load())
                 shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
-                # create a new TensorValue of the right shape and dtype
+                # create a new NDArray of the right shape and dtype
                 assign_output = Assign(
                     [Name(output_var_name, Store())],
-                    self.create_call('TensorValue', [
+                    self.create_call('nd.array', [
                         self.create_call('numpy.empty', [shape, Str(ret_type.dtype)])
                     ]))
-                # we pass the data field as an argument
-                extra_arg = ast.Attribute(output_var, 'data', Load())
-                return ([assign_output], [extra_arg], output_var)
+                return ([assign_output], [output_var], output_var)
             assert isinstance(ret_type, relay.TupleType)
             assignments = []
             extra_args = []
@@ -459,7 +458,7 @@ class PythonConverter(ExprFunctor):
         true_body, true_defs = self.visit(if_block.true_branch)
         false_body, false_defs = self.visit(if_block.false_branch)
 
-        # need to get the value out of a TensorValue to check the condition
+        # need to get the value out of a NDArray to check the condition
         # equvialent to: val.asnumpy()
         cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], [])
         ret = ast.IfExp(cond_check, true_body, false_body)
@@ -474,7 +473,7 @@ class PythonConverter(ExprFunctor):
         const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
                               [self.parse_numpy_array(value)],
                               [ast.keyword('dtype', Str(constant.checked_type.dtype))])
-        return (self.create_call('TensorValue', [const_expr]), [])
+        return (self.create_call('nd.array', [const_expr]), [])
 
 
     def visit_function(self, func: Expr):
index 343e899..2c466b0 100644 (file)
 # under the License.
 
 import tvm
-from .base import NodeBase
+from .base import Object
 
 
-class PassContext(NodeBase):
+class PassContext(Object):
     def __init__(self):
         ...
 
-class PassInfo(NodeBase):
+class PassInfo(Object):
     name = ...  # type: str
     opt_level = ... # type: int
     required = ... # type: list
@@ -32,7 +32,7 @@ class PassInfo(NodeBase):
         # type: (str, int, list) -> None
 
 
-class Pass(NodeBase):
+class Pass(Object):
     def __init__(self):
         ...
 
index 5a7ecff..cde8511 100644 (file)
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """The type nodes of the Relay language."""
 from enum import IntEnum
-from .base import NodeBase, register_relay_node
+from .base import Object, register_relay_node
 from . import _make
 
 
-class Type(NodeBase):
+class Type(Object):
     """The base type for all Relay types."""
 
     def __eq__(self, other):
index 6b577c4..c8fcd7c 100644 (file)
@@ -17,8 +17,8 @@
 """The computation schedule api of TVM."""
 from __future__ import absolute_import as _abs
 from ._ffi.base import string_types
-from ._ffi.node import NodeBase, register_node
-from ._ffi.node import convert_to_node as _convert_to_node
+from ._ffi.object import Object, register_object
+from ._ffi.object import convert_to_object as _convert_to_object
 from ._ffi.function import _init_api, Function
 from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
 from . import _api_internal
@@ -27,7 +27,7 @@ from . import expr as _expr
 from . import container as _container
 
 def convert(value):
-    """Convert value to TVM node or function.
+    """Convert value to TVM object or function.
 
     Parameters
     ----------
@@ -35,19 +35,19 @@ def convert(value):
 
     Returns
     -------
-    tvm_val : Node or Function
+    tvm_val : Object or Function
         Converted value in TVM
     """
-    if isinstance(value, (Function, NodeBase)):
+    if isinstance(value, (Function, Object)):
         return value
 
     if callable(value):
         return _convert_tvm_func(value)
 
-    return _convert_to_node(value)
+    return _convert_to_object(value)
 
-@register_node
-class Buffer(NodeBase):
+@register_object
+class Buffer(Object):
     """Symbolic data buffer in TVM.
 
     Buffer provide a way to represent data layout
@@ -156,23 +156,23 @@ class Buffer(NodeBase):
         return _api_internal._BufferVStore(self, begin, value)
 
 
-@register_node
-class Split(NodeBase):
+@register_object
+class Split(Object):
     """Split operation on axis."""
 
 
-@register_node
-class Fuse(NodeBase):
+@register_object
+class Fuse(Object):
     """Fuse operation on axis."""
 
 
-@register_node
-class Singleton(NodeBase):
+@register_object
+class Singleton(Object):
     """Singleton axis."""
 
 
-@register_node
-class IterVar(NodeBase, _expr.ExprOp):
+@register_object
+class IterVar(Object, _expr.ExprOp):
     """Represent iteration variable.
 
     IterVar is normally created by Operation, to represent
@@ -214,8 +214,8 @@ def create_schedule(ops):
     return _api_internal._CreateSchedule(ops)
 
 
-@register_node
-class Schedule(NodeBase):
+@register_object
+class Schedule(Object):
     """Schedule for all the stages."""
     def __getitem__(self, k):
         if isinstance(k, _tensor.Tensor):
@@ -348,8 +348,8 @@ class Schedule(NodeBase):
         return factored[0] if len(factored) == 1 else factored
 
 
-@register_node
-class Stage(NodeBase):
+@register_object
+class Stage(Object):
     """A Stage represents schedule for one operation."""
     def split(self, parent, factor=None, nparts=None):
         """Split the stage either by factor providing outer scope, or both
index 64628d1..6b87fcb 100644 (file)
@@ -30,14 +30,14 @@ Each statement node have subfields that can be visited from python side.
     assert(st.buffer_var == a)
 """
 from __future__ import absolute_import as _abs
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 from . import make as _make
 
 
-class Stmt(NodeBase):
+class Stmt(Object):
     pass
 
-@register_node
+@register_object
 class LetStmt(Stmt):
     """LetStmt node.
 
@@ -57,7 +57,7 @@ class LetStmt(Stmt):
             _make.LetStmt, var, value, body)
 
 
-@register_node
+@register_object
 class AssertStmt(Stmt):
     """AssertStmt node.
 
@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
             _make.AssertStmt, condition, message, body)
 
 
-@register_node
+@register_object
 class ProducerConsumer(Stmt):
     """ProducerConsumer node.
 
@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
             _make.ProducerConsumer, func, is_producer, body)
 
 
-@register_node
+@register_object
 class For(Stmt):
     """For node.
 
@@ -137,7 +137,7 @@ class For(Stmt):
             for_type, device_api, body)
 
 
-@register_node
+@register_object
 class Store(Stmt):
     """Store node.
 
@@ -160,7 +160,7 @@ class Store(Stmt):
             _make.Store, buffer_var, value, index, predicate)
 
 
-@register_node
+@register_object
 class Provide(Stmt):
     """Provide node.
 
@@ -183,7 +183,7 @@ class Provide(Stmt):
             _make.Provide, func, value_index, value, args)
 
 
-@register_node
+@register_object
 class Allocate(Stmt):
     """Allocate node.
 
@@ -215,7 +215,7 @@ class Allocate(Stmt):
             extents, condition, body)
 
 
-@register_node
+@register_object
 class AttrStmt(Stmt):
     """AttrStmt node.
 
@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
             _make.AttrStmt, node, attr_key, value, body)
 
 
-@register_node
+@register_object
 class Free(Stmt):
     """Free node.
 
@@ -252,7 +252,7 @@ class Free(Stmt):
             _make.Free, buffer_var)
 
 
-@register_node
+@register_object
 class Realize(Stmt):
     """Realize node.
 
@@ -288,7 +288,7 @@ class Realize(Stmt):
             bounds, condition, body)
 
 
-@register_node
+@register_object
 class SeqStmt(Stmt):
     """Sequence of statements.
 
@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
         return len(self.seq)
 
 
-@register_node
+@register_object
 class IfThenElse(Stmt):
     """IfThenElse node.
 
@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
             _make.IfThenElse, condition, then_case, else_case)
 
 
-@register_node
+@register_object
 class Evaluate(Stmt):
     """Evaluate node.
 
@@ -342,7 +342,7 @@ class Evaluate(Stmt):
             _make.Evaluate, value)
 
 
-@register_node
+@register_object
 class Prefetch(Stmt):
     """Prefetch node.
 
index afddd5f..c2d3752 100644 (file)
@@ -59,7 +59,7 @@ from __future__ import absolute_import
 import warnings
 
 from ._ffi.base import _LIB_NAME
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 from . import _api_internal
 
 try:
@@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts):
     return opts
 
 
-@register_node
-class Target(NodeBase):
+@register_object
+class Target(Object):
     """Target device information, use through TVM API.
 
     Note
@@ -97,7 +97,7 @@ class Target(NodeBase):
     """
     def __new__(cls):
         # Always override new to enable class
-        obj = NodeBase.__new__(cls)
+        obj = Object.__new__(cls)
         obj._keys = None
         obj._options = None
         obj._libs = None
@@ -146,8 +146,8 @@ class Target(NodeBase):
         _api_internal._ExitTargetScope(self)
 
 
-@register_node
-class GenericFunc(NodeBase):
+@register_object
+class GenericFunc(Object):
     """GenericFunc node reference. This represents a generic function
     that may be specialized for different targets. When this object is
     called, a specialization is chosen based on the current target.
index e4a2f4f..e4c36c1 100644 (file)
 """Tensor and Operation class for computation declaration."""
 # pylint: disable=invalid-name
 from __future__ import absolute_import as _abs
-from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
+from ._ffi.object import Object, register_object, ObjectGeneric, \
+        convert_to_object
 from . import _api_internal
 from . import make as _make
 from . import expr as _expr
 
 
-class TensorSlice(NodeGeneric, _expr.ExprOp):
+class TensorSlice(ObjectGeneric, _expr.ExprOp):
     """Auxiliary data structure for enable slicing syntax from tensor."""
 
     def __init__(self, tensor, indices):
@@ -37,8 +38,8 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
             indices = (indices,)
         return TensorSlice(self.tensor, self.indices + indices)
 
-    def asnode(self):
-        """Convert slice to node."""
+    def asobject(self):
+        """Convert slice to object."""
         return self.tensor(*self.indices)
 
     @property
@@ -46,23 +47,23 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
         """Data content of the tensor."""
         return self.tensor.dtype
 
-@register_node
-class TensorIntrinCall(NodeBase):
+@register_object
+class TensorIntrinCall(Object):
     """Intermediate structure for calling a tensor intrinsic."""
 
 
 itervar_cls = None
 
 
-@register_node
-class Tensor(NodeBase, _expr.ExprOp):
+@register_object
+class Tensor(Object, _expr.ExprOp):
     """Tensor object, to construct, see function.Tensor"""
 
     def __call__(self, *indices):
         ndim = self.ndim
         if len(indices) != ndim:
             raise ValueError("Need to provide %d index in tensor slice" % ndim)
-        indices = convert_to_node(indices)
+        indices = convert_to_object(indices)
         args = []
         for x in indices:
             if isinstance(x, _expr.PrimExpr):
@@ -127,7 +128,7 @@ class Tensor(NodeBase, _expr.ExprOp):
 
 
 
-class Operation(NodeBase):
+class Operation(Object):
     """Represent an operation that generates a tensor"""
 
     def output(self, index):
@@ -156,12 +157,12 @@ class Operation(NodeBase):
         return _api_internal._OpInputTensors(self)
 
 
-@register_node
+@register_object
 class PlaceholderOp(Operation):
     """Placeholder operation."""
 
 
-@register_node
+@register_object
 class BaseComputeOp(Operation):
     """Compute operation."""
     @property
@@ -175,18 +176,18 @@ class BaseComputeOp(Operation):
         return self.__getattr__("reduce_axis")
 
 
-@register_node
+@register_object
 class ComputeOp(BaseComputeOp):
     """Scalar operation."""
     pass
 
 
-@register_node
+@register_object
 class TensorComputeOp(BaseComputeOp):
     """Tensor operation."""
 
 
-@register_node
+@register_object
 class ScanOp(Operation):
     """Scan operation."""
     @property
@@ -195,12 +196,12 @@ class ScanOp(Operation):
         return self.__getattr__("scan_axis")
 
 
-@register_node
+@register_object
 class ExternOp(Operation):
     """External operation."""
 
 
-@register_node
+@register_object
 class HybridOp(Operation):
     """Hybrid operation."""
     @property
@@ -209,8 +210,8 @@ class HybridOp(Operation):
         return self.__getattr__("axis")
 
 
-@register_node
-class Layout(NodeBase):
+@register_object
+class Layout(Object):
     """Layout is composed of upper cases, lower cases and numbers,
     where upper case indicates a primal axis and
     the corresponding lower case with factor size indicates the subordinate axis.
@@ -269,8 +270,8 @@ class Layout(NodeBase):
         return _api_internal._LayoutFactorOf(self, axis)
 
 
-@register_node
-class BijectiveLayout(NodeBase):
+@register_object
+class BijectiveLayout(Object):
     """Bijective mapping for two layouts (src-layout and dst-layout).
     It provides shape and index conversion between each other.
 
index 378cfe5..4665ccf 100644 (file)
@@ -24,7 +24,7 @@ from . import make as _make
 from . import tensor as _tensor
 from . import schedule as _schedule
 from .build_module import current_build_config
-from ._ffi.node import NodeBase, register_node
+from ._ffi.object import Object, register_object
 
 
 def _get_region(tslice):
@@ -41,8 +41,8 @@ def _get_region(tslice):
             region.append(_make.range_by_min_extent(begin, 1))
     return region
 
-@register_node
-class TensorIntrin(NodeBase):
+@register_object
+class TensorIntrin(Object):
     """Tensor intrinsic functions for certain computation.
 
     See Also
index c1e4fd5..432ad29 100644 (file)
@@ -43,8 +43,8 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) {
   return *pf;
 }
 
-/* Value Implementation */
-Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
+/* Object Implementation */
+Closure ClosureNode::make(tvm::Map<Var, ObjectRef> env, Function func) {
   ObjectPtr<ClosureNode> n = make_object<ClosureNode>();
   n->env = std::move(env);
   n->func = std::move(func);
@@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 
 
 // TODO(@jroesch): this doesn't support mutual letrec
-/* Value Implementation */
+/* Object Implementation */
 RecClosure RecClosureNode::make(Closure clos, Var bind) {
   ObjectPtr<RecClosureNode> n = make_object<RecClosureNode>();
   n->clos = std::move(clos);
@@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
     p->stream << "RecClosureNode(" << node->clos << ")";
   });
 
-TupleValue TupleValueNode::make(tvm::Array<Value> value) {
+TupleValue TupleValueNode::make(tvm::Array<ObjectRef> value) {
   ObjectPtr<TupleValueNode> n = make_object<TupleValueNode>();
   n->fields = value;
   return TupleValue(n);
@@ -94,24 +94,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
     p->stream << "TupleValueNode(" << node->fields << ")";
   });
 
-TensorValue TensorValueNode::make(runtime::NDArray data) {
-  ObjectPtr<TensorValueNode> n = make_object<TensorValueNode>();
-  n->data = std::move(data);
-  return TensorValue(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
-    auto* node = static_cast<const TensorValueNode*>(ref.get());
-    auto to_str = GetPackedFunc("relay._tensor_value_repr");
-    std::string data_str = to_str(GetRef<TensorValue>(node));
-    p->stream << "TensorValueNode(" << data_str << ")";
-  });
-
-TVM_REGISTER_GLOBAL("relay._make.TensorValue")
-.set_body_typed(TensorValueNode::make);
 
-RefValue RefValueNode::make(Value value) {
+RefValue RefValueNode::make(ObjectRef value) {
   ObjectPtr<RefValueNode> n = make_object<RefValueNode>();
   n->value = value;
   return RefValue(n);
@@ -129,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 ConstructorValue ConstructorValueNode::make(int32_t tag,
-                                            tvm::Array<Value> fields,
+                                            tvm::Array<ObjectRef> fields,
                                             Constructor constructor) {
   ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
   n->tag = tag;
@@ -153,13 +137,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 /*!
  * \brief A stack frame in the Relay interpreter.
  *
- * Contains a mapping from relay::Var to relay::Value.
+ * Contains a mapping from relay::Var to relay::ObjectRef.
  */
 struct Frame {
   /*! \brief The set of local variables and arguments for the frame. */
-  tvm::Map<Var, Value> locals;
+  tvm::Map<Var, ObjectRef> locals;
 
-  explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
+  explicit Frame(tvm::Map<Var, ObjectRef> locals) : locals(locals) {}
 };
 
 /*!
@@ -175,7 +159,7 @@ struct Stack {
 
   Frame& current_frame() { return frames.back(); }
 
-  Value Lookup(const Var& local) {
+  ObjectRef Lookup(const Var& local) {
     for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
       auto elem = frame->locals.find(local);
       if (elem != frame->locals.end()) {
@@ -185,7 +169,7 @@ struct Stack {
 
     LOG(FATAL) << "could not find variable binding for " << local
                << "address= " << local.operator->();
-    return Value();
+    return ObjectRef();
   }
   /*!
    * A wrapper around Frame to add RAII semantics to pushing and popping
@@ -206,7 +190,7 @@ class InterpreterState;
 /*! \brief A container capturing the state of the interpreter. */
 class InterpreterStateNode : public Object {
  public:
-  using Frame = tvm::Map<Var, Value>;
+  using Frame = tvm::Map<Var, ObjectRef>;
   using Stack = tvm::Array<Frame>;
 
   /*! \brief The current expression under evaluation. */
@@ -246,8 +230,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
 //
 // Conversion to ANF is recommended before running the interpretation.
 class Interpreter :
-      public ExprFunctor<Value(const Expr& n)>,
-             PatternFunctor<bool(const Pattern& p, const Value& v)> {
+      public ExprFunctor<ObjectRef(const Expr& n)>,
+             PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
  public:
   Interpreter(Module mod, DLContext context, Target target)
       : mod_(mod),
@@ -264,56 +248,56 @@ class Interpreter :
     return f();
   }
 
-  void extend(const Var& id, Value v) {
+  void extend(const Var& id, ObjectRef v) {
     stack_.current_frame().locals.Set(id, v);
   }
 
-  Value Lookup(const Var& local) {
+  ObjectRef Lookup(const Var& local) {
     return stack_.Lookup(local);
   }
 
-  Value Eval(const Expr& expr) {
+  ObjectRef Eval(const Expr& expr) {
     return VisitExpr(expr);
   }
 
-  Value VisitExpr(const Expr& expr) final {
-    auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
+  ObjectRef VisitExpr(const Expr& expr) final {
+    auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
     return ret;
   }
 
-  Value VisitExpr_(const VarNode* var_node) final {
+  ObjectRef VisitExpr_(const VarNode* var_node) final {
     return Lookup(GetRef<Var>(var_node));
   }
 
-  Value VisitExpr_(const GlobalVarNode* op) final {
+  ObjectRef VisitExpr_(const GlobalVarNode* op) final {
     return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
   }
 
-  Value VisitExpr_(const OpNode* id) override {
+  ObjectRef VisitExpr_(const OpNode* id) override {
     // TODO(@jroesch): Eta-expand and return in this case.
     LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
                << "in "
                << "this case, eta expand";
-    return Value();
+    return ObjectRef();
   }
 
-  Value VisitExpr_(const ConstantNode* op) final {
-    return TensorValueNode::make(op->data.CopyTo(context_));
+  ObjectRef VisitExpr_(const ConstantNode* op) final {
+    return op->data.CopyTo(context_);
   }
 
-  Value VisitExpr_(const TupleNode* op) final {
-    std::vector<Value> values;
+  ObjectRef VisitExpr_(const TupleNode* op) final {
+    std::vector<ObjectRef> values;
 
     for (const auto& field : op->fields) {
-      Value field_value = Eval(field);
+      ObjectRef field_value = Eval(field);
       values.push_back(field_value);
     }
 
     return TupleValueNode::make(values);
   }
 
-  Value MakeClosure(const Function& func, Var letrec_name = Var()) {
-    tvm::Map<Var, Value> captured_mod;
+  ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) {
+    tvm::Map<Var, ObjectRef> captured_mod;
     Array<Var> free_vars = FreeVars(func);
 
     for (const auto& var : free_vars) {
@@ -334,13 +318,13 @@ class Interpreter :
     return std::move(closure);
   }
 
-  Value VisitExpr_(const FunctionNode* func_node) final {
+  ObjectRef VisitExpr_(const FunctionNode* func_node) final {
     auto func = GetRef<Function>(func_node);
     return MakeClosure(func);
   }
 
   Array<Shape> ComputeDynamicShape(const Function& func,
-                                   const Array<Value>& args) {
+                                   const Array<ObjectRef>& args) {
     auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
     auto cfunc = engine_->LowerShapeFunc(key);
     size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
@@ -355,11 +339,10 @@ class Interpreter :
     cpu_ctx.device_type = kDLCPU;
     cpu_ctx.device_id = 0;
 
-    auto fset_input = [&](size_t i, Value val, bool need_shape) {
-        const TensorValueNode* tv = val.as<TensorValueNode>();
-        CHECK(tv != nullptr) << "expect Tensor argument";
+    auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
+        auto nd_array = Downcast<NDArray>(val);
         if (need_shape) {
-          int64_t ndim = tv->data.Shape().size();
+          int64_t ndim = nd_array.Shape().size();
           NDArray shape_arr;
           if (ndim == 0) {
             shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx);
@@ -367,13 +350,13 @@ class Interpreter :
             shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
             int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
             for (auto j = 0; j < ndim; ++j) {
-              data[j] = tv->data.Shape()[j];
+              data[j] = nd_array.Shape()[j];
             }
           }
           inputs[i] = shape_arr;
           setter(i, shape_arr);
         } else {
-          auto arr = tv->data.CopyTo(cpu_ctx);
+          auto arr = nd_array.CopyTo(cpu_ctx);
           inputs[i] = arr;
           setter(i, arr);
         }
@@ -384,7 +367,7 @@ class Interpreter :
       auto arg = args[i];
       auto param = func->params[i];
       int state = cfunc->shape_func_param_states[i]->value;
-      if (arg.as<TensorValueNode>()) {
+      if (arg->IsInstance<runtime::NDArray::ContainerType>()) {
         if (state & kNeedInputData) {
           fset_input(arg_counter++, arg, false);
         }
@@ -457,8 +440,8 @@ class Interpreter :
     return out_shapes;
   }
 
-  Value InvokePrimitiveOp(const Function& func,
-                          const Array<Value>& args) {
+  ObjectRef InvokePrimitiveOp(const Function& func,
+                          const Array<ObjectRef>& args) {
     const auto* call_node = func->body.as<CallNode>();
 
     if (call_node && call_node->op == debug_op_) {
@@ -478,7 +461,7 @@ class Interpreter :
     // Handle tuple input/output by flattening them.
     size_t arg_len = 0;
     for (size_t i = 0; i < args.size(); ++i) {
-      if (args[i].as<TensorValueNode>()) {
+      if (args[i]->IsInstance<NDArray::ContainerType>()) {
         ++arg_len;
       } else {
         const auto* tvalue = args[i].as<TupleValueNode>();
@@ -497,11 +480,10 @@ class Interpreter :
     std::vector<int> codes(arg_len);
     TVMArgsSetter setter(values.data(), codes.data());
 
-    auto fset_input = [&](size_t i, Value val) {
-      const TensorValueNode* tv = val.as<TensorValueNode>();
-      CHECK(tv != nullptr) << "expect Tensor argument";
-      setter(i, tv->data);
-      DLContext arg_ctx = tv->data->ctx;
+    auto fset_input = [&](size_t i, ObjectRef val) {
+      const auto nd_array = Downcast<NDArray>(val);
+      setter(i, nd_array);
+      DLContext arg_ctx = nd_array->ctx;
       CHECK(arg_ctx.device_type ==  context_.device_type &&
             arg_ctx.device_id == context_.device_id)
         << "Interpreter expect context to be "
@@ -509,8 +491,8 @@ class Interpreter :
     };
 
     int arg_counter = 0;
-    for (Value arg : args) {
-      if (arg.as<TensorValueNode>()) {
+    for (ObjectRef arg : args) {
+      if (arg->IsInstance<NDArray::ContainerType>()) {
         fset_input(arg_counter++,  arg);
       } else {
         const TupleValueNode* tuple = arg.as<TupleValueNode>();
@@ -536,10 +518,9 @@ class Interpreter :
         shape.push_back(ivalue[0]);
       }
       DLDataType dtype = rtype->dtype;
-      auto out_tensor = TensorValueNode::make(
-          NDArray::Empty(shape, dtype, context_));
-      setter(num_inputs + i, out_tensor->data);
-      return out_tensor;
+      NDArray nd_array = NDArray::Empty(shape, dtype, context_);
+      setter(num_inputs + i, nd_array);
+      return nd_array;
     };
 
     Array<Shape> out_shapes;
@@ -560,7 +541,7 @@ class Interpreter :
     TVMRetValue rv;
     if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
       CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
-      Array<Value> fields;
+      Array<ObjectRef> fields;
       for (size_t i = 0; i < rtype->fields.size(); ++i) {
         if (is_dyn) {
           auto sh = out_shapes[i];
@@ -573,7 +554,7 @@ class Interpreter :
       packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
       return TupleValueNode::make(fields);
     } else {
-      Value out_tensor;
+      ObjectRef out_tensor;
       if (is_dyn) {
         CHECK_EQ(out_shapes.size(), 1);
         auto sh = out_shapes[0];
@@ -588,14 +569,16 @@ class Interpreter :
   }
 
   // Invoke the closure
-  Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
+  ObjectRef Invoke(const Closure& closure,
+                   const tvm::Array<ObjectRef>& args,
+                   const Var& bind = Var()) {
     // Get a reference to the function inside the closure.
     if (closure->func->IsPrimitive()) {
       return InvokePrimitiveOp(closure->func, args);
     }
     auto func = closure->func;
     // Allocate a frame with the parameters and free variables.
-    tvm::Map<Var, Value> locals;
+    tvm::Map<Var, ObjectRef> locals;
 
     CHECK_EQ(func->params.size(), args.size());
 
@@ -614,11 +597,11 @@ class Interpreter :
       locals.Set(bind, RecClosureNode::make(closure, bind));
     }
 
-    return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
+    return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); });
   }
 
-  Value VisitExpr_(const CallNode* call) final {
-    tvm::Array<Value> args;
+  ObjectRef VisitExpr_(const CallNode* call) final {
+    tvm::Array<ObjectRef> args;
     for (auto arg : call->args) {
       args.push_back(Eval(arg));
     }
@@ -636,7 +619,7 @@ class Interpreter :
       return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
     }
     // Now we just evaluate and expect to find a closure.
-    Value fn_val = Eval(call->op);
+    ObjectRef fn_val = Eval(call->op);
     if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
       auto closure = GetRef<Closure>(closure_node);
       return this->Invoke(closure, args);
@@ -645,11 +628,11 @@ class Interpreter :
     } else {
       LOG(FATAL) << "internal error: type error, expected function value in the call "
                  << "position";
-      return Value();
+      return ObjectRef();
     }
   }
 
-  Value VisitExpr_(const LetNode* let) final {
+  ObjectRef VisitExpr_(const LetNode* let) final {
     if (auto func = let->value.as<FunctionNode>()) {
       auto clo = MakeClosure(GetRef<Function>(func), let->var);
       this->extend(let->var, clo);
@@ -661,8 +644,8 @@ class Interpreter :
     return Eval(let->body);
   }
 
-  Value VisitExpr_(const TupleGetItemNode* op) final {
-    Value val = Eval(op->tuple);
+  ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
+    ObjectRef val = Eval(op->tuple);
     auto product_node = val.as<TupleValueNode>();
     CHECK(product_node)
       << "interal error: when evaluating TupleGetItem expected a tuple value";
@@ -671,13 +654,14 @@ class Interpreter :
     return product_node->fields[op->index];
   }
 
-  Value VisitExpr_(const IfNode* op) final {
-    Value v = Eval(op->cond);
-    if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
+  ObjectRef VisitExpr_(const IfNode* op) final {
+    ObjectRef v = Eval(op->cond);
+    if (v->IsInstance<NDArray::ContainerType>()) {
+      auto nd_array = Downcast<NDArray>(v);
       DLContext cpu_ctx;
       cpu_ctx.device_type = kDLCPU;
       cpu_ctx.device_id = 0;
-      NDArray cpu_array = bv->data.CopyTo(cpu_ctx);
+      NDArray cpu_array = nd_array.CopyTo(cpu_ctx);
       CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool());
       // TODO(@jroesch, @MK): Refactor code into helper from DCE.
       if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
@@ -687,47 +671,47 @@ class Interpreter :
       }
     } else {
       LOG(FATAL) << "type error, type system should have caught this";
-      return Value();
+      return ObjectRef();
     }
   }
 
-  Value VisitExpr_(const RefWriteNode* op) final {
-    Value r = Eval(op->ref);
+  ObjectRef VisitExpr_(const RefWriteNode* op) final {
+    ObjectRef r = Eval(op->ref);
     if (const RefValueNode* rv = r.as<RefValueNode>()) {
       rv->value = Eval(op->value);
       return TupleValueNode::make({});
     } else {
       LOG(FATAL) << "type error, type system should have caught this";
-      return Value();
+      return ObjectRef();
     }
   }
 
-  Value VisitExpr_(const RefCreateNode* op) final {
+  ObjectRef VisitExpr_(const RefCreateNode* op) final {
     return RefValueNode::make(Eval(op->value));
   }
 
-  Value VisitExpr_(const RefReadNode* op) final {
-    Value r = Eval(op->ref);
+  ObjectRef VisitExpr_(const RefReadNode* op) final {
+    ObjectRef r = Eval(op->ref);
     if (const RefValueNode* rv = r.as<RefValueNode>()) {
       return rv->value;
     } else {
       LOG(FATAL) << "type error, type system should have caught this";
-      return Value();
+      return ObjectRef();
     }
   }
 
-  Value VisitExpr_(const MatchNode* op) final {
-    Value v = Eval(op->data);
+  ObjectRef VisitExpr_(const MatchNode* op) final {
+    ObjectRef v = Eval(op->data);
     for (const Clause& c : op->clauses) {
       if (VisitPattern(c->lhs, v)) {
         return VisitExpr(c->rhs);
       }
     }
     LOG(FATAL) << "did not find any match";
-    return Value();
+    return ObjectRef();
   }
 
-  bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
+  bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final {
     const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
     CHECK(cvn) << "need to be a constructor for match";
     CHECK_NE(op->constructor->tag, -1);
@@ -744,7 +728,7 @@ class Interpreter :
     return false;
   }
 
-  bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
+  bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final {
     const TupleValueNode* tvn = v.as<TupleValueNode>();
     CHECK(tvn) << "need to be a tuple for match";
     CHECK_EQ(op->patterns.size(), tvn->fields.size());
@@ -756,11 +740,11 @@ class Interpreter :
     return true;
   }
 
-  bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
+  bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final {
     return true;
   }
 
-  bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
+  bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final {
     extend(op->var, v);
     return true;
   }
@@ -783,7 +767,7 @@ class Interpreter :
   DLContext context_;
   // Target parameter being used by the interpreter.
   Target target_;
-  // Value stack.
+  // Object stack.
   Stack stack_;
   // Backend compile engine.
   CompileEngine engine_;
@@ -793,7 +777,7 @@ class Interpreter :
 };
 
 
-TypedPackedFunc<Value(Expr)>
+TypedPackedFunc<ObjectRef(Expr)>
 CreateInterpreter(
     Module mod,
     DLContext context,
@@ -814,7 +798,7 @@ CreateInterpreter(
     CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
     return intrp->Eval(expr);
   };
-  return TypedPackedFunc<Value(Expr)>(packed);
+  return TypedPackedFunc<ObjectRef(Expr)>(packed);
 }
 
 TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
@@ -822,7 +806,6 @@ TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
 
 TVM_REGISTER_NODE_TYPE(ClosureNode);
 TVM_REGISTER_NODE_TYPE(TupleValueNode);
-TVM_REGISTER_NODE_TYPE(TensorValueNode);
 
 }  // namespace relay
 }  // namespace tvm
index 5d262a0..bb47685 100644 (file)
@@ -854,7 +854,7 @@ void VMCompiler::Lower(Module mod,
 
   // populate constants
   for (auto data : context_.constants) {
-    exec_->constants.push_back(vm::Tensor(data));
+    exec_->constants.push_back(data);
   }
 
   // update global function map
index d36733d..7f00c71 100644 (file)
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/transform.h>
-#include "./pattern_util.h"
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/ndarray.h>
+#include "pattern_util.h"
 
 namespace tvm {
 namespace relay {
 
-using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
+using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
 
 class ConstantChecker : private ExprVisitor {
  public:
@@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator {
   const Op& cast_op_;
 
   // Convert value to expression.
-  Expr ValueToExpr(Value value) {
-    if (const auto* val = value.as<TensorValueNode>()) {
-      for (auto dim : val->data.Shape()) {
+  Expr ObjectToExpr(const ObjectRef& value) {
+    if (value->IsInstance<runtime::NDArray::ContainerType>()) {
+      auto nd_array = Downcast<runtime::NDArray>(value);
+      for (auto dim : nd_array.Shape()) {
         CHECK_GT(dim, 0)
           << "invalid dimension after constant eval";
       }
-      return ConstantNode::make(val->data);
+      return ConstantNode::make(nd_array);
     } else if (const auto* val = value.as<TupleValueNode>()) {
       Array<Expr> fields;
-      for (Value field : val->fields) {
-        fields.push_back(ValueToExpr(field));
+      for (ObjectRef field : val->fields) {
+        fields.push_back(ObjectToExpr(field));
       }
       return TupleNode::make(fields);
     } else {
@@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator {
     mod = seq(mod);
     auto entry_func = mod->Lookup("main");
     expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
-    return ValueToExpr(executor_(expr));
+    return ObjectToExpr(executor_(expr));
   }
 
   // Evaluate a call to the shape_of operator for tensors with constant
@@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator {
       }
     }
 
-    Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value)));
+    Constant shape = Downcast<Constant>(ObjectToExpr(value));
 
     if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
       auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
@@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) {
   // in case we are already in a build context.
   With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
 
-  return ConstantFolder(CreateInterpreter(
-      mod, ctx, target), mod).Mutate(expr);
+  return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
 }
 
 namespace transform {
index a6b8671..b066803 100644 (file)
@@ -403,7 +403,7 @@ Fuel MkFTop() {
 /*!
  * \brief A stack frame in the Relay interpreter.
  *
- * Contains a mapping from relay::Var to relay::Value.
+ * Contains a mapping from relay::Var to relay::Object.
  */
 struct Frame {
   /*! \brief The set of local variables and arguments for the frame. */
@@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) {
   return sov.stateful;
 }
 
-using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
+using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
 
 DLContext CPUContext() {
   DLContext ctx;
@@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     }
   }
 
-  PStatic Reify(const Value& v, LetList* ll) const {
-    if (const TensorValueNode* op = v.as<TensorValueNode>()) {
-      return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data)));
+  PStatic Reify(const ObjectRef& v, LetList* ll) const {
+    if (v->IsInstance<runtime::NDArray::ContainerType>()) {
+      auto nd_array = Downcast<runtime::NDArray>(v);
+      return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
     } else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
       std::vector<PStatic> fields;
       tvm::Array<Expr> fields_dyn;
-      for (const Value& field : op->fields) {
+      for (const ObjectRef& field : op->fields) {
         PStatic ps = Reify(field, ll);
         fields.push_back(ps);
         fields_dyn.push_back(ps->dynamic);
index 3714425..bd65066 100644 (file)
@@ -150,10 +150,8 @@ std::string Executable::Stats() const {
   // Get the number of constants and the shape of each of them.
   oss << "  Constant shapes (# " << constants.size() << "): [";
   for (const auto& it : constants) {
-    const auto* cell = it.as<TensorObj>();
-    CHECK(cell);
-    runtime::NDArray data = cell->data;
-    const auto& shape = data.Shape();
+    const auto constant = Downcast<NDArray>(it);
+    const auto& shape = constant.Shape();
 
     // Scalar
     if (shape.empty()) {
@@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) {
 void Executable::SaveConstantSection(dmlc::Stream* strm) {
   std::vector<DLTensor*> arrays;
   for (const auto& obj : this->constants) {
-    const auto* cell = obj.as<runtime::vm::TensorObj>();
-    CHECK(cell != nullptr);
-    runtime::NDArray data = cell->data;
-    arrays.push_back(const_cast<DLTensor*>(data.operator->()));
+    const auto cell = Downcast<runtime::NDArray>(obj);
+    arrays.push_back(const_cast<DLTensor*>(cell.operator->()));
   }
   strm->Write(static_cast<uint64_t>(this->constants.size()));
   for (const auto& it : arrays) {
@@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) {
   for (size_t i = 0; i < size; i++) {
     runtime::NDArray constant;
     STREAM_CHECK(constant.Load(strm), "constant");
-    runtime::ObjectRef obj = runtime::vm::Tensor(constant);
-    this->constants.push_back(obj);
+    this->constants.push_back(constant);
   }
 }
 
index 988ba5d..d7760d5 100644 (file)
@@ -34,12 +34,6 @@ namespace tvm {
 namespace runtime {
 namespace vm {
 
-Tensor::Tensor(NDArray data) {
-  auto ptr = make_object<TensorObj>();
-  ptr->data = std::move(data);
-  data_ = std::move(ptr);
-}
-
 Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
   auto ptr = make_object<ClosureObj>();
   ptr->func_index = func_index;
@@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
 }
 
 
-TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-  ObjectRef obj = args[0];
-  const auto* cell = obj.as<TensorObj>();
-  CHECK(cell != nullptr);
-  *rv = cell->data;
-});
-
 TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
@@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
   *rv = adt[idx];
 });
 
-TVM_REGISTER_GLOBAL("_vmobj.Tensor")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-*rv = Tensor(args[0].operator NDArray());
-});
-
 TVM_REGISTER_GLOBAL("_vmobj.Tuple")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   std::vector<ObjectRef> fields;
@@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT")
   *rv = ADT(tag, fields);
 });
 
-TVM_REGISTER_OBJECT_TYPE(TensorObj);
 TVM_REGISTER_OBJECT_TYPE(ADTObj);
 TVM_REGISTER_OBJECT_TYPE(ClosureObj);
 }  // namespace vm
index 10b27d1..49aba7f 100644 (file)
@@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
   return os;
 }
 
-ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
-  if (const TensorObj* obj = src.as<TensorObj>()) {
-    auto tensor = obj->data;
-    if (tensor->ctx.device_type != ctx.device_type) {
-      auto copy = tensor.CopyTo(ctx);
-      return Tensor(copy);
-    } else {
-      return src;
+inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
+  if (src->IsInstance<NDArray::ContainerType>()) {
+    auto nd_array = Downcast<NDArray>(src);
+    if (nd_array->ctx.device_type != ctx.device_type) {
+      return nd_array.CopyTo(ctx);
     }
-  } else {
-    return src;
   }
+  return src;
 }
 
 PackedFunc VirtualMachine::GetFunction(const std::string& name,
@@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
     if (const auto* dt_cell = args[i].as<ADTObj>()) {
       for (size_t fi = 0; fi < dt_cell->size; ++fi) {
         auto obj = (*dt_cell)[fi];
-        const auto* tensor = obj.as<TensorObj>();
-        CHECK(tensor != nullptr) << "Expect tensor object, but received: "
-                                 << obj->GetTypeKey();
-        setter(idx++, tensor->data);
+        auto nd_array = Downcast<NDArray>(obj);
+        setter(idx++, nd_array);
       }
     } else {
-      const auto* tensor = args[i].as<TensorObj>();
-      CHECK(tensor != nullptr) << "Expect tensor object, but received: "
-                               << args[i]->GetTypeKey();
-      setter(idx++, tensor->data);
+      auto nd_array = Downcast<NDArray>(args[i]);
+      setter(idx++, nd_array);
     }
   }
 
@@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
 inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
   int32_t result;
   const auto& obj = ReadRegister(r);
-  const auto* tensor = obj.as<TensorObj>();
-  CHECK(tensor != nullptr) << "Expect tensor object, but received: "
-                           << obj->GetTypeKey();
-  NDArray array = tensor->data.CopyTo({kDLCPU, 0});
+  auto nd_array = Downcast<NDArray>(obj);
+  NDArray array = nd_array.CopyTo({kDLCPU, 0});
 
   if (array->dtype.bits <= 8) {
     result = reinterpret_cast<int8_t*>(array->data)[0];
@@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() {
       case Opcode::LoadConsti: {
         auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
         reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
-        WriteRegister(instr.dst, Tensor(tensor));
+        WriteRegister(instr.dst, tensor);
         pc_++;
         goto main_loop;
       }
@@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() {
         auto tag = adt.tag();
         auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
         reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
-        WriteRegister(instr.dst, Tensor(tag_tensor));
+        WriteRegister(instr.dst, tag_tensor);
         pc_++;
         goto main_loop;
       }
@@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() {
 
         auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
         auto storage = Downcast<Storage>(storage_obj);
-        auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
+        auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
 
-        auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
         pc_++;
         goto main_loop;
@@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() {
         cpu_ctx.device_type = kDLCPU;
         cpu_ctx.device_id = 0;
         auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
-        const auto* tensor = shape_tensor_obj.as<TensorObj>();
-        CHECK(tensor != nullptr) << "Expect tensor object, but received: "
-                                 << shape_tensor_obj->GetTypeKey();
-        NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
+        const auto shape_arr = Downcast<NDArray>(shape_tensor_obj);
+        NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx);
         const DLTensor* dl_tensor = shape_tensor.operator->();
         CHECK_EQ(dl_tensor->dtype.code, 0u);
         CHECK_LE(dl_tensor->dtype.bits, 64);
@@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() {
 
         auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
         auto storage = Downcast<Storage>(storage_obj);
-        auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
+        auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
 
-        auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
         pc_++;
         goto main_loop;
index 612347d..e39c41e 100644 (file)
@@ -18,6 +18,7 @@
 import pytest
 import tensorflow as tf
 import numpy as np
+from tvm import nd
 from tvm import relay
 from tvm.relay.frontend.tensorflow import from_tensorflow
 
@@ -26,7 +27,7 @@ def check_equal(graph, tf_out):
     mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
     ex = relay.create_executor('vm', mod=mod)
     relay_out = ex.evaluate()(**params)
-    if isinstance(relay_out, relay.vmobj.Tensor):
+    if isinstance(relay_out, nd.NDArray):
         np.testing.assert_allclose(tf_out, relay_out.asnumpy())
     else:
         if not isinstance(tf_out, list):
index 97557d3..b394081 100644 (file)
@@ -60,7 +60,7 @@ tf_dtypes = {
 }
 
 def vmobj_to_list(o):
-    if isinstance(o, tvm.relay.backend.vmobj.Tensor):
+    if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy().tolist()]
     elif isinstance(o, tvm.relay.backend.vmobj.ADT):
         result = []
@@ -87,8 +87,6 @@ def vmobj_to_list(o):
         else:
             raise RuntimeError("Unknown object type: %s" %
                                o.constructor.name_hint)
-    elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
-        return [o.data.asnumpy()]
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
index c0185e4..8e304bd 100644 (file)
@@ -115,10 +115,8 @@ def tree_to_dict(t):
 
 
 def vmobj_to_list(o, dtype="float32"):
-    if isinstance(o, tvm.relay.backend.vmobj.Tensor):
+    if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy().tolist()]
-    elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
-        return [o.asnumpy()]
     elif isinstance(o, tvm.relay.backend.vmobj.ADT):
         if len(o) == 0:
             tensor_nil = p.get_var("tensor_nil", dtype=dtype)
index c1a19c4..85bba44 100644 (file)
@@ -17,8 +17,9 @@
 import numpy as np
 import tvm
 import tvm.testing
+from tvm import nd
 from tvm import relay
-from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
+from tvm.relay.backend.interpreter import TupleValue
 from tvm.relay.backend.interpreter import RefValue, ConstructorValue
 from tvm.relay.scope_builder import ScopeBuilder
 from tvm.relay import testing, create_executor
@@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
             result.asnumpy(), expected_result, rtol=rtol)
 
 
-def test_from_scalar():
-    np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
-    np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
-    np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
-
-
 def test_tuple_value():
-    tv = TupleValue(Value.from_scalar(
-        1), Value.from_scalar(2), Value.from_scalar(3))
-    np.testing.assert_allclose(tv[0].asnumpy(), 1)
-    np.testing.assert_allclose(tv[1].asnumpy(), 2)
-    np.testing.assert_allclose(tv[2].asnumpy(), 3)
+    tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
+    np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
+    np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
+    np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
 
 
 def test_tuple_getitem():
@@ -158,12 +152,6 @@ def test_binds():
     tvm.testing.assert_allclose(xx + xx, res)
 
 
-def test_tensor_value():
-    x = relay.var("x", shape=(1, 10))
-    xx = np.ones((1, 10)).astype("float32")
-    check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
-
-
 def test_kwargs_params():
     x = relay.var("x", shape=(1, 10))
     y = relay.var("y", shape=(1, 10))
@@ -174,7 +162,7 @@ def test_kwargs_params():
     z_data = np.random.rand(1, 10).astype('float32')
     params = { 'y': y_data, 'z': z_data }
     intrp = create_executor("debug")
-    res = intrp.evaluate(f)(x_data, **params).data
+    res = intrp.evaluate(f)(x_data, **params)
     tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
 
 
@@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple():
 
     nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
     cons_value = ConstructorValue(prelude.cons.tag, [
-        TensorValue(np.random.rand(1, 10).astype('float32')),
+        nd.array(np.random.rand(1, 10).astype('float32')),
         nil_value
     ], prelude.cons)
 
-    ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
+    ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
     tuple_value = TupleValue(*[
-        TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
+        nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
     ])
 
     id_func = intrp.evaluate(prelude.id)
@@ -236,9 +224,7 @@ def test_tuple_passing():
     out = f((10, 8))
     tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
     # Second use a tuple value.
-    value_tuple = TupleValue(
-        TensorValue(np.array(11)),
-        TensorValue(np.array(12)))
+    value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12)))
     out = f(value_tuple)
     tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
 
@@ -252,7 +238,6 @@ if __name__ == "__main__":
     test_binds()
     test_kwargs_params()
     test_ref()
-    test_tensor_value()
     test_tuple_value()
     test_tuple_getitem()
     test_function_taking_adt_ref_tuple()
index 2a07e95..f87f90a 100644 (file)
@@ -19,7 +19,7 @@ import tvm
 from tvm import relay
 from tvm.relay.testing import to_python, run_as_python
 from tvm.relay.prelude import Prelude
-from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue
+from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue
 
 # helper: uses a dummy let binding to sequence a list
 # of expressions: expr1; expr2; expr3, etc.
@@ -39,9 +39,9 @@ def init_box_adt(mod):
     return (box, box_ctor)
 
 
-# assert that the candidate is a TensorValue with value val
+# assert that the candidate is a NDArray with value val
 def assert_tensor_value(candidate, val):
-    assert isinstance(candidate, TensorValue)
+    assert isinstance(candidate, tvm.nd.NDArray)
     assert np.array_equal(candidate.asnumpy(), np.array(val))
 
 
@@ -68,6 +68,7 @@ def test_create_empty_tuple():
 def test_create_scalar():
     scalar = relay.const(1)
     tensor_val = run_as_python(scalar)
+    print(type(tensor_val))
     assert_tensor_value(tensor_val, 1)
 
 
@@ -544,7 +545,7 @@ def test_batch_norm():
 
         # there will be a change in accuracy so we need to check
         # approximate equality
-        assert isinstance(call_val, TensorValue)
+        assert isinstance(call_val, tvm.nd.NDArray)
         tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps)
 
     verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)])
index 8a160b1..d53360d 100644 (file)
@@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
     return vm.invoke("main", *args)
 
 def vmobj_to_list(o):
-    if isinstance(o, tvm.relay.backend.vm.Tensor):
+    if isinstance(o, tvm.nd.NDArray):
         return [o.asnumpy().tolist()]
     elif isinstance(o, tvm.relay.backend.vm.ADT):
         result = []
index 12d263d..82a2b11 100644 (file)
@@ -19,28 +19,16 @@ import numpy as np
 import tvm
 from tvm.relay import vm
 
-def test_tensor():
-    arr = tvm.nd.array([1,2,3])
-    x = vm.Tensor(arr)
-    assert isinstance(x, vm.Tensor)
-    assert x.asnumpy()[0] == 1
-    assert x.asnumpy()[-1] == 3
-    assert isinstance(x.data, tvm.nd.NDArray)
-
-
 def test_adt():
     arr = tvm.nd.array([1,2,3])
-    x = vm.Tensor(arr)
-    y = vm.ADT(0, [x, x])
+    y = vm.ADT(0, [arr, arr])
 
     assert len(y) == 2
     assert isinstance(y, vm.ADT)
-    y[0:1][-1].data == x.data
+    y[0:1][-1] == arr
     assert y.tag == 0
-    assert isinstance(x.data, tvm.nd.NDArray)
-
+    assert isinstance(arr, tvm.nd.NDArray)
 
 
 if __name__ == "__main__":
-    test_tensor()
     test_adt()
index dc517e2..aa569ce 100644 (file)
@@ -28,7 +28,7 @@ def test_double_buffer():
     with ib.for_range(0, n) as i:
         B = ib.allocate("float32", m, name="B", scope="shared")
         with ib.new_scope():
-            ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
+            ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
             with ib.for_range(0, m) as j:
                 B[j] = A[i * 4 + j]
         with ib.for_range(0, m) as j:
@@ -39,7 +39,7 @@ def test_double_buffer():
     stmt = tvm.ir_pass.Simplify(stmt)
     assert isinstance(stmt.body.body, tvm.stmt.Allocate)
     assert stmt.body.body.extents[0].value == 2
-    f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
+    f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.ir_pass.ThreadSync(f, "shared")
     count = [0]
     def count_sync(op):
index a3d0597..08e261b 100644 (file)
@@ -32,7 +32,7 @@ def test_vthread():
             ib.scope_attr(ty, "virtual_thread", nthread)
             B = ib.allocate("float32", m, name="B", scope="shared")
             B[i] = A[i * nthread + tx]
-            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
+            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
             ib.emit(tvm.call_extern("int32", "Run",
                                     bbuffer.access_ptr("r"),
                                     tvm.call_pure_intrin("int32", "tvm_context_id")))
@@ -60,9 +60,9 @@ def test_vthread_extern():
             A = ib.allocate("float32", m, name="A", scope="shared")
             B = ib.allocate("float32", m, name="B", scope="shared")
             C = ib.allocate("float32", m, name="C", scope="shared")
-            cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
-            abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
-            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
+            cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asobject())
+            abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asobject())
+            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
             A[tx] = tx + 1.0
             B[ty] = ty + 1.0
             ib.emit(tvm.call_extern("int32", "Run",
index 02edfe7..da32f60 100644 (file)
@@ -79,7 +79,7 @@ def test_flatten_double_buffer():
     with ib.for_range(0, n) as i:
         B = ib.allocate("float32", m, name="B", scope="shared")
         with ib.new_scope():
-            ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
+            ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
             with ib.for_range(0, m) as j:
                 B[j] = A[i * 4 + j]
         with ib.for_range(0, m) as j:
@@ -91,7 +91,7 @@ def test_flatten_double_buffer():
     stmt = tvm.ir_pass.Simplify(stmt)
     assert isinstance(stmt.body.body, tvm.stmt.Allocate)
     assert stmt.body.body.extents[0].value == 2
-    f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
+    f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.ir_pass.ThreadSync(f, "shared")
     count = [0]
     def count_sync(op):