[REFACTOR][PY] Establish tvm.runtime (#4818)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 5 Feb 2020 17:00:03 +0000 (09:00 -0800)
committerGitHub <noreply@github.com>
Wed, 5 Feb 2020 17:00:03 +0000 (09:00 -0800)
* [REFACTOR][PY] Establish tvm.runtime

This PR establishes the tvm.runtime namespace that contains the core runtime data structures.
The top-level API are kept inact for now via re-exporting.

We will followup later to cleanup some of the top-level APIs.

* Fix ndarray name

45 files changed:
python/tvm/__init__.py
python/tvm/_ffi/_ctypes/packed_func.py
python/tvm/_ffi/_cython/packed_func.pxi
python/tvm/_ffi/module.py [deleted file]
python/tvm/_ffi/runtime_ctypes.py
python/tvm/api.py
python/tvm/arith.py
python/tvm/attrs.py
python/tvm/build_module.py
python/tvm/container.py
python/tvm/contrib/debugger/debug_runtime.py
python/tvm/contrib/nvcc.py
python/tvm/datatype.py
python/tvm/expr.py
python/tvm/ir_builder.py
python/tvm/ndarray.py [deleted file]
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/relay/backend/vm.py
python/tvm/relay/base.py
python/tvm/relay/expr.py
python/tvm/relay/memory_alloc.py
python/tvm/relay/op/annotation/annotation.py
python/tvm/relay/op/tensor.py
python/tvm/rpc/client.py
python/tvm/rpc/server.py
python/tvm/runtime/__init__.py [new file with mode: 0644]
python/tvm/runtime/container.py [new file with mode: 0644]
python/tvm/runtime/module.py [moved from python/tvm/module.py with 78% similarity]
python/tvm/runtime/ndarray.py [moved from python/tvm/_ffi/ndarray.py with 65% similarity]
python/tvm/runtime/object.py [moved from python/tvm/_ffi/object.py with 66% similarity]
python/tvm/runtime/object_generic.py [moved from python/tvm/_ffi/object_generic.py with 94% similarity]
python/tvm/runtime/packed_func.py [moved from python/tvm/_ffi/packed_func.py with 81% similarity]
python/tvm/schedule.py
python/tvm/stmt.py
python/tvm/target.py
python/tvm/tensor.py
python/tvm/tensor_intrin.py
tests/python/contrib/test_sparse.py
tests/python/relay/test_ir_parser.py
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_codegen_llvm.py
tests/python/unittest/test_hybrid_script.py
tests/python/unittest/test_pass_verify_gpu_code.py
tests/python/unittest/test_schedule_schedule_ops.py
topi/tests/python/test_topi_sparse.py

index ed42599..e765720 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=redefined-builtin, wildcard-import
-"""TVM: Low level DSL/IR stack for tensor computation."""
+"""TVM: Open Deep Learning Compiler Stack."""
 import multiprocessing
 import sys
 import traceback
 
-# import ffi related features
+# top-level alias
+# tvm._ffi
 from ._ffi.base import TVMError, __version__
-from ._ffi.runtime_ctypes import TypeCode, TVMType
-from ._ffi.ndarray import TVMContext
-from ._ffi.packed_func import PackedFunc as Function
+from ._ffi.runtime_ctypes import TypeCode, DataType
 from ._ffi.registry import register_object, register_func, register_extension
-from ._ffi.object import Object
 
+# top-level alias
+# tvm.runtime
+from .runtime.object import Object
+from .runtime.packed_func import PackedFunc as Function
+from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
+from .runtime import module
+from .runtime import ndarray
+# pylint: disable=reimported
+from .runtime import ndarray as nd
+
+# others
 from . import tensor
 from . import arith
 from . import expr
@@ -37,7 +47,7 @@ from . import ir_pass
 from . import codegen
 from . import container
 from . import schedule
-from . import module
+
 from . import attrs
 from . import ir_builder
 from . import target
@@ -47,9 +57,6 @@ from . import testing
 from . import error
 from . import datatype
 
-from . import ndarray as nd
-from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
-from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
 
 from .api import *
 from .intrin import *
index 5eaa738..6e1dbf5 100644 (file)
@@ -23,7 +23,7 @@ from numbers import Number, Integral
 
 from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
 from ..base import c_str, string_types
-from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
+from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
 from . import ndarray as _nd
 from .ndarray import NDArrayBase, _make_array
 from .types import TVMValue, TypeCode
@@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
         elif isinstance(arg, Number):
             values[i].v_float64 = arg
             type_codes[i] = TypeCode.FLOAT
-        elif isinstance(arg, TVMType):
+        elif isinstance(arg, DataType):
             values[i].v_str = c_str(str(arg))
             type_codes[i] = TypeCode.STR
         elif isinstance(arg, TVMContext):
index 5630d72..9d13dbd 100644 (file)
@@ -20,7 +20,7 @@ import traceback
 from cpython cimport Py_INCREF, Py_DECREF
 from numbers import Number, Integral
 from ..base import string_types, py2cerror
-from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
+from ..runtime_ctypes import DataType, TVMContext, TVMByteArray
 
 
 cdef void tvm_callback_finalize(void* fhandle):
@@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
     elif isinstance(arg, Number):
         value[0].v_float64 = arg
         tcode[0] = kFloat
-    elif isinstance(arg, TVMType):
+    elif isinstance(arg, DataType):
         tstr = c_str(str(arg))
         value[0].v_str = tstr
         tcode[0] = kTVMStr
diff --git a/python/tvm/_ffi/module.py b/python/tvm/_ffi/module.py
deleted file mode 100644 (file)
index d6c81b3..0000000
+++ /dev/null
@@ -1,98 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=invalid-name, unused-import
-"""Runtime Module namespace."""
-import ctypes
-from .base import _LIB, check_call, c_str, string_types
-from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
-
-class ModuleBase(object):
-    """Base class for module"""
-    __slots__ = ["handle", "_entry", "entry_name"]
-
-    def __init__(self, handle):
-        self.handle = handle
-        self._entry = None
-        self.entry_name = "__tvm_main__"
-
-    def __del__(self):
-        check_call(_LIB.TVMModFree(self.handle))
-
-    def __hash__(self):
-        return ctypes.cast(self.handle, ctypes.c_void_p).value
-
-    @property
-    def entry_func(self):
-        """Get the entry function
-
-        Returns
-        -------
-        f : Function
-            The entry function if exist
-        """
-        if self._entry:
-            return self._entry
-        self._entry = self.get_function(self.entry_name)
-        return self._entry
-
-    def get_function(self, name, query_imports=False):
-        """Get function from the module.
-
-        Parameters
-        ----------
-        name : str
-            The name of the function
-
-        query_imports : bool
-            Whether also query modules imported by this module.
-
-        Returns
-        -------
-        f : Function
-            The result function.
-        """
-        ret_handle = PackedFuncHandle()
-        check_call(_LIB.TVMModGetFunction(
-            self.handle, c_str(name),
-            ctypes.c_int(query_imports),
-            ctypes.byref(ret_handle)))
-        if not ret_handle.value:
-            raise AttributeError(
-                "Module has no function '%s'" %  name)
-        return PackedFunc(ret_handle, False)
-
-    def import_module(self, module):
-        """Add module to the import list of current one.
-
-        Parameters
-        ----------
-        module : Module
-            The other module.
-        """
-        check_call(_LIB.TVMModImport(self.handle, module.handle))
-
-    def __getitem__(self, name):
-        if not isinstance(name, string_types):
-            raise ValueError("Can only take string as function name")
-        return self.get_function(name)
-
-    def __call__(self, *args):
-        if self._entry:
-            return self._entry(*args)
-        f = self.entry_func
-        return f(*args)
index d6d9b3a..f779b51 100644 (file)
@@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
     _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
                 ("size", ctypes.c_size_t)]
 
-class TVMType(ctypes.Structure):
+class DataType(ctypes.Structure):
     """TVM datatype structure"""
     _fields_ = [("type_code", ctypes.c_uint8),
                 ("bits", ctypes.c_uint8),
@@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
         4 : 'handle'
     }
     def __init__(self, type_str):
-        super(TVMType, self).__init__()
+        super(DataType, self).__init__()
         if isinstance(type_str, np.dtype):
             type_str = str(type_str)
 
@@ -104,8 +104,8 @@ class TVMType(ctypes.Structure):
     def __repr__(self):
         if self.bits == 1 and self.lanes == 1:
             return "bool"
-        if self.type_code in TVMType.CODE2STR:
-            type_name = TVMType.CODE2STR[self.type_code]
+        if self.type_code in DataType.CODE2STR:
+            type_name = DataType.CODE2STR[self.type_code]
         else:
             type_name = "custom[%s]" % \
                         _api_internal._datatype_get_type_name(self.type_code)
@@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
     _fields_ = [("data", ctypes.c_void_p),
                 ("ctx", TVMContext),
                 ("ndim", ctypes.c_int),
-                ("dtype", TVMType),
+                ("dtype", DataType),
                 ("shape", ctypes.POINTER(tvm_shape_index_t)),
                 ("strides", ctypes.POINTER(tvm_shape_index_t)),
                 ("byte_offset", ctypes.c_uint64)]
index 573732e..c536919 100644 (file)
@@ -20,10 +20,10 @@ from numbers import Integral as _Integral
 
 import tvm._ffi
 
+from tvm.runtime import convert, const, DataType
 from ._ffi.base import string_types, TVMError
-from ._ffi.object_generic import convert, const
 from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
-from ._ffi.runtime_ctypes import TVMType
+
 from . import _api_internal
 from . import make as _make
 from . import expr as _expr
index 7434aee..35cd6a3 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Arithmetic data structure and utility"""
-from __future__ import absolute_import as _abs
 import tvm._ffi
+from tvm.runtime import Object
 
-from ._ffi.object import Object
 from . import _api_internal
 
 class IntSet(Object):
index 78d5b18..dc6ca72 100644 (file)
@@ -17,7 +17,7 @@
 """ TVM Attribute module, which is mainly used for defining attributes of operators"""
 import tvm._ffi
 
-from ._ffi.object import Object
+from tvm.runtime import Object
 from . import _api_internal
 
 
index c5097b2..a5bb3d0 100644 (file)
@@ -22,7 +22,7 @@ LoweredFunc and compiled Module.
 import warnings
 import tvm._ffi
 
-from ._ffi.object import Object
+from tvm.runtime import Object, ndarray
 from . import api
 from . import _api_internal
 from . import tensor
@@ -33,7 +33,6 @@ from . import stmt as _stmt
 from . import container
 from . import module
 from . import codegen
-from . import ndarray
 from . import target as _target
 from . import make
 
index b74cc04..094d7ca 100644 (file)
@@ -17,9 +17,9 @@
 """Container data structures used in TVM DSL."""
 import tvm._ffi
 
-from tvm import ndarray as _nd
+from tvm.runtime import Object, ObjectTypes
+from tvm.runtime.container import getitem_helper
 from . import _api_internal
-from ._ffi.object import Object, getitem_helper
 
 
 @tvm._ffi.register_object
@@ -31,23 +31,9 @@ class Array(Object):
     to Array during tvm function call.
     You may get Array in return values of TVM function call.
     """
-    def __getitem__(self, i):
-        if isinstance(i, slice):
-            start = i.start if i.start is not None else 0
-            stop = i.stop if i.stop is not None else len(self)
-            step = i.step if i.step is not None else 1
-            if start < 0:
-                start += len(self)
-            if stop < 0:
-                stop += len(self)
-            return [self[idx] for idx in range(start, stop, step)]
-
-        if i < -len(self) or i >= len(self):
-            raise IndexError("Array index out of range. Array size: {}, got index {}"
-                             .format(len(self), i))
-        if i < 0:
-            i += len(self)
-        return _api_internal._ArrayGetItem(self, i)
+    def __getitem__(self, idx):
+        return getitem_helper(
+            self, _api_internal._ArrayGetItem, len(self), idx)
 
     def __len__(self):
         return _api_internal._ArraySize(self)
@@ -133,7 +119,7 @@ class ADT(Object):
     """
     def __init__(self, tag, fields):
         for f in fields:
-            assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
+            assert isinstance(f, ObjectTypes), "Expect object or " \
             "tvm NDArray type, but received : {0}".format(type(f))
         self.__init_handle_by_constructor__(_ADT, tag, *fields)
 
@@ -164,7 +150,7 @@ def tuple_object(fields=None):
     """
     fields = fields if fields else []
     for f in fields:
-        assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
+        assert isinstance(f, ObjectTypes), "Expect object or tvm " \
         "NDArray type, but received : {0}".format(type(f))
     return _Tuple(*fields)
 
index a5f6e30..848d7f5 100644 (file)
@@ -23,7 +23,7 @@ import tvm._ffi
 
 from tvm._ffi.base import string_types
 from tvm.contrib import graph_runtime
-from tvm.ndarray import array
+from tvm.runtime.ndarray import array
 from . import debug_result
 
 _DUMP_ROOT_PREFIX = "tvmdbg_"
index 0e97ac1..c50a9ce 100644 (file)
@@ -21,8 +21,9 @@ from __future__ import absolute_import as _abs
 import subprocess
 import os
 import warnings
+from tvm.runtime import ndarray as nd
+
 from . import util
-from .. import ndarray as nd
 from ..api import register_func
 from .._ffi.base import py_str
 
index 809e435..8a93673 100644 (file)
@@ -20,7 +20,7 @@ import tvm._ffi
 from . import make as _make
 from .api import convert
 from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
-from ._ffi.runtime_ctypes import TVMType as _TVMType
+from ._ffi.runtime_ctypes import DataType
 from . import _api_internal
 
 
@@ -131,7 +131,7 @@ def create_lower_func(extern_func_name):
         width as the custom type is returned. Otherwise, the type is
         unchanged."""
         dtype = op.dtype
-        t = _TVMType(dtype)
+        t = DataType(dtype)
         if get_type_registered(t.type_code):
             dtype = "uint" + str(t.bits)
             if t.lanes > 1:
index 46a7eac..910061a 100644 (file)
@@ -31,12 +31,9 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 # pylint: disable=missing-docstring
-from __future__ import absolute_import as _abs
 import tvm._ffi
+from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode
 
-from ._ffi.object import Object
-from ._ffi.object_generic import ObjectGeneric
-from ._ffi.runtime_ctypes import TVMType, TypeCode
 from . import make as _make
 from . import generic as _generic
 from . import _api_internal
@@ -52,7 +49,7 @@ def _dtype_is_int(value):
     if isinstance(value, int):
         return True
     return (isinstance(value, ExprOp) and
-            TVMType(value.dtype).type_code == TypeCode.INT)
+            DataType(value.dtype).type_code == TypeCode.INT)
 
 
 class ExprOp(object):
index 8bd5892..c08b3a5 100644 (file)
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Developer API of IR node builder make function."""
-from __future__ import absolute_import as _abs
+from tvm.runtime import ObjectGeneric, DataType
+
+from ._ffi.base import string_types
 
 from . import api as _api
 from . import stmt as _stmt
@@ -23,9 +25,6 @@ from . import expr as _expr
 from . import make as _make
 from . import ir_pass as _pass
 from . import container as _container
-from ._ffi.base import string_types
-from ._ffi.object_generic import ObjectGeneric
-from ._ffi.runtime_ctypes import TVMType
 from .expr import Call as _Call
 
 class WithScope(object):
@@ -78,7 +77,7 @@ class BufferVar(ObjectGeneric):
         return self._content_type
 
     def __getitem__(self, index):
-        t = TVMType(self._content_type)
+        t = DataType(self._content_type)
         if t.lanes > 1:
             index = _make.Ramp(index * t.lanes, 1, t.lanes)
         return _make.Load(self._content_type, self._buffer_var, index)
@@ -89,7 +88,7 @@ class BufferVar(ObjectGeneric):
             raise ValueError(
                 "data type does not match content type %s vs %s" % (
                     value.dtype, self._content_type))
-        t = TVMType(self._content_type)
+        t = DataType(self._content_type)
         if t.lanes > 1:
             index = _make.Ramp(index * t.lanes, 1, t.lanes)
         self._builder.emit(_make.Store(self._buffer_var, value, index))
diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py
deleted file mode 100644 (file)
index 5d78fe4..0000000
+++ /dev/null
@@ -1,233 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Runtime NDArray API.
-
-tvm.ndarray provides a minimum runtime array API to test
-the correctness of the program.
-"""
-# pylint: disable=invalid-name,unused-import
-import tvm._ffi
-import numpy as _np
-
-from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
-from ._ffi.ndarray import context, empty, from_dlpack
-from ._ffi.ndarray import _set_class_ndarray
-
-
-@tvm._ffi.register_object
-class NDArray(NDArrayBase):
-    """Lightweight NDArray class of TVM runtime.
-
-    Strictly this is only an Array Container (a buffer object)
-    No arthimetic operations are defined.
-    All operations are performed by TVM functions.
-
-    The goal is not to re-build yet another array library.
-    Instead, this is a minimal data structure to demonstrate
-    how can we use TVM in existing project which might have their own array containers.
-    """
-
-
-def cpu(dev_id=0):
-    """Construct a CPU device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(1, dev_id)
-
-
-def gpu(dev_id=0):
-    """Construct a CPU device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(2, dev_id)
-
-def rocm(dev_id=0):
-    """Construct a ROCM device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(10, dev_id)
-
-
-def opencl(dev_id=0):
-    """Construct a OpenCL device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(4, dev_id)
-
-
-def metal(dev_id=0):
-    """Construct a metal device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(8, dev_id)
-
-
-def vpi(dev_id=0):
-    """Construct a VPI simulated device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(9, dev_id)
-
-
-def vulkan(dev_id=0):
-    """Construct a Vulkan device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(7, dev_id)
-
-
-def opengl(dev_id=0):
-    """Construct a OpenGL device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(11, dev_id)
-
-
-def ext_dev(dev_id=0):
-    """Construct a extension device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-
-    Note
-    ----
-    This API is reserved for quick testing of new
-    device by plugin device API as ext_dev.
-    """
-    return TVMContext(12, dev_id)
-
-
-def micro_dev(dev_id=0):
-    """Construct a micro device
-
-    Parameters
-    ----------
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx : TVMContext
-        The created context
-    """
-    return TVMContext(13, dev_id)
-
-
-cl = opencl
-mtl = metal
-
-
-def array(arr, ctx=cpu(0)):
-    """Create an array from source arr.
-
-    Parameters
-    ----------
-    arr : numpy.ndarray
-        The array to be copied from
-
-    ctx : TVMContext, optional
-        The device context to create the array
-
-    Returns
-    -------
-    ret : NDArray
-        The created array
-    """
-    if not isinstance(arr, (_np.ndarray, NDArray)):
-        arr = _np.array(arr)
-    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
-
-_set_class_ndarray(NDArray)
index 73a700e..f58a9b0 100644 (file)
@@ -33,9 +33,7 @@ To connect to the graph runtime, we use a printer that converts our graph format
 into TVM's JSON format. The resulting string can be loaded by
 contrib.graph_runtime or any other TVM runtime compatible systems.
 """
-from __future__ import absolute_import
-
-from tvm.ndarray import empty
+from tvm.runtime.ndarray import empty
 from tvm.relay import _build_module
 from tvm import target as _target
 from tvm import expr as _expr
index 606e0cc..53b56d5 100644 (file)
@@ -23,9 +23,9 @@ Implements a Python interface to compiling and executing on the Relay VM.
 import numpy as np
 
 import tvm
-import tvm.ndarray as _nd
+import tvm.runtime.ndarray as _nd
+from tvm.runtime import Object
 from tvm import autotvm, container
-from tvm._ffi.object import Object
 from tvm.relay import expr as _expr
 from tvm._ffi.runtime_ctypes import TVMByteArray
 from tvm._ffi import base as _base
index a723eda..bc04125 100644 (file)
 """The base node types for the Relay language."""
 import tvm._ffi
 
-from .._ffi.object import Object
+from tvm.runtime import Object
 from . import _make
 from . import _expr
 from . import _base
 
-Object = Object
 
 def register_relay_node(type_key=None):
     """Register a Relay node type.
index e288c3d..97185ee 100644 (file)
@@ -20,14 +20,14 @@ from __future__ import absolute_import
 from numbers import Number as _Number
 
 import numpy as _np
+from tvm._ffi import base as _base
+from tvm.runtime import NDArray, convert, ndarray as _nd
+
 from .base import RelayNode, register_relay_node
 from . import _make
 from . import _expr
 from . import ty as _ty
-from .._ffi import base as _base
-from .. import nd as _nd
-from .. import convert
-from ..ndarray import NDArray
+
 
 # will be registered afterwards
 _op_make = None
index a8d1a30..f93aa9e 100644 (file)
@@ -23,7 +23,7 @@ from .expr_functor import ExprMutator
 from .scope_builder import ScopeBuilder
 from . import transform
 from . import op, ty, expr
-from .. import TVMType, register_func
+from .. import DataType, register_func
 from .backend import compile_engine
 
 
@@ -109,7 +109,7 @@ class ManifestAllocPass(ExprMutator):
         return expr.Tuple(new_fields)
 
     def compute_alignment(self, dtype):
-        dtype = TVMType(dtype)
+        dtype = DataType(dtype)
         align = (dtype.bits // 8) * dtype.lanes
         # MAGIC CONSTANT FROM device_api.h
         if align < 64:
@@ -118,7 +118,7 @@ class ManifestAllocPass(ExprMutator):
         return expr.const(align, dtype="int64")
 
     def compute_storage_in_relay(self, shape, dtype):
-        dtype = TVMType(dtype)
+        dtype = DataType(dtype)
         els = op.prod(shape)
         num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
         num = num + expr.const(7, self.compute_dtype)
@@ -126,7 +126,7 @@ class ManifestAllocPass(ExprMutator):
         return els * (num / div)
 
     def compute_storage(self, tensor_type):
-        dtype = TVMType(tensor_type.dtype)
+        dtype = DataType(tensor_type.dtype)
         shape = [int(sh) for sh in tensor_type.shape]
         size = 1
         for sh in shape:
index 9363925..586c300 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Annotation operations."""
-from __future__ import absolute_import as _abs
+from tvm.runtime import ndarray as _nd
+from tvm.runtime import TVMContext as _TVMContext
+
 from . import _make
 from ..op import register_schedule, schedule_injective
-from .... import nd as _nd
-from .... import TVMContext as _TVMContext
+
 
 def on_device(data, device):
     """Annotate an expression with a certain device type.
index bf78ca9..0d3f098 100644 (file)
 # under the License.
 """Basic tensor operations."""
 # pylint: disable=redefined-builtin
-from __future__ import absolute_import as _abs
+from tvm.runtime import ndarray as _nd
+from tvm.runtime import TVMContext as _TVMContext
+
 from . import _make
 from ..expr import Tuple
-from ... import nd as _nd
-from ... import TVMContext as _TVMContext
+
 
 # We create a wrapper function for each operator in the
 # python side to call into the positional _make.OpName function.
index 314f1ab..ed57e0d 100644 (file)
@@ -22,12 +22,12 @@ import socket
 import struct
 import time
 import tvm._ffi
+from tvm.contrib import util
+from tvm._ffi.base import TVMError
+from tvm.runtime import ndarray as nd
+from tvm.runtime import load_module as _load_module
 
 from . import base
-from ..contrib import util
-from .._ffi.base import TVMError
-from .._ffi import ndarray as nd
-from ..module import load as _load_module
 
 
 class RPCSession(object):
index efebe8b..e79aa82 100644 (file)
@@ -38,10 +38,10 @@ import sys
 import signal
 import tvm._ffi
 
-from .._ffi.base import py_str
-from .._ffi.libinfo import find_lib_path
-from ..module import load as _load_module
-from ..contrib import util
+from tvm._ffi.base import py_str
+from tvm._ffi.libinfo import find_lib_path
+from tvm.runtime.module import load as _load_module
+from tvm.contrib import util
 from . import base
 from . base import TrackerCode
 
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
new file mode 100644 (file)
index 0000000..a6ca614
--- /dev/null
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""TVM runtime."""
+
+# class exposures
+from .packed_func import PackedFunc
+from .object import Object
+from .object_generic import ObjectGeneric, ObjectTypes
+from .ndarray import NDArray, DataType, TypeCode, TVMContext
+from .module import Module
+
+# function exposures
+from .object_generic import convert_to_object, convert, const
+from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
+from .module import load as load_module
+
+DataType = DataType
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
new file mode 100644 (file)
index 0000000..196291f
--- /dev/null
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Runtime container structures."""
+
+def getitem_helper(obj, elem_getter, length, idx):
+    """Helper function to implement a pythonic getitem function.
+
+    Parameters
+    ----------
+    obj: object
+        The original object
+
+    elem_getter : function
+        A simple function that takes index and return a single element.
+
+    length : int
+        The size of the array
+
+    idx : int or slice
+        The argument passed to getitem
+
+    Returns
+    -------
+    result : object
+        The result of getitem
+    """
+    if isinstance(idx, slice):
+        start = idx.start if idx.start is not None else 0
+        stop = idx.stop if idx.stop is not None else length
+        step = idx.step if idx.step is not None else 1
+        if start < 0:
+            start += length
+        if stop < 0:
+            stop += length
+        return [elem_getter(obj, i) for i in range(start, stop, step)]
+
+    if idx < -length or idx >= length:
+        raise IndexError("Index out of range. size: {}, got index {}"
+                         .format(length, idx))
+    if idx < 0:
+        idx += length
+    return elem_getter(obj, idx)
similarity index 78%
rename from python/tvm/module.py
rename to python/tvm/runtime/module.py
index d810af7..0d21c38 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Container of compiled functions of TVM."""
-from __future__ import absolute_import as _abs
 
+# pylint: disable=invalid-name, unused-import
+"""Runtime Module namespace."""
+import ctypes
 import struct
 from collections import namedtuple
+
 import tvm._ffi
+from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
+from tvm._ffi.libinfo import find_include_path
+from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
 
-from ._ffi.module import ModuleBase, _set_class_module
-from ._ffi.libinfo import find_include_path
-from .contrib import cc as _cc, tar as _tar, util as _util
 
+# profile result of time evaluator
 ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
 
 
-class Module(ModuleBase):
-    """Module container of all TVM generated functions"""
+class Module(object):
+    """Runtime Module."""
+    __slots__ = ["handle", "_entry", "entry_name"]
+
+    def __init__(self, handle):
+        self.handle = handle
+        self._entry = None
+        self.entry_name = "__tvm_main__"
+
+    def __del__(self):
+        check_call(_LIB.TVMModFree(self.handle))
+
+    def __hash__(self):
+        return ctypes.cast(self.handle, ctypes.c_void_p).value
+
+    @property
+    def entry_func(self):
+        """Get the entry function
+
+        Returns
+        -------
+        f : Function
+            The entry function if exist
+        """
+        if self._entry:
+            return self._entry
+        self._entry = self.get_function(self.entry_name)
+        return self._entry
+
+    def get_function(self, name, query_imports=False):
+        """Get function from the module.
+
+        Parameters
+        ----------
+        name : str
+            The name of the function
+
+        query_imports : bool
+            Whether also query modules imported by this module.
+
+        Returns
+        -------
+        f : Function
+            The result function.
+        """
+        ret_handle = PackedFuncHandle()
+        check_call(_LIB.TVMModGetFunction(
+            self.handle, c_str(name),
+            ctypes.c_int(query_imports),
+            ctypes.byref(ret_handle)))
+        if not ret_handle.value:
+            raise AttributeError(
+                "Module has no function '%s'" %  name)
+        return PackedFunc(ret_handle, False)
+
+    def import_module(self, module):
+        """Add module to the import list of current one.
+
+        Parameters
+        ----------
+        module : Module
+            The other module.
+        """
+        check_call(_LIB.TVMModImport(self.handle, module.handle))
+
+    def __getitem__(self, name):
+        if not isinstance(name, string_types):
+            raise ValueError("Can only take string as function name")
+        return self.get_function(name)
+
+    def __call__(self, *args):
+        if self._entry:
+            return self._entry(*args)
+        # pylint: disable=not-callable
+        return self.entry_func(*args)
+
 
     def __repr__(self):
         return "Module(%s, %x)" % (self.type_key, self.handle.value)
@@ -85,6 +162,83 @@ class Module(ModuleBase):
         """
         _SaveToFile(self, file_name, fmt)
 
+    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
+        """Get an evaluator that measures time cost of running function.
+
+        Parameters
+        ----------
+        func_name: str
+            The name of the function in the module.
+
+        ctx: TVMContext
+            The context we should run this function on.
+
+        number: int
+            The number of times to run this function for taking average.
+            We call these runs as one `repeat` of measurement.
+
+        repeat: int, optional
+            The number of times to repeat the measurement.
+            In total, the function will be invoked (1 + number x repeat) times,
+            where the first one is warm up and will be discarded.
+            The returned result contains `repeat` costs,
+            each of which is an average of `number` costs.
+
+        min_repeat_ms: int, optional
+            The minimum duration of one `repeat` in milliseconds.
+            By default, one `repeat` contains `number` runs. If this parameter is set,
+            the parameters `number` will be dynamically adjusted to meet the
+            minimum duration requirement of one `repeat`.
+            i.e., When the run time of one `repeat` falls below this time, the `number` parameter
+            will be automatically increased.
+
+        Note
+        ----
+        The function will be invoked  (1 + number x repeat) times,
+        with the first call discarded in case there is lazy initialization.
+
+        Returns
+        -------
+        ftimer : Function
+            The function that takes same argument as func and returns a ProfileResult.
+            The ProfileResult reports `repeat` time costs in seconds.
+        """
+        try:
+            feval = _RPCTimeEvaluator(
+                self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
+
+            def evaluator(*args):
+                """Internal wrapped evaluator."""
+                # Wrap feval so we can add more stats in future.
+                blob = feval(*args)
+                fmt = "@" + ("d" * repeat)
+                results = struct.unpack(fmt, blob)
+                mean = sum(results) / float(repeat)
+                return ProfileResult(mean=mean, results=results)
+
+            return evaluator
+        except NameError:
+            raise NameError("time_evaluate is only supported when RPC is enabled")
+
+    def _collect_dso_modules(self):
+        """Helper function to collect dso modules, then return it."""
+        visited, stack, dso_modules = set(), [], []
+        # append root module
+        visited.add(self)
+        stack.append(self)
+        while stack:
+            module = stack.pop()
+            if module._dso_exportable():
+                dso_modules.append(module)
+            for m in module.imported_modules:
+                if m not in visited:
+                    visited.add(m)
+                    stack.append(m)
+        return dso_modules
+
+    def _dso_exportable(self):
+        return self.type_key == "llvm" or self.type_key == "c"
+
     def export_library(self,
                        file_name,
                        fcompile=None,
@@ -107,7 +261,14 @@ class Module(ModuleBase):
         kwargs : dict, optional
             Additional arguments passed to fcompile
         """
+        # NOTE: this function depends on contrib library features
+        # which are only available in when TVM function is available.
+        if _RUNTIME_ONLY:
+            raise RuntimeError("Cannot call export_library in runtime only mode")
+        # Extra dependencies during runtime.
         from pathlib import Path
+        from tvm.contrib import cc as _cc, tar as _tar, util as _util
+
         if isinstance(file_name, Path):
             file_name = str(file_name)
 
@@ -172,83 +333,6 @@ class Module(ModuleBase):
 
         fcompile(file_name, files, **kwargs)
 
-    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
-        """Get an evaluator that measures time cost of running function.
-
-        Parameters
-        ----------
-        func_name: str
-            The name of the function in the module.
-
-        ctx: TVMContext
-            The context we should run this function on.
-
-        number: int
-            The number of times to run this function for taking average.
-            We call these runs as one `repeat` of measurement.
-
-        repeat: int, optional
-            The number of times to repeat the measurement.
-            In total, the function will be invoked (1 + number x repeat) times,
-            where the first one is warm up and will be discarded.
-            The returned result contains `repeat` costs,
-            each of which is an average of `number` costs.
-
-        min_repeat_ms: int, optional
-            The minimum duration of one `repeat` in milliseconds.
-            By default, one `repeat` contains `number` runs. If this parameter is set,
-            the parameters `number` will be dynamically adjusted to meet the
-            minimum duration requirement of one `repeat`.
-            i.e., When the run time of one `repeat` falls below this time, the `number` parameter
-            will be automatically increased.
-
-        Note
-        ----
-        The function will be invoked  (1 + number x repeat) times,
-        with the first call discarded in case there is lazy initialization.
-
-        Returns
-        -------
-        ftimer : Function
-            The function that takes same argument as func and returns a ProfileResult.
-            The ProfileResult reports `repeat` time costs in seconds.
-        """
-        try:
-            feval = _RPCTimeEvaluator(
-                self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
-
-            def evaluator(*args):
-                """Internal wrapped evaluator."""
-                # Wrap feval so we can add more stats in future.
-                blob = feval(*args)
-                fmt = "@" + ("d" * repeat)
-                results = struct.unpack(fmt, blob)
-                mean = sum(results) / float(repeat)
-                return ProfileResult(mean=mean, results=results)
-
-            return evaluator
-        except NameError:
-            raise NameError("time_evaluate is only supported when RPC is enabled")
-
-    def _collect_dso_modules(self):
-        """Helper function to collect dso modules, then return it."""
-        visited, stack, dso_modules = set(), [], []
-        # append root module
-        visited.add(self)
-        stack.append(self)
-        while stack:
-            module = stack.pop()
-            if module._dso_exportable():
-                dso_modules.append(module)
-            for m in module.imported_modules:
-                if m not in visited:
-                    visited.add(m)
-                    stack.append(m)
-        return dso_modules
-
-    def _dso_exportable(self):
-        return self.type_key == "llvm" or self.type_key == "c"
-
 
 def system_lib():
     """Get system-wide library module singleton.
@@ -296,9 +380,13 @@ def load(path, fmt=""):
     # High level handling for .o and .tar file.
     # We support this to be consistent with RPC module load.
     if path.endswith(".o"):
+        # Extra dependencies during runtime.
+        from tvm.contrib import cc as _cc
         _cc.create_shared(path + ".so", path)
         path += ".so"
     elif path.endswith(".tar"):
+        # Extra dependencies during runtime.
+        from tvm.contrib import cc as _cc, util as _util, tar as _tar
         tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
         _tar.untar(path, tar_temp.temp_dir)
         files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
@@ -333,5 +421,6 @@ def enabled(target):
     return _Enabled(target)
 
 
-tvm._ffi._init_api("tvm.module")
 _set_class_module(Module)
+
+tvm._ffi._init_api("tvm.module", "tvm.runtime.module")
similarity index 65%
rename from python/tvm/_ffi/ndarray.py
rename to python/tvm/runtime/ndarray.py
index f526195..cce73e0 100644 (file)
 """Runtime NDArray api"""
 import ctypes
 import numpy as np
-from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
-from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
-from .runtime_ctypes import TypeCode, tvm_shape_index_t
+import tvm._ffi
+
+from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
+from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle
+from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
 
 try:
     # pylint: disable=wrong-import-position
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
-    from ._cy3.core import NDArrayBase as _NDArrayBase
+    from tvm._ffi._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
+    from tvm._ffi._cy3.core import NDArrayBase
 except (RuntimeError, ImportError):
     # pylint: disable=wrong-import-position
-    from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
-    from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
-
-
-def context(dev_type, dev_id=0):
-    """Construct a TVM context with given device type and id.
-
-    Parameters
-    ----------
-    dev_type: int or str
-        The device type mask or name of the device.
-
-    dev_id : int, optional
-        The integer device id
-
-    Returns
-    -------
-    ctx: TVMContext
-        The corresponding context.
-
-    Examples
-    --------
-    Context can be used to create reflection of context by
-    string representation of the device type.
-
-    .. code-block:: python
-
-      assert tvm.context("cpu", 1) == tvm.cpu(1)
-      assert tvm.context("gpu", 0) == tvm.gpu(0)
-      assert tvm.context("cuda", 0) == tvm.gpu(0)
-    """
-    if isinstance(dev_type, string_types):
-        if '-device=micro_dev' in dev_type:
-            dev_type = 'micro_dev'
-        else:
-            dev_type = dev_type.split()[0]
-            if dev_type not in TVMContext.STR2MASK:
-                raise ValueError("Unknown device type %s" % dev_type)
-            dev_type = TVMContext.STR2MASK[dev_type]
-    return TVMContext(dev_type, dev_id)
-
-
-def numpyasarray(np_data):
-    """Return a TVMArray representation of a numpy array.
-    """
-    data = np_data
-    assert data.flags['C_CONTIGUOUS']
-    arr = TVMArray()
-    shape = c_array(tvm_shape_index_t, data.shape)
-    arr.data = data.ctypes.data_as(ctypes.c_void_p)
-    arr.shape = shape
-    arr.strides = None
-    arr.dtype = TVMType(np.dtype(data.dtype).name)
-    arr.ndim = data.ndim
-    # CPU device
-    arr.ctx = context(1, 0)
-    return arr, shape
-
-
-def empty(shape, dtype="float32", ctx=context(1, 0)):
-    """Create an empty array given shape and device
-
-    Parameters
-    ----------
-    shape : tuple of int
-        The shape of the array
-
-    dtype : type or str
-        The data type of the array.
-
-    ctx : TVMContext
-        The context of the array
-
-    Returns
-    -------
-    arr : tvm.nd.NDArray
-        The array tvm supported.
-    """
-    shape = c_array(tvm_shape_index_t, shape)
-    ndim = ctypes.c_int(len(shape))
-    handle = TVMArrayHandle()
-    dtype = TVMType(dtype)
-    check_call(_LIB.TVMArrayAlloc(
-        shape, ndim,
-        ctypes.c_int(dtype.type_code),
-        ctypes.c_int(dtype.bits),
-        ctypes.c_int(dtype.lanes),
-        ctx.device_type,
-        ctx.device_id,
-        ctypes.byref(handle)))
-    return _make_array(handle, False, False)
+    from tvm._ffi._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
+    from tvm._ffi._ctypes.ndarray import NDArrayBase
 
 
-def from_dlpack(dltensor):
-    """Produce an array from a DLPack tensor without memory copy.
-    Retreives the underlying DLPack tensor's pointer to create an array from the
-    data. Removes the original DLPack tensor's destructor as now the array is
-    responsible for destruction.
+@tvm._ffi.register_object
+class NDArray(NDArrayBase):
+    """Lightweight NDArray class of TVM runtime.
 
-    Parameters
-    ----------
-    dltensor : DLPack tensor
-        Input DLManagedTensor, can only be consumed once.
+    Strictly this is only an Array Container (a buffer object)
+    No arthimetic operations are defined.
+    All operations are performed by TVM functions.
 
-    Returns
-    -------
-    arr: tvm.nd.NDArray
-        The array view of the tensor data.
+    The goal is not to re-build yet another array library.
+    Instead, this is a minimal data structure to demonstrate
+    how can we use TVM in existing project which might have their own array containers.
     """
-    return _from_dlpack(dltensor)
-
-
-class NDArrayBase(_NDArrayBase):
-    """A simple Device/CPU Array object in runtime."""
 
     @property
     def dtype(self):
@@ -224,7 +128,7 @@ class NDArrayBase(_NDArrayBase):
                 raise TypeError('array must be an array_like data,' +
                                 'type %s is not supported' % str(type(source_array)))
 
-        t = TVMType(self.dtype)
+        t = DataType(self.dtype)
         shape, dtype = self.shape, self.dtype
         if t.lanes > 1:
             shape = shape + (t.lanes,)
@@ -257,7 +161,7 @@ class NDArrayBase(_NDArrayBase):
         np_arr : numpy.ndarray
             The corresponding numpy array.
         """
-        t = TVMType(self.dtype)
+        t = DataType(self.dtype)
         shape, dtype = self.shape, self.dtype
         if t.lanes > 1:
             shape = shape + (t.lanes,)
@@ -284,3 +188,303 @@ class NDArrayBase(_NDArrayBase):
             res = empty(self.shape, self.dtype, target)
             return self._copyto(res)
         raise ValueError("Unsupported target type %s" % str(type(target)))
+
+
+def context(dev_type, dev_id=0):
+    """Construct a TVM context with given device type and id.
+
+    Parameters
+    ----------
+    dev_type: int or str
+        The device type mask or name of the device.
+
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx: TVMContext
+        The corresponding context.
+
+    Examples
+    --------
+    Context can be used to create reflection of context by
+    string representation of the device type.
+
+    .. code-block:: python
+
+      assert tvm.context("cpu", 1) == tvm.cpu(1)
+      assert tvm.context("gpu", 0) == tvm.gpu(0)
+      assert tvm.context("cuda", 0) == tvm.gpu(0)
+    """
+    if isinstance(dev_type, string_types):
+        if '-device=micro_dev' in dev_type:
+            dev_type = 'micro_dev'
+        else:
+            dev_type = dev_type.split()[0]
+            if dev_type not in TVMContext.STR2MASK:
+                raise ValueError("Unknown device type %s" % dev_type)
+            dev_type = TVMContext.STR2MASK[dev_type]
+    return TVMContext(dev_type, dev_id)
+
+
+def numpyasarray(np_data):
+    """Return a TVMArray representation of a numpy array.
+    """
+    data = np_data
+    assert data.flags['C_CONTIGUOUS']
+    arr = TVMArray()
+    shape = c_array(tvm_shape_index_t, data.shape)
+    arr.data = data.ctypes.data_as(ctypes.c_void_p)
+    arr.shape = shape
+    arr.strides = None
+    arr.dtype = DataType(np.dtype(data.dtype).name)
+    arr.ndim = data.ndim
+    # CPU device
+    arr.ctx = context(1, 0)
+    return arr, shape
+
+
+def empty(shape, dtype="float32", ctx=context(1, 0)):
+    """Create an empty array given shape and device
+
+    Parameters
+    ----------
+    shape : tuple of int
+        The shape of the array
+
+    dtype : type or str
+        The data type of the array.
+
+    ctx : TVMContext
+        The context of the array
+
+    Returns
+    -------
+    arr : tvm.nd.NDArray
+        The array tvm supported.
+    """
+    shape = c_array(tvm_shape_index_t, shape)
+    ndim = ctypes.c_int(len(shape))
+    handle = TVMArrayHandle()
+    dtype = DataType(dtype)
+    check_call(_LIB.TVMArrayAlloc(
+        shape, ndim,
+        ctypes.c_int(dtype.type_code),
+        ctypes.c_int(dtype.bits),
+        ctypes.c_int(dtype.lanes),
+        ctx.device_type,
+        ctx.device_id,
+        ctypes.byref(handle)))
+    return _make_array(handle, False, False)
+
+
+def from_dlpack(dltensor):
+    """Produce an array from a DLPack tensor without memory copy.
+    Retreives the underlying DLPack tensor's pointer to create an array from the
+    data. Removes the original DLPack tensor's destructor as now the array is
+    responsible for destruction.
+
+    Parameters
+    ----------
+    dltensor : DLPack tensor
+        Input DLManagedTensor, can only be consumed once.
+
+    Returns
+    -------
+    arr: tvm.nd.NDArray
+        The array view of the tensor data.
+    """
+    return _from_dlpack(dltensor)
+
+
+def cpu(dev_id=0):
+    """Construct a CPU device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(1, dev_id)
+
+
+def gpu(dev_id=0):
+    """Construct a CPU device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(2, dev_id)
+
+def rocm(dev_id=0):
+    """Construct a ROCM device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(10, dev_id)
+
+
+def opencl(dev_id=0):
+    """Construct a OpenCL device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(4, dev_id)
+
+
+def metal(dev_id=0):
+    """Construct a metal device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(8, dev_id)
+
+
+def vpi(dev_id=0):
+    """Construct a VPI simulated device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(9, dev_id)
+
+
+def vulkan(dev_id=0):
+    """Construct a Vulkan device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(7, dev_id)
+
+
+def opengl(dev_id=0):
+    """Construct a OpenGL device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(11, dev_id)
+
+
+def ext_dev(dev_id=0):
+    """Construct a extension device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+
+    Note
+    ----
+    This API is reserved for quick testing of new
+    device by plugin device API as ext_dev.
+    """
+    return TVMContext(12, dev_id)
+
+
+def micro_dev(dev_id=0):
+    """Construct a micro device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+    """
+    return TVMContext(13, dev_id)
+
+
+cl = opencl
+mtl = metal
+
+
+def array(arr, ctx=cpu(0)):
+    """Create an array from source arr.
+
+    Parameters
+    ----------
+    arr : numpy.ndarray
+        The array to be copied from
+
+    ctx : TVMContext, optional
+        The device context to create the array
+
+    Returns
+    -------
+    ret : NDArray
+        The created array
+    """
+    if not isinstance(arr, (np.ndarray, NDArray)):
+        arr = np.array(arr)
+    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
+
+# Register back to FFI
+_set_class_ndarray(NDArray)
similarity index 66%
rename from python/tvm/_ffi/object.py
rename to python/tvm/runtime/object.py
index a808580..6b4b77b 100644 (file)
 # pylint: disable=invalid-name, unused-import
 """Runtime Object API"""
 import ctypes
+
+from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
 from .. import _api_internal
-from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
 
 try:
     # pylint: disable=wrong-import-position,unused-import
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    from ._cy3.core import _set_class_object, _set_class_object_generic
-    from ._cy3.core import ObjectBase
+    from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
+    from tvm._ffi._cy3.core import ObjectBase
 except (RuntimeError, ImportError):
     # pylint: disable=wrong-import-position,unused-import
-    from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
-    from ._ctypes.object import ObjectBase
+    from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
+    from tvm._ffi._ctypes.object import ObjectBase
 
 
 def _new_object(cls):
@@ -91,44 +92,4 @@ class Object(ObjectBase):
         return self.__hash__() == other.__hash__()
 
 
-def getitem_helper(obj, elem_getter, length, idx):
-    """Helper function to implement a pythonic getitem function.
-
-    Parameters
-    ----------
-    obj: object
-        The original object
-
-    elem_getter : function
-        A simple function that takes index and return a single element.
-
-    length : int
-        The size of the array
-
-    idx : int or slice
-        The argument passed to getitem
-
-    Returns
-    -------
-    result : object
-        The result of getitem
-    """
-    if isinstance(idx, slice):
-        start = idx.start if idx.start is not None else 0
-        stop = idx.stop if idx.stop is not None else length
-        step = idx.step if idx.step is not None else 1
-        if start < 0:
-            start += length
-        if stop < 0:
-            stop += length
-        return [elem_getter(obj, i) for i in range(start, stop, step)]
-
-    if idx < -length or idx >= length:
-        raise IndexError("Index out of range. size: {}, got index {}"
-                         .format(length, idx))
-    if idx < 0:
-        idx += length
-    return elem_getter(obj, idx)
-
-
 _set_class_object(Object)
similarity index 94%
rename from python/tvm/_ffi/object_generic.py
rename to python/tvm/runtime/object_generic.py
index cbbca4d..499f1cb 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Common implementation of object generic related logic"""
-# pylint: disable=unused-import
+# pylint: disable=unused-import, invalid-name
 from numbers import Number, Integral
-from .. import _api_internal
+from tvm._ffi.base import string_types
 
-from .base import string_types
+from .. import _api_internal
 from .object import ObjectBase, _set_class_object_generic
 from .ndarray import NDArrayBase
 from .packed_func import PackedFuncBase, convert_to_tvm_func
-from .module import ModuleBase
+from .module import Module
 
 
 class ObjectGeneric(object):
@@ -33,7 +33,7 @@ class ObjectGeneric(object):
         raise NotImplementedError()
 
 
-_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
+ObjectTypes = (ObjectBase, NDArrayBase, Module)
 
 
 def convert_to_object(value):
@@ -49,7 +49,7 @@ def convert_to_object(value):
     obj : Object
         The corresponding object value.
     """
-    if isinstance(value, _CLASS_OBJECTS):
+    if isinstance(value, ObjectTypes):
         return value
     if isinstance(value, bool):
         return const(value, 'uint1x1')
@@ -63,7 +63,7 @@ def convert_to_object(value):
     if isinstance(value, dict):
         vlist = []
         for item in value.items():
-            if (not isinstance(item[0], _CLASS_OBJECTS) and
+            if (not isinstance(item[0], ObjectTypes) and
                     not isinstance(item[0], string_types)):
                 raise ValueError("key of map must already been a container type")
             vlist.append(item[0])
similarity index 81%
rename from python/tvm/_ffi/packed_func.py
rename to python/tvm/runtime/packed_func.py
index d0917a8..05cdef4 100644 (file)
 # pylint: disable=invalid-name, unused-import
 """Packed Function namespace."""
 import ctypes
-from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
+from tvm._ffi.base import _LIB, check_call, c_str, string_types, _FFI_MODE
 
 try:
     # pylint: disable=wrong-import-position
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    from ._cy3.core import _set_class_packed_func, _set_class_module
-    from ._cy3.core import PackedFuncBase
-    from ._cy3.core import convert_to_tvm_func
+    from tvm._ffi._cy3.core import _set_class_packed_func, _set_class_module
+    from tvm._ffi._cy3.core import PackedFuncBase
+    from tvm._ffi._cy3.core import convert_to_tvm_func
 except (RuntimeError, ImportError):
     # pylint: disable=wrong-import-position
-    from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
-    from ._ctypes.packed_func import PackedFuncBase
-    from ._ctypes.packed_func import convert_to_tvm_func
+    from tvm._ffi._ctypes.packed_func import _set_class_packed_func, _set_class_module
+    from tvm._ffi._ctypes.packed_func import PackedFuncBase
+    from tvm._ffi._ctypes.packed_func import convert_to_tvm_func
 
 
 PackedFuncHandle = ctypes.c_void_p
index bf4e75f..d779b59 100644 (file)
@@ -17,9 +17,8 @@
 """The computation schedule api of TVM."""
 import tvm._ffi
 
-from ._ffi.base import string_types
-from ._ffi.object import Object
-from ._ffi.object_generic import convert
+from tvm._ffi.base import string_types
+from tvm.runtime import Object, convert
 
 from . import _api_internal
 from . import tensor as _tensor
index 5908934..27bb4db 100644 (file)
@@ -30,7 +30,8 @@ Each statement node have subfields that can be visited from python side.
     assert(st.buffer_var == a)
 """
 import tvm._ffi
-from ._ffi.object import Object
+
+from tvm.runtime import Object
 from . import make as _make
 
 
index 45dbf5f..a505271 100644 (file)
@@ -57,8 +57,8 @@ We can also use other specific function in this module to create specific target
 import warnings
 import tvm._ffi
 
+from tvm.runtime import Object
 from ._ffi.base import _LIB_NAME
-from ._ffi.object import Object
 from . import _api_internal
 
 try:
index 522e901..525c2b1 100644 (file)
@@ -18,8 +18,7 @@
 # pylint: disable=invalid-name
 import tvm._ffi
 
-from ._ffi.object import Object
-from ._ffi.object_generic import ObjectGeneric, convert_to_object
+from tvm.runtime import Object, ObjectGeneric, convert_to_object
 
 from . import _api_internal
 from . import make as _make
@@ -129,7 +128,6 @@ class Tensor(Object, _expr.ExprOp):
         return "%s.v%d" % (op.name, self.value_index)
 
 
-
 class Operation(Object):
     """Represent an operation that generates a tensor"""
 
index 0b88af1..91f99b6 100644 (file)
@@ -17,6 +17,7 @@
 """Tensor intrinsics"""
 import tvm._ffi
 
+from tvm.runtime import Object
 from . import _api_internal
 from . import api as _api
 from . import expr as _expr
@@ -25,7 +26,6 @@ from . import make as _make
 from . import tensor as _tensor
 from . import schedule as _schedule
 from .build_module import current_build_config
-from ._ffi.object import Object
 
 
 def _get_region(tslice):
index 5425b19..7cc4a00 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 import tvm
 import tvm.contrib.sparse as tvmsp
-import tvm.ndarray as _nd
+import tvm.runtime.ndarray as _nd
 import numpy as np
 from collections import namedtuple
 
index e9a24bf..a871ae1 100644 (file)
@@ -132,7 +132,7 @@ def test_comments():
 
 def test_int_literal():
     assert isinstance(parse_text("1"), relay.Constant)
-    assert isinstance(parse_text("1").data, tvm.ndarray.NDArray)
+    assert isinstance(parse_text("1").data, tvm.nd.NDArray)
 
     assert get_scalar(parse_text("1")) == 1
     assert get_scalar(parse_text("10")) == 10
index 5d05b6d..e4bad2f 100644 (file)
@@ -207,7 +207,7 @@ def test_cuda_shuffle():
         b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
         c_ = np.zeros((64, ), dtype='int32')
         ref = a_ +  np.array((list(range(4))) * 16, dtype='int32')
-        nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
+        nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
         module(nda, ndb, ndc)
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
 
index 42e8581..26a6f82 100644 (file)
@@ -657,9 +657,9 @@ def test_llvm_shuffle():
     with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
         ir = tvm.lower(sch, [a, b, c], simple_mode=True)
         module = tvm.build(sch, [a, b, c])
-        a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
-        b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
-        c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
+        a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
+        b_ = tvm.nd.array(np.arange(8, 0, -1, dtype='int32'))
+        c_ = tvm.nd.array(np.zeros((8, ), dtype='int32'))
         module(a_, b_, c_)
         tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
 
index 9eca902..87e5a26 100644 (file)
@@ -405,8 +405,8 @@ def test_math_intrin():
     func = tvm.build(sch, [a8, b8])
     assert func
     a = numpy.arange(2, 10).astype('float32')
-    tvm_a = tvm.ndarray.array(a)
-    tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32'))
+    tvm_a = tvm.nd.array(a)
+    tvm_b = tvm.nd.array(numpy.zeros((8, ), dtype='float32'))
     b = intrin_real(a)
     func(tvm_a, tvm_b)
     tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
@@ -423,8 +423,8 @@ def test_math_intrin():
     func = tvm.build(sch, [a1, b1])
     assert func
     a = numpy.array([114514]).astype('int32')
-    tvm_a = tvm.ndarray.array(a)
-    tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32'))
+    tvm_a = tvm.nd.array(a)
+    tvm_b = tvm.nd.array(numpy.array([0]).astype('int32'))
     b = intrin_int(a)
     func(tvm_a, tvm_b)
     assert tvm_b.asnumpy()[0] == b[0]
@@ -578,8 +578,8 @@ def test_const_param():
     np_b = 11
     np_c = numpy.zeros((11, )).astype('int32')
 
-    nd_a = tvm.ndarray.array(np_a)
-    nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
+    nd_a = tvm.nd.array(np_a)
+    nd_c = tvm.nd.array(numpy.zeros((11, )).astype('int32'))
     module(nd_a, nd_c)
     ref = add_something(np_a, 11)
 
@@ -614,8 +614,8 @@ def test_value_index():
     np_b, np_c = kernel_a(np_a)
     ref = kernel_b(np_c, np_b)
 
-    res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
-    module(tvm.ndarray.array(np_a), res)
+    res = tvm.nd.array(numpy.zeros((4, 4)).astype('int32'))
+    module(tvm.nd.array(np_a), res)
     tvm.testing.assert_allclose(res.asnumpy(), ref)
 
 def test_func_call():
index bb646f4..76e5f0d 100644 (file)
@@ -28,7 +28,7 @@ def test_shared_memory():
         N = 1024
         M = 128
 
-        tvm_type = tvm.datatype._TVMType(dtype)
+        tvm_type = tvm.runtime.DataType(dtype)
         type_size = tvm_type.bits // 8 * tvm_type.lanes
 
         A = tvm.placeholder((N,), name='A', dtype=dtype)
index b102243..1030a99 100644 (file)
@@ -444,7 +444,7 @@ def test_reduction_and_dummy_fuse_split():
     axo, axi = s[Y.op].split(ax, nparts=20)
     f = tvm.build(s, [Y, X])
 
-    args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
+    args = [tvm.nd.empty((), 'int32')] + [tvm.nd.array(np.ones((n,), dtype='int32'))]
     f(*args)
     assert args[0].asnumpy() == n
 
@@ -456,8 +456,8 @@ def test_reduction_and_dummy_fuse_split():
     ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
     f = tvm.build(s, [Y, X])
 
-    args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \
-        [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
+    args = [tvm.nd.array(np.ones((n,), dtype='int32'))] + \
+        [tvm.nd.array(np.ones((n,), dtype='int32'))]
     f(*args)
     assert np.all(args[0].asnumpy() == n)
 
index 1b40b13..fc28859 100644 (file)
@@ -231,8 +231,8 @@ def test_sparse_dense_csr():
     Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
     s = tvm.create_schedule(Y.op)
     func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
-    Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-    func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
+    Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
+    func(tvm.nd.array(X_np), tvm.nd.array(W_sp_np.data), tvm.nd.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indptr), Y_tvm)
     tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
 
 def test_sparse_transpose_csr():
@@ -246,17 +246,17 @@ def test_sparse_transpose_csr():
     X_data = tvm.placeholder(shape=X_sp.data.shape, dtype=str(X_sp.data.dtype))
     X_indices = tvm.placeholder(shape=X_sp.indices.shape, dtype=str(X_sp.indices.dtype))
     X_indptr = tvm.placeholder(shape=X_sp.indptr.shape, dtype=str(X_sp.indptr.dtype))
-    
+
     X_T_data, X_T_indices, X_T_indptr = topi.nn.sparse_transpose(X_data, X_indices, X_indptr)
     s = tvm.create_schedule([X_T_data.op, X_T_indices.op, X_T_indptr.op])
     func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
 
 
-    X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
-    X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
-    X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
+    X_T_data_tvm = tvm.nd.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
+    X_T_indices_tvm = tvm.nd.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
+    X_T_indptr_tvm = tvm.nd.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
 
-    func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr),
+    func(tvm.nd.array(X_sp.data), tvm.nd.array(X_sp.indices), tvm.nd.array(X_sp.indptr),
         X_T_data_tvm,  X_T_indices_tvm, X_T_indptr_tvm)
 
     X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
@@ -295,11 +295,11 @@ def test_sparse_dense_bsr():
     Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
     s = tvm.create_schedule(Y.op)
     func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
-    Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-    func(tvm.ndarray.array(X_np),
-         tvm.ndarray.array(W_sp_np.data),
-         tvm.ndarray.array(W_sp_np.indices),
-         tvm.ndarray.array(W_sp_np.indptr),
+    Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
+    func(tvm.nd.array(X_np),
+         tvm.nd.array(W_sp_np.data),
+         tvm.nd.array(W_sp_np.indices),
+         tvm.nd.array(W_sp_np.indptr),
          Y_tvm)
     tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
 
@@ -324,11 +324,11 @@ def test_sparse_dense_bsr_randomized():
         Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
         s = tvm.create_schedule(Y.op)
         func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
-        Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
-        func(tvm.ndarray.array(X_np),
-             tvm.ndarray.array(W_sp_np.data),
-             tvm.ndarray.array(W_sp_np.indices),
-             tvm.ndarray.array(W_sp_np.indptr),
+        Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
+        func(tvm.nd.array(X_np),
+             tvm.nd.array(W_sp_np.data),
+             tvm.nd.array(W_sp_np.indices),
+             tvm.nd.array(W_sp_np.indptr),
              Y_tvm)
         tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)