[REFACTOR][PY] tvm._ffi (#4813)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 5 Feb 2020 01:01:01 +0000 (17:01 -0800)
committerGitHub <noreply@github.com>
Wed, 5 Feb 2020 01:01:01 +0000 (17:01 -0800)
* [REFACTOR][PY] tvm._ffi

- Remove from __future__ import absolute_import in the related files as they are no longer needed if the code only runs in python3
- Remove reverse dependency of _ctypes _cython to object_generic.
- function.py -> packed_func.py
- Function -> PackedFunc
- all registry related logics goes to tvm._ffi.registry
- Use absolute references for FFI related calls.
  - tvm._ffi.register_object
  - tvm._ffi.register_func
  - tvm._ffi.get_global_func

* Move get global func to the ffi side

81 files changed:
python/tvm/__init__.py
python/tvm/_ffi/__init__.py
python/tvm/_ffi/_ctypes/ndarray.py
python/tvm/_ffi/_ctypes/object.py
python/tvm/_ffi/_ctypes/packed_func.py [moved from python/tvm/_ffi/_ctypes/function.py with 85% similarity]
python/tvm/_ffi/_ctypes/types.py
python/tvm/_ffi/_cython/base.pxi
python/tvm/_ffi/_cython/core.pyx
python/tvm/_ffi/_cython/object.pxi
python/tvm/_ffi/_cython/packed_func.pxi [moved from python/tvm/_ffi/_cython/function.pxi with 87% similarity]
python/tvm/_ffi/_pyversion.py [moved from python/tvm/_pyversion.py with 93% similarity]
python/tvm/_ffi/base.py
python/tvm/_ffi/libinfo.py
python/tvm/_ffi/module.py [new file with mode: 0644]
python/tvm/_ffi/ndarray.py
python/tvm/_ffi/object.py
python/tvm/_ffi/object_generic.py
python/tvm/_ffi/packed_func.py [new file with mode: 0644]
python/tvm/_ffi/registry.py [moved from python/tvm/_ffi/function.py with 53% similarity]
python/tvm/_ffi/runtime_ctypes.py
python/tvm/api.py
python/tvm/arith.py
python/tvm/attrs.py
python/tvm/build_module.py
python/tvm/codegen.py
python/tvm/container.py
python/tvm/contrib/debugger/debug_runtime.py
python/tvm/contrib/graph_runtime.py
python/tvm/contrib/nnpack.py
python/tvm/contrib/random.py
python/tvm/contrib/tflite_runtime.py
python/tvm/datatype.py
python/tvm/expr.py
python/tvm/hybrid/__init__.py
python/tvm/intrin.py
python/tvm/ir_builder.py
python/tvm/ir_pass.py
python/tvm/make.py
python/tvm/micro/base.py
python/tvm/module.py
python/tvm/ndarray.py
python/tvm/object.py [deleted file]
python/tvm/relay/_analysis.py
python/tvm/relay/_base.py
python/tvm/relay/_build_module.py
python/tvm/relay/_expr.py
python/tvm/relay/_make.py
python/tvm/relay/_module.py
python/tvm/relay/_transform.py
python/tvm/relay/backend/_backend.py
python/tvm/relay/backend/_vm.py
python/tvm/relay/backend/vm.py
python/tvm/relay/base.py
python/tvm/relay/op/_make.py
python/tvm/relay/op/annotation/_make.py
python/tvm/relay/op/contrib/_make.py
python/tvm/relay/op/image/_make.py
python/tvm/relay/op/memory/_make.py
python/tvm/relay/op/nn/_make.py
python/tvm/relay/op/op.py
python/tvm/relay/op/vision/_make.py
python/tvm/relay/qnn/op/_make.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/quantize/_quantize.py
python/tvm/rpc/base.py
python/tvm/rpc/client.py
python/tvm/rpc/server.py
python/tvm/schedule.py
python/tvm/stmt.py
python/tvm/target.py
python/tvm/tensor.py
python/tvm/tensor_intrin.py
topi/python/topi/cpp/cuda.py
topi/python/topi/cpp/generic.py
topi/python/topi/cpp/impl.py
topi/python/topi/cpp/nn.py
topi/python/topi/cpp/rocm.py
topi/python/topi/cpp/util.py
topi/python/topi/cpp/vision/__init__.py
topi/python/topi/cpp/vision/yolo.py
topi/python/topi/cpp/x86.py

index b2a4ca3..ed42599 100644 (file)
 # under the License.
 # pylint: disable=redefined-builtin, wildcard-import
 """TVM: Low level DSL/IR stack for tensor computation."""
-from __future__ import absolute_import as _abs
-
 import multiprocessing
 import sys
 import traceback
 
-from . import _pyversion
+# import ffi related features
+from ._ffi.base import TVMError, __version__
+from ._ffi.runtime_ctypes import TypeCode, TVMType
+from ._ffi.ndarray import TVMContext
+from ._ffi.packed_func import PackedFunc as Function
+from ._ffi.registry import register_object, register_func, register_extension
+from ._ffi.object import Object
 
 from . import tensor
 from . import arith
@@ -34,7 +38,6 @@ from . import codegen
 from . import container
 from . import schedule
 from . import module
-from . import object
 from . import attrs
 from . import ir_builder
 from . import target
@@ -48,15 +51,9 @@ from . import ndarray as nd
 from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
 from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
 
-from ._ffi.runtime_ctypes import TypeCode, TVMType
-from ._ffi.ndarray import TVMContext
-from ._ffi.function import Function
-from ._ffi.base import TVMError, __version__
 from .api import *
 from .intrin import *
 from .tensor_intrin import decl_tensor_intrin
-from .object import register_object
-from .ndarray import register_extension
 from .schedule import create_schedule
 from .build_module import build, lower, build_config
 from .tag import tag_scope
index f19851c..1b2fc58 100644 (file)
@@ -24,3 +24,7 @@ be used via ctypes function calls.
 Some performance critical functions are implemented by cython
 and have a ctypes fallback implementation.
 """
+from . import _pyversion
+from .base import register_error
+from .registry import register_object, register_func, register_extension
+from .registry import _init_api, get_global_func
index c572947..949cc8b 100644 (file)
@@ -16,8 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name
 """Runtime NDArray api"""
-from __future__ import absolute_import
-
 import ctypes
 from ..base import _LIB, check_call, c_str
 from ..runtime_ctypes import TVMArrayHandle
index 8a2fb1b..907b7dd 100644 (file)
@@ -16,8 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name
 """Runtime Object api"""
-from __future__ import absolute_import
-
 import ctypes
 from ..base import _LIB, check_call
 from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
similarity index 85%
rename from python/tvm/_ffi/_ctypes/function.py
rename to python/tvm/_ffi/_ctypes/packed_func.py
index ee3dead..5eaa738 100644 (file)
 # coding: utf-8
 # pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
 """Function configuration API."""
-from __future__ import absolute_import
-
 import ctypes
 import traceback
 from numbers import Number, Integral
 
-from ..base import _LIB, get_last_ffi_error, py2cerror
+from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
 from ..base import c_str, string_types
-from ..object_generic import convert_to_object, ObjectGeneric
 from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
 from . import ndarray as _nd
 from .ndarray import NDArrayBase, _make_array
@@ -35,7 +32,7 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_in
 from .object import ObjectBase, _set_class_object
 from . import object as _object
 
-FunctionHandle = ctypes.c_void_p
+PackedFuncHandle = ctypes.c_void_p
 ModuleHandle = ctypes.c_void_p
 ObjectHandle = ctypes.c_void_p
 TVMRetValueHandle = ctypes.c_void_p
@@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle):
 TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
 ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
 
+
+def _make_packed_func(handle, is_global):
+    """Make a packed function class"""
+    obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
+    obj.is_global = is_global
+    obj.handle = handle
+    return obj
+
+
 def convert_to_tvm_func(pyfunc):
     """Convert a python function to TVM function
 
@@ -89,7 +95,7 @@ def convert_to_tvm_func(pyfunc):
             _ = rv
         return 0
 
-    handle = FunctionHandle()
+    handle = PackedFuncHandle()
     f = TVMPackedCFunc(cfun)
     # NOTE: We will need to use python-api to increase ref count of the f
     # TVM_FREE_PYOBJ will be called after it is no longer needed.
@@ -98,7 +104,7 @@ def convert_to_tvm_func(pyfunc):
     if _LIB.TVMFuncCreateFromCFunc(
             f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
         raise get_last_ffi_error()
-    return _CLASS_FUNCTION(handle, False)
+    return _make_packed_func(handle, False)
 
 
 def _make_tvm_args(args, temp_args):
@@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args):
         elif isinstance(arg, string_types):
             values[i].v_str = c_str(arg)
             type_codes[i] = TypeCode.STR
-        elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
-            arg = convert_to_object(arg)
+        elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
+            arg = _FUNC_CONVERT_TO_OBJECT(arg)
             values[i].v_handle = arg.handle
             type_codes[i] = TypeCode.OBJECT_HANDLE
             temp_args.append(arg)
         elif isinstance(arg, _CLASS_MODULE):
             values[i].v_handle = arg.handle
             type_codes[i] = TypeCode.MODULE_HANDLE
-        elif isinstance(arg, FunctionBase):
+        elif isinstance(arg, PackedFuncBase):
             values[i].v_handle = arg.handle
             type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
         elif isinstance(arg, ctypes.c_void_p):
@@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args):
     return values, type_codes, num_args
 
 
-class FunctionBase(object):
+class PackedFuncBase(object):
     """Function base."""
     __slots__ = ["handle", "is_global"]
     # pylint: disable=no-member
@@ -177,7 +183,7 @@ class FunctionBase(object):
 
         Parameters
         ----------
-        handle : FunctionHandle
+        handle : PackedFuncHandle
             the handle to the underlying function.
 
         is_global : bool
@@ -238,9 +244,22 @@ def _return_module(x):
 def _handle_return_func(x):
     """Return function"""
     handle = x.v_handle
-    if not isinstance(handle, FunctionHandle):
-        handle = FunctionHandle(handle)
-    return _CLASS_FUNCTION(handle, False)
+    if not isinstance(handle, PackedFuncHandle):
+        handle = PackedFuncHandle(handle)
+    return _CLASS_PACKED_FUNC(handle, False)
+
+
+def _get_global_func(name, allow_missing=False):
+    handle = PackedFuncHandle()
+    check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
+
+    if handle.value:
+        return _make_packed_func(handle, False)
+
+    if allow_missing:
+        return None
+
+    raise ValueError("Cannot find global function %s" % name)
 
 # setup return handle for function type
 _object.__init_by_constructor__ = __init_handle_by_constructor__
@@ -255,13 +274,22 @@ C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle,
 C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
 
 _CLASS_MODULE = None
-_CLASS_FUNCTION = None
+_CLASS_PACKED_FUNC = None
+_CLASS_OBJECT_GENERIC = None
+_FUNC_CONVERT_TO_OBJECT = None
+
 
 def _set_class_module(module_class):
     """Initialize the module."""
     global _CLASS_MODULE
     _CLASS_MODULE = module_class
 
-def _set_class_function(func_class):
-    global _CLASS_FUNCTION
-    _CLASS_FUNCTION = func_class
+def _set_class_packed_func(packed_func_class):
+    global _CLASS_PACKED_FUNC
+    _CLASS_PACKED_FUNC = packed_func_class
+
+def _set_class_object_generic(object_generic_class, func_convert_to_object):
+    global _CLASS_OBJECT_GENERIC
+    global _FUNC_CONVERT_TO_OBJECT
+    _CLASS_OBJECT_GENERIC = object_generic_class
+    _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
index 31c4786..f45748f 100644 (file)
@@ -16,8 +16,6 @@
 # under the License.
 """The C Types used in API."""
 # pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
-
 import ctypes
 import struct
 from ..base import py_str, check_call, _LIB
index 420ec62..ad281d7 100644 (file)
@@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t
 ctypedef DLTensor* DLTensorHandle
 ctypedef void* TVMStreamHandle
 ctypedef void* TVMRetValueHandle
-ctypedef void* TVMFunctionHandle
+ctypedef void* TVMPackedFuncHandle
 ctypedef void* ObjectHandle
 
 ctypedef struct TVMObject:
@@ -96,13 +96,15 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
 cdef extern from "tvm/runtime/c_runtime_api.h":
     void TVMAPISetLastError(const char* msg)
     const char *TVMGetLastError()
-    int TVMFuncCall(TVMFunctionHandle func,
+    int TVMFuncGetGlobal(const char* name,
+                         TVMPackedFuncHandle* out);
+    int TVMFuncCall(TVMPackedFuncHandle func,
                     TVMValue* arg_values,
                     int* type_codes,
                     int num_args,
                     TVMValue* ret_val,
                     int* ret_type_code)
-    int TVMFuncFree(TVMFunctionHandle func)
+    int TVMFuncFree(TVMPackedFuncHandle func)
     int TVMCFuncSetReturn(TVMRetValueHandle ret,
                           TVMValue* value,
                           int* type_code,
@@ -110,7 +112,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
     int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
                                void* resource_handle,
                                TVMPackedCFuncFinalizer fin,
-                               TVMFunctionHandle *out)
+                               TVMPackedFuncHandle *out)
     int TVMCbArgToReturn(TVMValue* value, int code)
     int TVMArrayAlloc(tvm_index_t* shape,
                       tvm_index_t ndim,
index cbf9d58..730f8fc 100644 (file)
@@ -17,7 +17,5 @@
 
 include "./base.pxi"
 include "./object.pxi"
-# include "./node.pxi"
-include "./function.pxi"
+include "./packed_func.pxi"
 include "./ndarray.pxi"
-
index 494c3ff..25a9c3f 100644 (file)
@@ -96,6 +96,6 @@ cdef class ObjectBase:
         self.chandle = NULL
         cdef void* chandle
         ConstructorCall(
-            (<FunctionBase>fconstructor).chandle,
+            (<PackedFuncBase>fconstructor).chandle,
             kTVMObjectHandle, args, &chandle)
         self.chandle = chandle
similarity index 87%
rename from python/tvm/_ffi/_cython/function.pxi
rename to python/tvm/_ffi/_cython/packed_func.pxi
index bde672f..5630d72 100644 (file)
@@ -20,7 +20,6 @@ import traceback
 from cpython cimport Py_INCREF, Py_DECREF
 from numbers import Number, Integral
 from ..base import string_types, py2cerror
-from ..object_generic import convert_to_object, ObjectGeneric
 from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
 
 
@@ -67,6 +66,13 @@ cdef int tvm_callback(TVMValue* args,
     return 0
 
 
+cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
+    obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
+    (<PackedFuncBase>obj).chandle = chandle
+    (<PackedFuncBase>obj).is_global = is_global
+    return obj
+
+
 def convert_to_tvm_func(object pyfunc):
     """Convert a python function to TVM function
 
@@ -80,15 +86,13 @@ def convert_to_tvm_func(object pyfunc):
     tvmfunc: tvm.Function
         The converted tvm function.
     """
-    cdef TVMFunctionHandle chandle
+    cdef TVMPackedFuncHandle chandle
     Py_INCREF(pyfunc)
     CALL(TVMFuncCreateFromCFunc(tvm_callback,
                                 <void*>(pyfunc),
                                 tvm_callback_finalize,
                                 &chandle))
-    ret = _CLASS_FUNCTION(None, False)
-    (<FunctionBase>ret).chandle = chandle
-    return ret
+    return make_packed_func(chandle, False)
 
 
 cdef inline int make_arg(object arg,
@@ -149,29 +153,30 @@ cdef inline int make_arg(object arg,
         value[0].v_str = tstr
         tcode[0] = kTVMStr
         temp_args.append(tstr)
-    elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
-        arg = convert_to_object(arg)
+    elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
+        arg = _FUNC_CONVERT_TO_OBJECT(arg)
         value[0].v_handle = (<ObjectBase>arg).chandle
         tcode[0] = kTVMObjectHandle
         temp_args.append(arg)
     elif isinstance(arg, _CLASS_MODULE):
         value[0].v_handle = c_handle(arg.handle)
         tcode[0] = kTVMModuleHandle
-    elif isinstance(arg, FunctionBase):
-        value[0].v_handle = (<FunctionBase>arg).chandle
+    elif isinstance(arg, PackedFuncBase):
+        value[0].v_handle = (<PackedFuncBase>arg).chandle
         tcode[0] = kTVMPackedFuncHandle
     elif isinstance(arg, ctypes.c_void_p):
         value[0].v_handle = c_handle(arg)
         tcode[0] = kTVMOpaqueHandle
     elif callable(arg):
         arg = convert_to_tvm_func(arg)
-        value[0].v_handle = (<FunctionBase>arg).chandle
+        value[0].v_handle = (<PackedFuncBase>arg).chandle
         tcode[0] = kTVMPackedFuncHandle
         temp_args.append(arg)
     else:
         raise TypeError("Don't know how to handle type %s" % type(arg))
     return 0
 
+
 cdef inline bytearray make_ret_bytes(void* chandle):
     handle = ctypes_handle(chandle)
     arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
@@ -182,6 +187,7 @@ cdef inline bytearray make_ret_bytes(void* chandle):
         raise RuntimeError('memmove failed')
     return res
 
+
 cdef inline object make_ret(TVMValue value, int tcode):
     """convert result to return value."""
     if tcode == kTVMObjectHandle:
@@ -205,9 +211,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
     elif tcode == kTVMModuleHandle:
         return _CLASS_MODULE(ctypes_handle(value.v_handle))
     elif tcode == kTVMPackedFuncHandle:
-        fobj = _CLASS_FUNCTION(None, False)
-        (<FunctionBase>fobj).chandle = value.v_handle
-        return fobj
+        return make_packed_func(value.v_handle, False)
     elif tcode in _TVM_EXT_RET:
         return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
 
@@ -264,8 +268,8 @@ cdef inline int ConstructorCall(void* constructor_handle,
     return 0
 
 
-cdef class FunctionBase:
-    cdef TVMFunctionHandle chandle
+cdef class PackedFuncBase:
+    cdef TVMPackedFuncHandle chandle
     cdef int is_global
 
     cdef inline _set_handle(self, handle):
@@ -305,19 +309,39 @@ cdef class FunctionBase:
         return make_ret(ret_val, ret_tcode)
 
 
-_CLASS_FUNCTION = None
+def _get_global_func(name, allow_missing):
+    cdef TVMPackedFuncHandle chandle
+    CALL(TVMFuncGetGlobal(c_str(name), &chandle))
+    if chandle != NULL:
+        return make_packed_func(chandle, True)
+
+    if allow_missing:
+       return None
+
+    raise ValueError("Cannot find global function %s" % name)
+
+
+_CLASS_PACKED_FUNC = None
 _CLASS_MODULE = None
 _CLASS_OBJECT = None
+_CLASS_OBJECT_GENERIC = None
+_FUNC_CONVERT_TO_OBJECT = None
 
 def _set_class_module(module_class):
     """Initialize the module."""
     global _CLASS_MODULE
     _CLASS_MODULE = module_class
 
-def _set_class_function(func_class):
-    global _CLASS_FUNCTION
-    _CLASS_FUNCTION = func_class
+def _set_class_packed_func(func_class):
+    global _CLASS_PACKED_FUNC
+    _CLASS_PACKED_FUNC = func_class
 
 def _set_class_object(obj_class):
     global _CLASS_OBJECT
     _CLASS_OBJECT = obj_class
+
+def _set_class_object_generic(object_generic_class, func_convert_to_object):
+    global _CLASS_OBJECT_GENERIC
+    global _FUNC_CONVERT_TO_OBJECT
+    _CLASS_OBJECT_GENERIC = object_generic_class
+    _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
similarity index 93%
rename from python/tvm/_pyversion.py
rename to python/tvm/_ffi/_pyversion.py
index a46b220..67591b3 100644 (file)
@@ -18,6 +18,9 @@
 """
 import sys
 
+#----------------------------
+# Python3 version.
+#----------------------------
 if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5):
     PY3STATEMENT = """TVM project proudly dropped support of Python2.
     The minimal Python requirement is Python 3.5
index 36effa3..ddc942a 100644 (file)
@@ -17,8 +17,6 @@
 # coding: utf-8
 # pylint: disable=invalid-name
 """Base library for TVM FFI."""
-from __future__ import absolute_import
-
 import sys
 import os
 import ctypes
@@ -28,27 +26,22 @@ from . import libinfo
 #----------------------------
 # library loading
 #----------------------------
-if sys.version_info[0] == 3:
-    string_types = (str,)
-    integer_types = (int, np.int32)
-    numeric_types = integer_types + (float, np.float32)
-    # this function is needed for python3
-    # to convert ctypes.char_p .value back to python str
-    if sys.platform == "win32":
-        def _py_str(x):
-            try:
-                return x.decode('utf-8')
-            except UnicodeDecodeError:
-                encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
-                return x.decode(encoding)
-        py_str = _py_str
-    else:
-        py_str = lambda x: x.decode('utf-8')
+string_types = (str,)
+integer_types = (int, np.int32)
+numeric_types = integer_types + (float, np.float32)
+
+# this function is needed for python3
+# to convert ctypes.char_p .value back to python str
+if sys.platform == "win32":
+    def _py_str(x):
+        try:
+            return x.decode('utf-8')
+        except UnicodeDecodeError:
+            encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
+        return x.decode(encoding)
+    py_str = _py_str
 else:
-    string_types = (basestring,)
-    integer_types = (int, long, np.int32)
-    numeric_types = integer_types + (float, np.float32)
-    py_str = lambda x: x
+    py_str = lambda x: x.decode('utf-8')
 
 
 def _load_lib():
index e87a336..c026a7a 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Library information."""
-from __future__ import absolute_import
 import sys
 import os
 
@@ -39,6 +38,7 @@ def split_env_var(env_var, split):
         return [p.strip() for p in os.environ[env_var].split(split)]
     return []
 
+
 def find_lib_path(name=None, search_path=None, optional=False):
     """Find dynamic library files.
 
diff --git a/python/tvm/_ffi/module.py b/python/tvm/_ffi/module.py
new file mode 100644 (file)
index 0000000..d6c81b3
--- /dev/null
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-import
+"""Runtime Module namespace."""
+import ctypes
+from .base import _LIB, check_call, c_str, string_types
+from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
+
+class ModuleBase(object):
+    """Base class for module"""
+    __slots__ = ["handle", "_entry", "entry_name"]
+
+    def __init__(self, handle):
+        self.handle = handle
+        self._entry = None
+        self.entry_name = "__tvm_main__"
+
+    def __del__(self):
+        check_call(_LIB.TVMModFree(self.handle))
+
+    def __hash__(self):
+        return ctypes.cast(self.handle, ctypes.c_void_p).value
+
+    @property
+    def entry_func(self):
+        """Get the entry function
+
+        Returns
+        -------
+        f : Function
+            The entry function if exist
+        """
+        if self._entry:
+            return self._entry
+        self._entry = self.get_function(self.entry_name)
+        return self._entry
+
+    def get_function(self, name, query_imports=False):
+        """Get function from the module.
+
+        Parameters
+        ----------
+        name : str
+            The name of the function
+
+        query_imports : bool
+            Whether also query modules imported by this module.
+
+        Returns
+        -------
+        f : Function
+            The result function.
+        """
+        ret_handle = PackedFuncHandle()
+        check_call(_LIB.TVMModGetFunction(
+            self.handle, c_str(name),
+            ctypes.c_int(query_imports),
+            ctypes.byref(ret_handle)))
+        if not ret_handle.value:
+            raise AttributeError(
+                "Module has no function '%s'" %  name)
+        return PackedFunc(ret_handle, False)
+
+    def import_module(self, module):
+        """Add module to the import list of current one.
+
+        Parameters
+        ----------
+        module : Module
+            The other module.
+        """
+        check_call(_LIB.TVMModImport(self.handle, module.handle))
+
+    def __getitem__(self, name):
+        if not isinstance(name, string_types):
+            raise ValueError("Can only take string as function name")
+        return self.get_function(name)
+
+    def __call__(self, *args):
+        if self._entry:
+            return self._entry(*args)
+        f = self.entry_func
+        return f(*args)
index 650f01d..f526195 100644 (file)
 # under the License.
 # pylint: disable=invalid-name, unused-import
 """Runtime NDArray api"""
-from __future__ import absolute_import
-
-import sys
 import ctypes
 import numpy as np
 from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
 from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
 from .runtime_ctypes import TypeCode, tvm_shape_index_t
 
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
-
 try:
     # pylint: disable=wrong-import-position
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    if sys.version_info >= (3, 0):
-        from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
-        from ._cy3.core import NDArrayBase as _NDArrayBase
-        from ._cy3.core import _reg_extension
-    else:
-        from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
-        from ._cy2.core import NDArrayBase as _NDArrayBase
-        from ._cy2.core import _reg_extension
-except IMPORT_EXCEPT:
+    from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
+    from ._cy3.core import NDArrayBase as _NDArrayBase
+except (RuntimeError, ImportError):
     # pylint: disable=wrong-import-position
     from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
     from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
-    from ._ctypes.ndarray import _reg_extension
 
 
 def context(dev_type, dev_id=0):
@@ -297,59 +284,3 @@ class NDArrayBase(_NDArrayBase):
             res = empty(self.shape, self.dtype, target)
             return self._copyto(res)
         raise ValueError("Unsupported target type %s" % str(type(target)))
-
-
-def register_extension(cls, fcreate=None):
-    """Register a extension class to TVM.
-
-    After the class is registered, the class will be able
-    to directly pass as Function argument generated by TVM.
-
-    Parameters
-    ----------
-    cls : class
-        The class object to be registered as extension.
-
-    fcreate : function, optional
-        The creation function to create a class object given handle value.
-
-    Note
-    ----
-    The registered class is requires one property: _tvm_handle.
-
-    If the registered class is a subclass of NDArray,
-    it is required to have a class attribute _array_type_code.
-    Otherwise, it is required to have a class attribute _tvm_tcode.
-
-    - ```_tvm_handle``` returns integer represents the address of the handle.
-    - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
-      code of the class.
-
-    Returns
-    -------
-    cls : class
-        The class being registered.
-
-    Example
-    -------
-    The following code registers user defined class
-    MyTensor to be DLTensor compatible.
-
-    .. code-block:: python
-
-       @tvm.register_extension
-       class MyTensor(object):
-           _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
-
-           def __init__(self):
-               self.handle = _LIB.NewDLTensor()
-
-           @property
-           def _tvm_handle(self):
-               return self.handle.value
-    """
-    assert hasattr(cls, "_tvm_tcode")
-    if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
-        raise ValueError("Cannot register create when extension tcode is same as buildin")
-    _reg_extension(cls, fcreate)
-    return cls
index 83d4129..a808580 100644 (file)
 # under the License.
 # pylint: disable=invalid-name, unused-import
 """Runtime Object API"""
-from __future__ import absolute_import
-
-import sys
 import ctypes
 from .. import _api_internal
 from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
-from .object_generic import ObjectGeneric, convert_to_object, const
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
 
 try:
     # pylint: disable=wrong-import-position,unused-import
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    if sys.version_info >= (3, 0):
-        from ._cy3.core import _set_class_object
-        from ._cy3.core import ObjectBase as _ObjectBase
-        from ._cy3.core import _register_object
-    else:
-        from ._cy2.core import _set_class_object
-        from ._cy2.core import ObjectBase as _ObjectBase
-        from ._cy2.core import _register_object
-except IMPORT_EXCEPT:
+    from ._cy3.core import _set_class_object, _set_class_object_generic
+    from ._cy3.core import ObjectBase
+except (RuntimeError, ImportError):
     # pylint: disable=wrong-import-position,unused-import
-    from ._ctypes.function import _set_class_object
-    from ._ctypes.object import ObjectBase as _ObjectBase
-    from ._ctypes.object import _register_object
+    from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
+    from ._ctypes.object import ObjectBase
 
 
 def _new_object(cls):
@@ -50,7 +37,7 @@ def _new_object(cls):
     return cls.__new__(cls)
 
 
-class Object(_ObjectBase):
+class Object(ObjectBase):
     """Base class for all tvm's runtime objects."""
     def __repr__(self):
         return _api_internal._format_str(self)
@@ -104,52 +91,6 @@ class Object(_ObjectBase):
         return self.__hash__() == other.__hash__()
 
 
-def register_object(type_key=None):
-    """register object type.
-
-    Parameters
-    ----------
-    type_key : str or cls
-        The type key of the node
-
-    Examples
-    --------
-    The following code registers MyObject
-    using type key "test.MyObject"
-
-    .. code-block:: python
-
-      @tvm.register_object("test.MyObject")
-      class MyObject(Object):
-          pass
-    """
-    object_name = type_key if isinstance(type_key, str) else type_key.__name__
-
-    def register(cls):
-        """internal register function"""
-        if hasattr(cls, "_type_index"):
-            tindex = cls._type_index
-        else:
-            tidx = ctypes.c_uint()
-            if not _RUNTIME_ONLY:
-                check_call(_LIB.TVMObjectTypeKey2Index(
-                    c_str(object_name), ctypes.byref(tidx)))
-            else:
-                # directly skip unknown objects during runtime.
-                ret = _LIB.TVMObjectTypeKey2Index(
-                    c_str(object_name), ctypes.byref(tidx))
-                if ret != 0:
-                    return cls
-            tindex = tidx.value
-        _register_object(tindex, cls)
-        return cls
-
-    if isinstance(type_key, str):
-        return register
-
-    return register(type_key)
-
-
 def getitem_helper(obj, elem_getter, length, idx):
     """Helper function to implement a pythonic getitem function.
 
index 92e73ad..cbbca4d 100644 (file)
 # under the License.
 """Common implementation of object generic related logic"""
 # pylint: disable=unused-import
-from __future__ import absolute_import
-
 from numbers import Number, Integral
 from .. import _api_internal
-from .base import string_types
-
-# Object base class
-_CLASS_OBJECTS = None
-
-def _set_class_objects(cls):
-    global _CLASS_OBJECTS
-    _CLASS_OBJECTS = cls
 
-
-def _scalar_type_inference(value):
-    if hasattr(value, 'dtype'):
-        dtype = str(value.dtype)
-    elif isinstance(value, bool):
-        dtype = 'bool'
-    elif isinstance(value, float):
-        # We intentionally convert the float to float32 since it's more common in DL.
-        dtype = 'float32'
-    elif isinstance(value, int):
-        # We intentionally convert the python int to int32 since it's more common in DL.
-        dtype = 'int32'
-    else:
-        raise NotImplementedError('Cannot automatically inference the type.'
-                                  ' value={}'.format(value))
-    return dtype
+from .base import string_types
+from .object import ObjectBase, _set_class_object_generic
+from .ndarray import NDArrayBase
+from .packed_func import PackedFuncBase, convert_to_tvm_func
+from .module import ModuleBase
 
 
 class ObjectGeneric(object):
@@ -54,6 +33,9 @@ class ObjectGeneric(object):
         raise NotImplementedError()
 
 
+_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
+
+
 def convert_to_object(value):
     """Convert a python value to corresponding object type.
 
@@ -95,22 +77,65 @@ def convert_to_object(value):
     raise ValueError("don't know how to convert type %s to object" % type(value))
 
 
+def convert(value):
+    """Convert value to TVM object or function.
+
+    Parameters
+    ----------
+    value : python value
+
+    Returns
+    -------
+    tvm_val : Object or Function
+        Converted value in TVM
+    """
+    if isinstance(value, (PackedFuncBase, ObjectBase)):
+        return value
+
+    if callable(value):
+        return convert_to_tvm_func(value)
+
+    return convert_to_object(value)
+
+
+def _scalar_type_inference(value):
+    if hasattr(value, 'dtype'):
+        dtype = str(value.dtype)
+    elif isinstance(value, bool):
+        dtype = 'bool'
+    elif isinstance(value, float):
+        # We intentionally convert the float to float32 since it's more common in DL.
+        dtype = 'float32'
+    elif isinstance(value, int):
+        # We intentionally convert the python int to int32 since it's more common in DL.
+        dtype = 'int32'
+    else:
+        raise NotImplementedError('Cannot automatically inference the type.'
+                                  ' value={}'.format(value))
+    return dtype
+
 def const(value, dtype=None):
-    """Construct a constant value for a given type.
+    """construct a constant
 
     Parameters
     ----------
-    value : int or float
-        The input value
+    value : number
+        The content of the constant number.
 
     dtype : str or None, optional
         The data type.
 
     Returns
     -------
-    expr : Expr
-        Constant expression corresponds to the value.
+    const_val: tvm.Expr
+        The result expression.
     """
     if dtype is None:
         dtype = _scalar_type_inference(value)
+    if dtype == "uint64" and value >= (1 << 63):
+        return _api_internal._LargeUIntImm(
+            dtype, value & ((1 << 32) - 1), value >> 32)
     return _api_internal._const(value, dtype)
+
+
+_set_class_object_generic(ObjectGeneric, convert_to_object)
diff --git a/python/tvm/_ffi/packed_func.py b/python/tvm/_ffi/packed_func.py
new file mode 100644 (file)
index 0000000..d0917a8
--- /dev/null
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-import
+"""Packed Function namespace."""
+import ctypes
+from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
+
+try:
+    # pylint: disable=wrong-import-position
+    if _FFI_MODE == "ctypes":
+        raise ImportError()
+    from ._cy3.core import _set_class_packed_func, _set_class_module
+    from ._cy3.core import PackedFuncBase
+    from ._cy3.core import convert_to_tvm_func
+except (RuntimeError, ImportError):
+    # pylint: disable=wrong-import-position
+    from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
+    from ._ctypes.packed_func import PackedFuncBase
+    from ._ctypes.packed_func import convert_to_tvm_func
+
+
+PackedFuncHandle = ctypes.c_void_p
+
+class PackedFunc(PackedFuncBase):
+    """The PackedFunc object used in TVM.
+
+    Function plays an key role to bridge front and backend in TVM.
+    Function provide a type-erased interface, you can call function with positional arguments.
+
+    The compiled module returns Function.
+    TVM backend also registers and exposes its API as Functions.
+    For example, the developer function exposed in tvm.ir_pass are actually
+    C++ functions that are registered as PackedFunc
+
+    The following are list of common usage scenario of tvm.Function.
+
+    - Automatic exposure of C++ API into python
+    - To call PackedFunc from python side
+    - To call python callbacks to inspect results in generated code
+    - Bring python hook into C++ backend
+
+    See Also
+    --------
+    tvm.register_func: How to register global function.
+    tvm.get_global_func: How to get global function.
+    """
+
+_set_class_packed_func(PackedFunc)
similarity index 53%
rename from python/tvm/_ffi/function.py
rename to python/tvm/_ffi/registry.py
index 22e0356..be15785 100644 (file)
 # under the License.
 
 # pylint: disable=invalid-name, unused-import
-"""Function namespace."""
-from __future__ import absolute_import
-
+"""FFI registry to register function and objects."""
 import sys
 import ctypes
-from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
-from .object_generic import _set_class_objects
+from .. import _api_internal
 
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
+from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
 
 try:
-    # pylint: disable=wrong-import-position
+    # pylint: disable=wrong-import-position,unused-import
     if _FFI_MODE == "ctypes":
         raise ImportError()
-    if sys.version_info >= (3, 0):
-        from ._cy3.core import _set_class_function, _set_class_module
-        from ._cy3.core import FunctionBase as _FunctionBase
-        from ._cy3.core import NDArrayBase as _NDArrayBase
-        from ._cy3.core import ObjectBase as _ObjectBase
-        from ._cy3.core import convert_to_tvm_func
-    else:
-        from ._cy2.core import _set_class_function, _set_class_module
-        from ._cy2.core import FunctionBase as _FunctionBase
-        from ._cy2.core import NDArrayBase as _NDArrayBase
-        from ._cy2.core import ObjectBase as _ObjectBase
-        from ._cy2.core import convert_to_tvm_func
-except IMPORT_EXCEPT:
-    # pylint: disable=wrong-import-position
-    from ._ctypes.function import _set_class_function, _set_class_module
-    from ._ctypes.function import FunctionBase as _FunctionBase
-    from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
-    from ._ctypes.object import ObjectBase as _ObjectBase
-    from ._ctypes.function import convert_to_tvm_func
-
-FunctionHandle = ctypes.c_void_p
-
-class Function(_FunctionBase):
-    """The PackedFunc object used in TVM.
-
-    Function plays an key role to bridge front and backend in TVM.
-    Function provide a type-erased interface, you can call function with positional arguments.
-
-    The compiled module returns Function.
-    TVM backend also registers and exposes its API as Functions.
-    For example, the developer function exposed in tvm.ir_pass are actually
-    C++ functions that are registered as PackedFunc
-
-    The following are list of common usage scenario of tvm.Function.
-
-    - Automatic exposure of C++ API into python
-    - To call PackedFunc from python side
-    - To call python callbacks to inspect results in generated code
-    - Bring python hook into C++ backend
-
-    See Also
+    from ._cy3.core import _register_object
+    from ._cy3.core import _reg_extension
+    from ._cy3.core import convert_to_tvm_func, _get_global_func, PackedFuncBase
+except (RuntimeError, ImportError):
+    # pylint: disable=wrong-import-position,unused-import
+    from ._ctypes.object import _register_object
+    from ._ctypes.ndarray import _reg_extension
+    from ._ctypes.packed_func import convert_to_tvm_func, _get_global_func, PackedFuncBase
+
+
+def register_object(type_key=None):
+    """register object type.
+
+    Parameters
+    ----------
+    type_key : str or cls
+        The type key of the node
+
+    Examples
     --------
-    tvm.register_func: How to register global function.
-    tvm.get_global_func: How to get global function.
+    The following code registers MyObject
+    using type key "test.MyObject"
+
+    .. code-block:: python
+
+      @tvm.register_object("test.MyObject")
+      class MyObject(Object):
+          pass
     """
+    object_name = type_key if isinstance(type_key, str) else type_key.__name__
+
+    def register(cls):
+        """internal register function"""
+        if hasattr(cls, "_type_index"):
+            tindex = cls._type_index
+        else:
+            tidx = ctypes.c_uint()
+            if not _RUNTIME_ONLY:
+                check_call(_LIB.TVMObjectTypeKey2Index(
+                    c_str(object_name), ctypes.byref(tidx)))
+            else:
+                # directly skip unknown objects during runtime.
+                ret = _LIB.TVMObjectTypeKey2Index(
+                    c_str(object_name), ctypes.byref(tidx))
+                if ret != 0:
+                    return cls
+            tindex = tidx.value
+        _register_object(tindex, cls)
+        return cls
+
+    if isinstance(type_key, str):
+        return register
+
+    return register(type_key)
+
+
+def register_extension(cls, fcreate=None):
+    """Register a extension class to TVM.
+
+    After the class is registered, the class will be able
+    to directly pass as Function argument generated by TVM.
 
+    Parameters
+    ----------
+    cls : class
+        The class object to be registered as extension.
+
+    fcreate : function, optional
+        The creation function to create a class object given handle value.
 
-class ModuleBase(object):
-    """Base class for module"""
-    __slots__ = ["handle", "_entry", "entry_name"]
-
-    def __init__(self, handle):
-        self.handle = handle
-        self._entry = None
-        self.entry_name = "__tvm_main__"
-
-    def __del__(self):
-        check_call(_LIB.TVMModFree(self.handle))
-
-    def __hash__(self):
-        return ctypes.cast(self.handle, ctypes.c_void_p).value
-
-    @property
-    def entry_func(self):
-        """Get the entry function
-
-        Returns
-        -------
-        f : Function
-            The entry function if exist
-        """
-        if self._entry:
-            return self._entry
-        self._entry = self.get_function(self.entry_name)
-        return self._entry
-
-    def get_function(self, name, query_imports=False):
-        """Get function from the module.
-
-        Parameters
-        ----------
-        name : str
-            The name of the function
-
-        query_imports : bool
-            Whether also query modules imported by this module.
-
-        Returns
-        -------
-        f : Function
-            The result function.
-        """
-        ret_handle = FunctionHandle()
-        check_call(_LIB.TVMModGetFunction(
-            self.handle, c_str(name),
-            ctypes.c_int(query_imports),
-            ctypes.byref(ret_handle)))
-        if not ret_handle.value:
-            raise AttributeError(
-                "Module has no function '%s'" %  name)
-        return Function(ret_handle, False)
-
-    def import_module(self, module):
-        """Add module to the import list of current one.
-
-        Parameters
-        ----------
-        module : Module
-            The other module.
-        """
-        check_call(_LIB.TVMModImport(self.handle, module.handle))
-
-    def __getitem__(self, name):
-        if not isinstance(name, string_types):
-            raise ValueError("Can only take string as function name")
-        return self.get_function(name)
-
-    def __call__(self, *args):
-        if self._entry:
-            return self._entry(*args)
-        f = self.entry_func
-        return f(*args)
+    Note
+    ----
+    The registered class is requires one property: _tvm_handle.
+
+    If the registered class is a subclass of NDArray,
+    it is required to have a class attribute _array_type_code.
+    Otherwise, it is required to have a class attribute _tvm_tcode.
+
+    - ```_tvm_handle``` returns integer represents the address of the handle.
+    - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
+      code of the class.
+
+    Returns
+    -------
+    cls : class
+        The class being registered.
+
+    Example
+    -------
+    The following code registers user defined class
+    MyTensor to be DLTensor compatible.
+
+    .. code-block:: python
+
+       @tvm.register_extension
+       class MyTensor(object):
+           _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
+
+           def __init__(self):
+               self.handle = _LIB.NewDLTensor()
+
+           @property
+           def _tvm_handle(self):
+               return self.handle.value
+    """
+    assert hasattr(cls, "_tvm_tcode")
+    if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
+        raise ValueError("Cannot register create when extension tcode is same as buildin")
+    _reg_extension(cls, fcreate)
+    return cls
 
 
 def register_func(func_name, f=None, override=False):
@@ -189,7 +174,7 @@ def register_func(func_name, f=None, override=False):
           return 10
       # Get it out from global function table
       f = tvm.get_global_func("my_packed_func")
-      assert isinstance(f, tvm.nd.Function)
+      assert isinstance(f, tvm.PackedFunc)
       y = f(*targs)
       assert y == 10
     """
@@ -203,7 +188,7 @@ def register_func(func_name, f=None, override=False):
     ioverride = ctypes.c_int(override)
     def register(myf):
         """internal register function"""
-        if not isinstance(myf, Function):
+        if not isinstance(myf, PackedFuncBase):
             myf = convert_to_tvm_func(myf)
         check_call(_LIB.TVMFuncRegisterGlobal(
             c_str(func_name), myf.handle, ioverride))
@@ -226,19 +211,10 @@ def get_global_func(name, allow_missing=False):
 
     Returns
     -------
-    func : tvm.Function
+    func : PackedFunc
         The function to be returned, None if function is missing.
     """
-    handle = FunctionHandle()
-    check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
-    if handle.value:
-        return Function(handle, False)
-
-    if allow_missing:
-        return None
-
-    raise ValueError("Cannot find global function %s" % name)
-
+    return _get_global_func(name, allow_missing)
 
 
 def list_global_func_names():
@@ -290,6 +266,7 @@ def _get_api(f):
     flocal.is_global = True
     return flocal
 
+
 def _init_api(namespace, target_module_name=None):
     """Initialize api for a given module name
 
@@ -330,6 +307,3 @@ def _init_api_prefix(module_name, prefix):
         ff.__name__ = fname
         ff.__doc__ = ("TVM PackedFunc %s. " % fname)
         setattr(target_module, ff.__name__, ff)
-
-_set_class_function(Function)
-_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
index 93c260a..d6d9b3a 100644 (file)
@@ -16,8 +16,6 @@
 # under the License.
 """Common runtime ctypes."""
 # pylint: disable=invalid-name
-from __future__ import absolute_import
-
 import ctypes
 import json
 import numpy as np
index 46faae3..573732e 100644 (file)
 # under the License.
 """Functions defined in TVM."""
 # pylint: disable=invalid-name,unused-import,redefined-builtin
-from __future__ import absolute_import as _abs
-
 from numbers import Integral as _Integral
 
+import tvm._ffi
+
 from ._ffi.base import string_types, TVMError
-from ._ffi.object import register_object, Object
-from ._ffi.object import convert_to_object as _convert_to_object
-from ._ffi.object_generic import _scalar_type_inference
-from ._ffi.function import Function
-from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
-from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
+from ._ffi.object_generic import convert, const
+from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
 from ._ffi.runtime_ctypes import TVMType
 from . import _api_internal
 from . import make as _make
@@ -75,30 +71,6 @@ def max_value(dtype):
     return _api_internal._max_value(dtype)
 
 
-def const(value, dtype=None):
-    """construct a constant
-
-    Parameters
-    ----------
-    value : number
-        The content of the constant number.
-
-    dtype : str or None, optional
-        The data type.
-
-    Returns
-    -------
-    const_val: tvm.Expr
-        The result expression.
-    """
-    if dtype is None:
-        dtype = _scalar_type_inference(value)
-    if dtype == "uint64" and value >= (1 << 63):
-        return _api_internal._LargeUIntImm(
-            dtype, value & ((1 << 32) - 1), value >> 32)
-    return _api_internal._const(value, dtype)
-
-
 def get_env_func(name):
     """Get an EnvFunc by a global name.
 
@@ -121,27 +93,6 @@ def get_env_func(name):
     return _api_internal._EnvFuncGet(name)
 
 
-def convert(value):
-    """Convert value to TVM node or function.
-
-    Parameters
-    ----------
-    value : python value
-
-    Returns
-    -------
-    tvm_val : Object or Function
-        Converted value in TVM
-    """
-    if isinstance(value, (Function, Object)):
-        return value
-
-    if callable(value):
-        return _convert_tvm_func(value)
-
-    return _convert_to_object(value)
-
-
 def load_json(json_str):
     """Load tvm object from json_str.
 
@@ -1073,10 +1024,9 @@ def floormod(a, b):
     """
     return _make._OpFloorMod(a, b)
 
-
-_init_api("tvm.api")
-
 #pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
 max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
+
+tvm._ffi._init_api("tvm.api")
index 81f478c..7434aee 100644 (file)
@@ -16,9 +16,9 @@
 # under the License.
 """Arithmetic data structure and utility"""
 from __future__ import absolute_import as _abs
+import tvm._ffi
 
-from ._ffi.object import Object, register_object
-from ._ffi.function import _init_api
+from ._ffi.object import Object
 from . import _api_internal
 
 class IntSet(Object):
@@ -32,7 +32,7 @@ class IntSet(Object):
         return _api_internal._IntSetIsEverything(self)
 
 
-@register_object("arith.IntervalSet")
+@tvm._ffi.register_object("arith.IntervalSet")
 class IntervalSet(IntSet):
     """Represent set of continuous interval [min_value, max_value]
 
@@ -49,7 +49,7 @@ class IntervalSet(IntSet):
             _make_IntervalSet, min_value, max_value)
 
 
-@register_object("arith.ModularSet")
+@tvm._ffi.register_object("arith.ModularSet")
 class ModularSet(Object):
     """Represent range of (coeff * x + base) for x in Z """
     def __init__(self, coeff, base):
@@ -57,7 +57,7 @@ class ModularSet(Object):
             _make_ModularSet, coeff, base)
 
 
-@register_object("arith.ConstIntBound")
+@tvm._ffi.register_object("arith.ConstIntBound")
 class ConstIntBound(Object):
     """Represent constant integer bound
 
@@ -258,4 +258,4 @@ class Analyzer:
                 "Do not know how to handle type {}".format(type(info)))
 
 
-_init_api("tvm.arith")
+tvm._ffi._init_api("tvm.arith")
index 2963a0e..78d5b18 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """ TVM Attribute module, which is mainly used for defining attributes of operators"""
-from ._ffi.object import Object, register_object
-from ._ffi.function import _init_api
+import tvm._ffi
+
+from ._ffi.object import Object
 from . import _api_internal
 
 
-@register_object
+@tvm._ffi.register_object
 class Attrs(Object):
     """Attribute node, which is mainly use for defining attributes of relay operators.
 
@@ -92,4 +93,4 @@ class Attrs(Object):
         return self.__getattr__(item)
 
 
-_init_api("tvm.attrs")
+tvm._ffi._init_api("tvm.attrs")
index 85d2b85..c5097b2 100644 (file)
 This module provides the functions to transform schedule to
 LoweredFunc and compiled Module.
 """
-from __future__ import absolute_import as _abs
 import warnings
+import tvm._ffi
 
-from ._ffi.function import Function
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
 from . import api
 from . import _api_internal
 from . import tensor
@@ -115,7 +114,7 @@ class DumpIR(object):
         DumpIR.scope_level -= 1
 
 
-@register_object
+@tvm._ffi.register_object
 class BuildConfig(Object):
     """Configuration scope to set a build config option.
 
index 61ee1f7..7dc7bea 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Code generation related functions."""
-from ._ffi.function import _init_api
+import tvm._ffi
 
 def build_module(lowered_func, target):
     """Build lowered_func into Module.
@@ -35,4 +35,4 @@ def build_module(lowered_func, target):
     """
     return _Build(lowered_func, target)
 
-_init_api("tvm.codegen")
+tvm._ffi._init_api("tvm.codegen")
index 673afb4..b74cc04 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Container data structures used in TVM DSL."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
 from tvm import ndarray as _nd
 from . import _api_internal
-from ._ffi.object import Object, register_object, getitem_helper
-from ._ffi.function import _init_api
+from ._ffi.object import Object, getitem_helper
+
 
-@register_object
+@tvm._ffi.register_object
 class Array(Object):
     """Array container of TVM.
 
@@ -52,7 +53,7 @@ class Array(Object):
         return _api_internal._ArraySize(self)
 
 
-@register_object
+@tvm._ffi.register_object
 class EnvFunc(Object):
     """Environment function.
 
@@ -66,7 +67,7 @@ class EnvFunc(Object):
         return _api_internal._EnvFuncGetPackedFunc(self)
 
 
-@register_object
+@tvm._ffi.register_object
 class Map(Object):
     """Map container of TVM.
 
@@ -89,7 +90,7 @@ class Map(Object):
         return _api_internal._MapSize(self)
 
 
-@register_object
+@tvm._ffi.register_object
 class StrMap(Map):
     """A special map container that has str as key.
 
@@ -101,7 +102,7 @@ class StrMap(Map):
         return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
 
 
-@register_object
+@tvm._ffi.register_object
 class Range(Object):
     """Represent a range in TVM.
 
@@ -110,7 +111,7 @@ class Range(Object):
     """
 
 
-@register_object
+@tvm._ffi.register_object
 class LoweredFunc(Object):
     """Represent a LoweredFunc in TVM."""
     MixedFunc = 0
@@ -118,7 +119,7 @@ class LoweredFunc(Object):
     DeviceFunc = 2
 
 
-@register_object("vm.ADT")
+@tvm._ffi.register_object("vm.ADT")
 class ADT(Object):
     """Algebatic data type(ADT) object.
 
@@ -168,4 +169,4 @@ def tuple_object(fields=None):
     return _Tuple(*fields)
 
 
-_init_api("tvm.container")
+tvm._ffi._init_api("tvm.container")
index 7d150c7..a5f6e30 100644 (file)
@@ -19,8 +19,9 @@
 import os
 import tempfile
 import shutil
+import tvm._ffi
+
 from tvm._ffi.base import string_types
-from tvm._ffi.function import get_global_func
 from tvm.contrib import graph_runtime
 from tvm.ndarray import array
 from . import debug_result
@@ -64,7 +65,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
             fcreate = ctx[0]._rpc_sess.get_function(
                 "tvm.graph_runtime_debug.create")
         else:
-            fcreate = get_global_func("tvm.graph_runtime_debug.create")
+            fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create")
     except ValueError:
         raise ValueError(
             "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
index 2c945d2..6b7c099 100644 (file)
@@ -16,9 +16,9 @@
 # under the License.
 """Minimum graph runtime that executes graph containing TVM PackedFunc."""
 import numpy as np
+import tvm._ffi
 
 from .._ffi.base import string_types
-from .._ffi.function import get_global_func
 from .._ffi.runtime_ctypes import TVMContext
 from ..rpc import base as rpc_base
 
@@ -54,7 +54,7 @@ def create(graph_json_str, libmod, ctx):
     if num_rpc_ctx == len(ctx):
         fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
     else:
-        fcreate = get_global_func("tvm.graph_runtime.create")
+        fcreate = tvm._ffi.get_global_func("tvm.graph_runtime.create")
 
     return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
 
index aceab6d..3e2132e 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to NNPACK libraries."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
 
 from .. import api as _api
 from .. import intrin as _intrin
-from .._ffi.function import _init_api
+
 
 def is_available():
     """Check whether NNPACK is available, that is, `nnp_initialize()`
@@ -202,4 +202,4 @@ def convolution_inference_weight_transform(
             "tvm.contrib.nnpack.convolution_inference_weight_transform",
             ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
 
-_init_api("tvm.contrib.nnpack")
+tvm._ffi._init_api("tvm.contrib.nnpack")
index a57fac0..059bf23 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to random library."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
 
 from .. import api as _api
 from .. import intrin as _intrin
-from .._ffi.function import _init_api
 
 
 def randint(low, high, size, dtype='int32'):
@@ -96,4 +95,4 @@ def normal(loc, scale, size):
         "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
 
 
-_init_api("tvm.contrib.random")
+tvm._ffi._init_api("tvm.contrib.random")
index 5ff30a1..985c747 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """TFLite runtime that load and run tflite models."""
-from .._ffi.function import get_global_func
+import tvm._ffi
 from ..rpc import base as rpc_base
 
 def create(tflite_model_bytes, ctx, runtime_target='cpu'):
@@ -44,7 +44,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'):
     if device_type >= rpc_base.RPC_SESS_MASK:
         fcreate = ctx._rpc_sess.get_function(runtime_func)
     else:
-        fcreate = get_global_func(runtime_func)
+        fcreate = tvm._ffi.get_global_func(runtime_func)
 
     return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
 
index df3e3a6..809e435 100644 (file)
@@ -15,9 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Custom datatype functionality"""
-from __future__ import absolute_import as _abs
+import tvm._ffi
 
-from ._ffi.function import register_func as _register_func
 from . import make as _make
 from .api import convert
 from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
@@ -111,7 +110,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
     else:
         lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
                           + type_name
-    _register_func(lower_func_name, lower_func)
+    tvm._ffi.register_func(lower_func_name, lower_func)
 
 
 def create_lower_func(extern_func_name):
index 20d9d89..46a7eac 100644 (file)
@@ -32,7 +32,10 @@ For example, you can use addexp.a to get the left operand of an Add node.
 """
 # pylint: disable=missing-docstring
 from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object, ObjectGeneric
+import tvm._ffi
+
+from ._ffi.object import Object
+from ._ffi.object_generic import ObjectGeneric
 from ._ffi.runtime_ctypes import TVMType, TypeCode
 from . import make as _make
 from . import generic as _generic
@@ -261,7 +264,7 @@ class CmpExpr(PrimExpr):
 class LogicalExpr(PrimExpr):
     pass
 
-@register_object("Variable")
+@tvm._ffi.register_object("Variable")
 class Var(PrimExpr):
     """Symbolic variable.
 
@@ -278,7 +281,7 @@ class Var(PrimExpr):
             _api_internal._Var, name, dtype)
 
 
-@register_object
+@tvm._ffi.register_object
 class SizeVar(Var):
     """Symbolic variable to represent a tensor index size
        which is greater or equal to zero
@@ -297,7 +300,7 @@ class SizeVar(Var):
             _api_internal._SizeVar, name, dtype)
 
 
-@register_object
+@tvm._ffi.register_object
 class Reduce(PrimExpr):
     """Reduce node.
 
@@ -324,7 +327,7 @@ class Reduce(PrimExpr):
             condition, value_index)
 
 
-@register_object
+@tvm._ffi.register_object
 class FloatImm(ConstExpr):
     """Float constant.
 
@@ -340,7 +343,7 @@ class FloatImm(ConstExpr):
         self.__init_handle_by_constructor__(
             _make.FloatImm, dtype, value)
 
-@register_object
+@tvm._ffi.register_object
 class IntImm(ConstExpr):
     """Int constant.
 
@@ -360,7 +363,7 @@ class IntImm(ConstExpr):
         return self.value
 
 
-@register_object
+@tvm._ffi.register_object
 class StringImm(ConstExpr):
     """String constant.
 
@@ -384,7 +387,7 @@ class StringImm(ConstExpr):
         return self.value != other
 
 
-@register_object
+@tvm._ffi.register_object
 class Cast(PrimExpr):
     """Cast expression.
 
@@ -401,7 +404,7 @@ class Cast(PrimExpr):
             _make.Cast, dtype, value)
 
 
-@register_object
+@tvm._ffi.register_object
 class Add(BinaryOpExpr):
     """Add node.
 
@@ -418,7 +421,7 @@ class Add(BinaryOpExpr):
             _make.Add, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Sub(BinaryOpExpr):
     """Sub node.
 
@@ -435,7 +438,7 @@ class Sub(BinaryOpExpr):
             _make.Sub, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Mul(BinaryOpExpr):
     """Mul node.
 
@@ -452,7 +455,7 @@ class Mul(BinaryOpExpr):
             _make.Mul, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Div(BinaryOpExpr):
     """Div node.
 
@@ -469,7 +472,7 @@ class Div(BinaryOpExpr):
             _make.Div, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Mod(BinaryOpExpr):
     """Mod node.
 
@@ -486,7 +489,7 @@ class Mod(BinaryOpExpr):
             _make.Mod, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class FloorDiv(BinaryOpExpr):
     """FloorDiv node.
 
@@ -503,7 +506,7 @@ class FloorDiv(BinaryOpExpr):
             _make.FloorDiv, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class FloorMod(BinaryOpExpr):
     """FloorMod node.
 
@@ -520,7 +523,7 @@ class FloorMod(BinaryOpExpr):
             _make.FloorMod, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Min(BinaryOpExpr):
     """Min node.
 
@@ -537,7 +540,7 @@ class Min(BinaryOpExpr):
             _make.Min, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Max(BinaryOpExpr):
     """Max node.
 
@@ -554,7 +557,7 @@ class Max(BinaryOpExpr):
             _make.Max, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class EQ(CmpExpr):
     """EQ node.
 
@@ -571,7 +574,7 @@ class EQ(CmpExpr):
             _make.EQ, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class NE(CmpExpr):
     """NE node.
 
@@ -588,7 +591,7 @@ class NE(CmpExpr):
             _make.NE, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class LT(CmpExpr):
     """LT node.
 
@@ -605,7 +608,7 @@ class LT(CmpExpr):
             _make.LT, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class LE(CmpExpr):
     """LE node.
 
@@ -622,7 +625,7 @@ class LE(CmpExpr):
             _make.LE, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class GT(CmpExpr):
     """GT node.
 
@@ -639,7 +642,7 @@ class GT(CmpExpr):
             _make.GT, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class GE(CmpExpr):
     """GE node.
 
@@ -656,7 +659,7 @@ class GE(CmpExpr):
             _make.GE, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class And(LogicalExpr):
     """And node.
 
@@ -673,7 +676,7 @@ class And(LogicalExpr):
             _make.And, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Or(LogicalExpr):
     """Or node.
 
@@ -690,7 +693,7 @@ class Or(LogicalExpr):
             _make.Or, a, b)
 
 
-@register_object
+@tvm._ffi.register_object
 class Not(LogicalExpr):
     """Not node.
 
@@ -704,7 +707,7 @@ class Not(LogicalExpr):
             _make.Not, a)
 
 
-@register_object
+@tvm._ffi.register_object
 class Select(PrimExpr):
     """Select node.
 
@@ -732,7 +735,7 @@ class Select(PrimExpr):
             _make.Select, condition, true_value, false_value)
 
 
-@register_object
+@tvm._ffi.register_object
 class Load(PrimExpr):
     """Load node.
 
@@ -755,7 +758,7 @@ class Load(PrimExpr):
             _make.Load, dtype, buffer_var, index, predicate)
 
 
-@register_object
+@tvm._ffi.register_object
 class Ramp(PrimExpr):
     """Ramp node.
 
@@ -775,7 +778,7 @@ class Ramp(PrimExpr):
             _make.Ramp, base, stride, lanes)
 
 
-@register_object
+@tvm._ffi.register_object
 class Broadcast(PrimExpr):
     """Broadcast node.
 
@@ -792,7 +795,7 @@ class Broadcast(PrimExpr):
             _make.Broadcast, value, lanes)
 
 
-@register_object
+@tvm._ffi.register_object
 class Shuffle(PrimExpr):
     """Shuffle node.
 
@@ -809,7 +812,7 @@ class Shuffle(PrimExpr):
             _make.Shuffle, vectors, indices)
 
 
-@register_object
+@tvm._ffi.register_object
 class Call(PrimExpr):
     """Call node.
 
@@ -844,7 +847,7 @@ class Call(PrimExpr):
             _make.Call, dtype, name, args, call_type, func, value_index)
 
 
-@register_object
+@tvm._ffi.register_object
 class Let(PrimExpr):
     """Let node.
 
index 11ecbc8..55c33e5 100644 (file)
@@ -28,13 +28,10 @@ HalideIR.
 # TODO(@were): Make this module more complete.
 # 1. Support HalideIR dumping to Hybrid Script
 # 2. Support multi-level HalideIR
-
-from __future__ import absolute_import as _abs
-
 import inspect
+import tvm._ffi
 
 from .._ffi.base import decorate
-from .._ffi.function import _init_api
 from ..build_module import form_body
 
 from .module import HybridModule
@@ -97,4 +94,4 @@ def build(sch, inputs, outputs, name="hybrid_func"):
     return HybridModule(src, name)
 
 
-_init_api("tvm.hybrid")
+tvm._ffi._init_api("tvm.hybrid")
index fd7131e..6146a71 100644 (file)
@@ -16,9 +16,9 @@
 # under the License.
 """Expression Intrinsics and math functions in TVM."""
 # pylint: disable=redefined-builtin
-from __future__ import absolute_import as _abs
+import tvm._ffi
+import tvm.codegen
 
-from ._ffi.function import register_func as _register_func
 from . import make as _make
 from .api import convert, const
 from .expr import Call as _Call
@@ -189,7 +189,6 @@ def call_llvm_intrin(dtype, name, *args):
     call : Expr
         The call expression.
     """
-    import tvm
     llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
@@ -596,7 +595,7 @@ def register_intrin_rule(target, intrin, f=None, override=False):
 
         register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
     """
-    return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
+    return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
 
 
 def _rule_float_suffix(op):
@@ -650,7 +649,7 @@ def _rule_float_direct(op):
         return call_pure_extern(op.dtype, op.name, *op.args)
     return None
 
-@_register_func("tvm.default_trace_action")
+@tvm._ffi.register_func("tvm.default_trace_action")
 def _tvm_default_trace_action(*args):
     print(list(args))
 
index ede17a1..8bd5892 100644 (file)
@@ -24,7 +24,7 @@ from . import make as _make
 from . import ir_pass as _pass
 from . import container as _container
 from ._ffi.base import string_types
-from ._ffi.object import ObjectGeneric
+from ._ffi.object_generic import ObjectGeneric
 from ._ffi.runtime_ctypes import TVMType
 from .expr import Call as _Call
 
index 59354e2..9d7f340 100644 (file)
@@ -23,6 +23,6 @@ Each api is a PackedFunc that can be called in a positional argument manner.
 You can read "include/tvm/tir/ir_pass.h" for the function signature and
 "src/api/api_pass.cc" for the PackedFunc's body of these functions.
 """
-from ._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("tvm.ir_pass")
+tvm._ffi._init_api("tvm.ir_pass")
index 241edd6..7f94d10 100644 (file)
@@ -22,8 +22,7 @@ The functions are automatically exported from C++ side via PackedFunc.
 Each api is a PackedFunc that can be called in a positional argument manner.
 You can use make function to build the IR node.
 """
-from __future__ import absolute_import as _abs
-from ._ffi.function import _init_api
+import tvm._ffi
 
 
 def range_by_min_extent(min_value, extent):
@@ -85,4 +84,4 @@ def node(type_key, **kwargs):
     return _Node(*args)
 
 
-_init_api("tvm.make")
+tvm._ffi._init_api("tvm.make")
index e2e1329..a46d1bb 100644 (file)
@@ -23,9 +23,11 @@ import sys
 from enum import Enum
 
 import tvm
+import tvm._ffi
+
 from tvm.contrib import util as _util
 from tvm.contrib import cc as _cc
-from .._ffi.function import _init_api
+
 
 class LibType(Enum):
     """Enumeration of library types that can be compiled and loaded onto a device"""
@@ -222,4 +224,4 @@ def get_micro_device_dir():
     return micro_device_dir
 
 
-_init_api("tvm.micro", "tvm.micro.base")
+tvm._ffi._init_api("tvm.micro", "tvm.micro.base")
index f8ad0a4..d810af7 100644 (file)
@@ -19,9 +19,9 @@ from __future__ import absolute_import as _abs
 
 import struct
 from collections import namedtuple
+import tvm._ffi
 
-from ._ffi.function import ModuleBase, _set_class_module
-from ._ffi.function import _init_api
+from ._ffi.module import ModuleBase, _set_class_module
 from ._ffi.libinfo import find_include_path
 from .contrib import cc as _cc, tar as _tar, util as _util
 
@@ -333,5 +333,5 @@ def enabled(target):
     return _Enabled(target)
 
 
-_init_api("tvm.module")
+tvm._ffi._init_api("tvm.module")
 _set_class_module(Module)
index b7fe780..5d78fe4 100644 (file)
@@ -20,17 +20,15 @@ tvm.ndarray provides a minimum runtime array API to test
 the correctness of the program.
 """
 # pylint: disable=invalid-name,unused-import
-from __future__ import absolute_import as _abs
+import tvm._ffi
 import numpy as _np
 
 from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
 from ._ffi.ndarray import context, empty, from_dlpack
 from ._ffi.ndarray import _set_class_ndarray
-from ._ffi.ndarray import register_extension
-from ._ffi.object import register_object
 
 
-@register_object
+@tvm._ffi.register_object
 class NDArray(NDArrayBase):
     """Lightweight NDArray class of TVM runtime.
 
diff --git a/python/tvm/object.py b/python/tvm/object.py
deleted file mode 100644 (file)
index 9659d3c..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Node is the base class of all TVM AST.
-
-Normally user do not need to touch this api.
-"""
-# pylint: disable=unused-import
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object
index 32a7324..050fcce 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI exposing the passes for Relay program analysis."""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api
-
-_init_api("relay._analysis", __name__)
+tvm._ffi._init_api("relay._analysis", __name__)
index d7ecaa8..f86aa70 100644 (file)
@@ -16,6 +16,6 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
 """The interface of expr function exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._base", __name__)
+tvm._ffi._init_api("relay._base", __name__)
index bdbcbef..9ee92e0 100644 (file)
@@ -16,6 +16,6 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
 """The interface for building Relay functions exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.build_module", __name__)
+tvm._ffi._init_api("relay.build_module", __name__)
index 07ef7e0..70c13ce 100644 (file)
@@ -16,6 +16,6 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
 """The interface of expr function exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._expr", __name__)
+tvm._ffi._init_api("relay._expr", __name__)
index 6081b26..351f7c6 100644 (file)
@@ -20,6 +20,6 @@ The constructors for all Relay AST nodes exposed from C++.
 This module includes MyPy type signatures for all of the
 exposed modules.
 """
-from .._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._make", __name__)
+tvm._ffi._init_api("relay._make", __name__)
index 365c827..aedb74a 100644 (file)
@@ -16,6 +16,6 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
 """The interface to the Module exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._module", __name__)
+tvm._ffi._init_api("relay._module", __name__)
index 273d97e..a4168df 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI exposing the Relay type inference and checking."""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api
-
-_init_api("relay._transform", __name__)
+tvm._ffi._init_api("relay._transform", __name__)
index 860788a..1db70c3 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """The interface of expr function exposed from C++."""
-from __future__ import absolute_import
+import tvm._ffi
 
 from ... import build_module as _build
 from ... import container as _container
-from ..._ffi.function import _init_api, register_func
 
 
-@register_func("relay.backend.lower")
+@tvm._ffi.register_func("relay.backend.lower")
 def lower(sch, inputs, func_name, source_func):
     """Backend function for lowering.
 
@@ -61,7 +60,7 @@ def lower(sch, inputs, func_name, source_func):
         f, (_container.Array, tuple, list)) else [f]
 
 
-@register_func("relay.backend.build")
+@tvm._ffi.register_func("relay.backend.build")
 def build(funcs, target, target_host=None):
     """Backend build function.
 
@@ -88,14 +87,14 @@ def build(funcs, target, target_host=None):
     return _build.build(funcs, target=target, target_host=target_host)
 
 
-@register_func("relay._tensor_value_repr")
+@tvm._ffi.register_func("relay._tensor_value_repr")
 def _tensor_value_repr(tvalue):
     return str(tvalue.data.asnumpy())
 
 
-@register_func("relay._constant_repr")
+@tvm._ffi.register_func("relay._constant_repr")
 def _tensor_constant_repr(tvalue):
     return str(tvalue.data.asnumpy())
 
 
-_init_api("relay.backend", __name__)
+tvm._ffi._init_api("relay.backend", __name__)
index e88f02a..cffbbdc 100644 (file)
@@ -16,6 +16,6 @@
 # under the License.
 """The Relay virtual machine FFI namespace.
 """
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._vm", __name__)
+tvm._ffi._init_api("relay._vm", __name__)
index f1cdefc..606e0cc 100644 (file)
@@ -25,7 +25,7 @@ import numpy as np
 import tvm
 import tvm.ndarray as _nd
 from tvm import autotvm, container
-from tvm.object import Object
+from tvm._ffi.object import Object
 from tvm.relay import expr as _expr
 from tvm._ffi.runtime_ctypes import TVMByteArray
 from tvm._ffi import base as _base
index d389803..a723eda 100644 (file)
@@ -16,8 +16,8 @@
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck
 """The base node types for the Relay language."""
-from __future__ import absolute_import as _abs
-from .._ffi.object import register_object as _register_tvm_node
+import tvm._ffi
+
 from .._ffi.object import Object
 from . import _make
 from . import _expr
@@ -34,9 +34,9 @@ def register_relay_node(type_key=None):
         The type key of the node.
     """
     if not isinstance(type_key, str):
-        return _register_tvm_node(
+        return tvm._ffi.register_object(
             "relay." + type_key.__name__)(type_key)
-    return _register_tvm_node(type_key)
+    return tvm._ffi.register_object(type_key)
 
 
 def register_relay_attr_node(type_key=None):
@@ -48,9 +48,9 @@ def register_relay_attr_node(type_key=None):
         The type key of the node.
     """
     if not isinstance(type_key, str):
-        return _register_tvm_node(
+        return tvm._ffi.register_object(
             "relay.attrs." + type_key.__name__)(type_key)
-    return _register_tvm_node(type_key)
+    return tvm._ffi.register_object(type_key)
 
 
 class RelayNode(Object):
index d51fee7..85c2368 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ..._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op._make", __name__)
+tvm._ffi._init_api("relay.op._make", __name__)
index ae909eb..12ece52 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.annotation._make", __name__)
+tvm._ffi._init_api("relay.op.annotation._make", __name__)
index 42d7175..9d3369e 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.contrib._make", __name__)
+tvm._ffi._init_api("relay.op.contrib._make", __name__)
index 747684b..1d5e028 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.image._make", __name__)
+tvm._ffi._init_api("relay.op.image._make", __name__)
index cdf2dcc..52a3777 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.memory._make", __name__)
+tvm._ffi._init_api("relay.op.memory._make", __name__)
index 7249685..15ae43b 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.nn._make", __name__)
+tvm._ffi._init_api("relay.op.nn._make", __name__)
index 382f667..f9bc853 100644 (file)
@@ -17,8 +17,7 @@
 #pylint: disable=unused-argument
 """The base node types for the Relay language."""
 import topi
-
-from ..._ffi.function import _init_api
+import tvm._ffi
 
 from ..base import register_relay_node
 from ..expr import Expr
@@ -283,8 +282,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
     get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
     return register(op_name, "FShapeFunc", shape_func, level)
 
-_init_api("relay.op", __name__)
-
 @register_func("relay.op.compiler._lower")
 def _lower(name, schedule, inputs, outputs):
     return lower(schedule, list(inputs) + list(outputs), name=name)
@@ -320,3 +317,5 @@ def debug(expr, debug_func=None):
         name = ''
 
     return _make.debug(expr, name)
+
+tvm._ffi._init_api("relay.op", __name__)
index f0e3170..eddca15 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.op.vision._make", __name__)
+tvm._ffi._init_api("relay.op.vision._make", __name__)
index 07b3dd1..4472bc7 100644 (file)
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay.qnn.op._make", __name__)
+tvm._ffi._init_api("relay.qnn.op._make", __name__)
index ab98f3c..ba100d8 100644 (file)
 # under the License.
 #pylint: disable=unused-argument,inconsistent-return-statements
 """Internal module for registering attribute for annotation."""
-from __future__ import absolute_import
 import warnings
-
 import topi
-from ..._ffi.function import register_func
+import tvm._ffi
+
 from .. import expr as _expr
 from .. import analysis as _analysis
 from .. import op as _op
@@ -144,7 +143,8 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
     qctx.qnode_map[key] = qnode
     return qnode
 
-register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
+tvm._ffi.register_func(
+    "relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
 
 
 @register_annotate_function("nn.contrib_conv2d_NCHWc")
index 6f5c75f..7b27b7a 100644 (file)
@@ -16,7 +16,6 @@
 # under the License.
 #pylint: disable=unused-argument
 """Internal module for quantization."""
-from __future__ import absolute_import
-from tvm._ffi.function import _init_api
+import tvm._ffi
 
-_init_api("relay._quantize", __name__)
+tvm._ffi._init_api("relay._quantize", __name__)
index a1e837c..bc81534 100644 (file)
@@ -26,8 +26,8 @@ import errno
 import struct
 import random
 import logging
+import tvm._ffi
 
-from .._ffi.function import _init_api
 from .._ffi.base import py_str
 
 # Magic header for RPC data plane
@@ -179,4 +179,4 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
 
 
 # Still use tvm.rpc for the foreign functions
-_init_api("tvm.rpc", "tvm.rpc.base")
+tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
index 9c0dea5..314f1ab 100644 (file)
@@ -21,11 +21,11 @@ import os
 import socket
 import struct
 import time
+import tvm._ffi
 
 from . import base
 from ..contrib import util
 from .._ffi.base import TVMError
-from .._ffi import function
 from .._ffi import ndarray as nd
 from ..module import load as _load_module
 
@@ -185,7 +185,7 @@ class LocalSession(RPCSession):
     def __init__(self):
         # pylint: disable=super-init-not-called
         self.context = nd.context
-        self.get_function = function.get_global_func
+        self.get_function = tvm._ffi.get_global_func
         self._temp = util.tempdir()
 
     def upload(self, data, target=None):
index 9e03097..efebe8b 100644 (file)
@@ -25,9 +25,6 @@ Server is TCP based with the following protocol:
    - {server|client}:device-type[:random-key] [-timeout=timeout]
 """
 # pylint: disable=invalid-name
-
-from __future__ import absolute_import
-
 import os
 import ctypes
 import socket
@@ -39,8 +36,8 @@ import subprocess
 import time
 import sys
 import signal
+import tvm._ffi
 
-from .._ffi.function import register_func
 from .._ffi.base import py_str
 from .._ffi.libinfo import find_lib_path
 from ..module import load as _load_module
@@ -58,11 +55,11 @@ def _server_env(load_library, work_path=None):
         temp = util.tempdir()
 
     # pylint: disable=unused-variable
-    @register_func("tvm.rpc.server.workpath")
+    @tvm._ffi.register_func("tvm.rpc.server.workpath")
     def get_workpath(path):
         return temp.relpath(path)
 
-    @register_func("tvm.rpc.server.load_module", override=True)
+    @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
     def load_module(file_name):
         """Load module from remote side."""
         path = temp.relpath(file_name)
index c8fcd7c..bf4e75f 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """The computation schedule api of TVM."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
 from ._ffi.base import string_types
-from ._ffi.object import Object, register_object
-from ._ffi.object import convert_to_object as _convert_to_object
-from ._ffi.function import _init_api, Function
-from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
+from ._ffi.object import Object
+from ._ffi.object_generic import convert
+
 from . import _api_internal
 from . import tensor as _tensor
 from . import expr as _expr
 from . import container as _container
 
-def convert(value):
-    """Convert value to TVM object or function.
-
-    Parameters
-    ----------
-    value : python value
-
-    Returns
-    -------
-    tvm_val : Object or Function
-        Converted value in TVM
-    """
-    if isinstance(value, (Function, Object)):
-        return value
-
-    if callable(value):
-        return _convert_tvm_func(value)
-
-    return _convert_to_object(value)
 
-@register_object
+@tvm._ffi.register_object
 class Buffer(Object):
     """Symbolic data buffer in TVM.
 
@@ -156,22 +137,22 @@ class Buffer(Object):
         return _api_internal._BufferVStore(self, begin, value)
 
 
-@register_object
+@tvm._ffi.register_object
 class Split(Object):
     """Split operation on axis."""
 
 
-@register_object
+@tvm._ffi.register_object
 class Fuse(Object):
     """Fuse operation on axis."""
 
 
-@register_object
+@tvm._ffi.register_object
 class Singleton(Object):
     """Singleton axis."""
 
 
-@register_object
+@tvm._ffi.register_object
 class IterVar(Object, _expr.ExprOp):
     """Represent iteration variable.
 
@@ -214,7 +195,7 @@ def create_schedule(ops):
     return _api_internal._CreateSchedule(ops)
 
 
-@register_object
+@tvm._ffi.register_object
 class Schedule(Object):
     """Schedule for all the stages."""
     def __getitem__(self, k):
@@ -348,7 +329,7 @@ class Schedule(Object):
         return factored[0] if len(factored) == 1 else factored
 
 
-@register_object
+@tvm._ffi.register_object
 class Stage(Object):
     """A Stage represents schedule for one operation."""
     def split(self, parent, factor=None, nparts=None):
@@ -670,4 +651,4 @@ class Stage(Object):
         """
         _api_internal._StageOpenGL(self)
 
-_init_api("tvm.schedule")
+tvm._ffi._init_api("tvm.schedule")
index 6b87fcb..5908934 100644 (file)
@@ -29,15 +29,15 @@ Each statement node have subfields that can be visited from python side.
     assert isinstance(st, tvm.stmt.Store)
     assert(st.buffer_var == a)
 """
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object
+import tvm._ffi
+from ._ffi.object import Object
 from . import make as _make
 
 
 class Stmt(Object):
     pass
 
-@register_object
+@tvm._ffi.register_object
 class LetStmt(Stmt):
     """LetStmt node.
 
@@ -57,7 +57,7 @@ class LetStmt(Stmt):
             _make.LetStmt, var, value, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class AssertStmt(Stmt):
     """AssertStmt node.
 
@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
             _make.AssertStmt, condition, message, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class ProducerConsumer(Stmt):
     """ProducerConsumer node.
 
@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
             _make.ProducerConsumer, func, is_producer, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class For(Stmt):
     """For node.
 
@@ -137,7 +137,7 @@ class For(Stmt):
             for_type, device_api, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class Store(Stmt):
     """Store node.
 
@@ -160,7 +160,7 @@ class Store(Stmt):
             _make.Store, buffer_var, value, index, predicate)
 
 
-@register_object
+@tvm._ffi.register_object
 class Provide(Stmt):
     """Provide node.
 
@@ -183,7 +183,7 @@ class Provide(Stmt):
             _make.Provide, func, value_index, value, args)
 
 
-@register_object
+@tvm._ffi.register_object
 class Allocate(Stmt):
     """Allocate node.
 
@@ -215,7 +215,7 @@ class Allocate(Stmt):
             extents, condition, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class AttrStmt(Stmt):
     """AttrStmt node.
 
@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
             _make.AttrStmt, node, attr_key, value, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class Free(Stmt):
     """Free node.
 
@@ -252,7 +252,7 @@ class Free(Stmt):
             _make.Free, buffer_var)
 
 
-@register_object
+@tvm._ffi.register_object
 class Realize(Stmt):
     """Realize node.
 
@@ -288,7 +288,7 @@ class Realize(Stmt):
             bounds, condition, body)
 
 
-@register_object
+@tvm._ffi.register_object
 class SeqStmt(Stmt):
     """Sequence of statements.
 
@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
         return len(self.seq)
 
 
-@register_object
+@tvm._ffi.register_object
 class IfThenElse(Stmt):
     """IfThenElse node.
 
@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
             _make.IfThenElse, condition, then_case, else_case)
 
 
-@register_object
+@tvm._ffi.register_object
 class Evaluate(Stmt):
     """Evaluate node.
 
@@ -342,7 +342,7 @@ class Evaluate(Stmt):
             _make.Evaluate, value)
 
 
-@register_object
+@tvm._ffi.register_object
 class Prefetch(Stmt):
     """Prefetch node.
 
index c2d3752..45dbf5f 100644 (file)
@@ -54,12 +54,11 @@ The list of options include:
 We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
 We can also use other specific function in this module to create specific targets.
 """
-from __future__ import absolute_import
-
 import warnings
+import tvm._ffi
 
 from ._ffi.base import _LIB_NAME
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
 from . import _api_internal
 
 try:
@@ -80,7 +79,7 @@ def _merge_opts(opts, new_opts):
     return opts
 
 
-@register_object
+@tvm._ffi.register_object
 class Target(Object):
     """Target device information, use through TVM API.
 
@@ -146,7 +145,7 @@ class Target(Object):
         _api_internal._ExitTargetScope(self)
 
 
-@register_object
+@tvm._ffi.register_object
 class GenericFunc(Object):
     """GenericFunc node reference. This represents a generic function
     that may be specialized for different targets. When this object is
index e4c36c1..522e901 100644 (file)
 # under the License.
 """Tensor and Operation class for computation declaration."""
 # pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object, ObjectGeneric, \
-        convert_to_object
+import tvm._ffi
+
+from ._ffi.object import Object
+from ._ffi.object_generic import ObjectGeneric, convert_to_object
+
 from . import _api_internal
 from . import make as _make
 from . import expr as _expr
@@ -47,7 +49,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
         """Data content of the tensor."""
         return self.tensor.dtype
 
-@register_object
+@tvm._ffi.register_object
 class TensorIntrinCall(Object):
     """Intermediate structure for calling a tensor intrinsic."""
 
@@ -55,7 +57,7 @@ class TensorIntrinCall(Object):
 itervar_cls = None
 
 
-@register_object
+@tvm._ffi.register_object
 class Tensor(Object, _expr.ExprOp):
     """Tensor object, to construct, see function.Tensor"""
 
@@ -157,12 +159,12 @@ class Operation(Object):
         return _api_internal._OpInputTensors(self)
 
 
-@register_object
+@tvm._ffi.register_object
 class PlaceholderOp(Operation):
     """Placeholder operation."""
 
 
-@register_object
+@tvm._ffi.register_object
 class BaseComputeOp(Operation):
     """Compute operation."""
     @property
@@ -176,18 +178,18 @@ class BaseComputeOp(Operation):
         return self.__getattr__("reduce_axis")
 
 
-@register_object
+@tvm._ffi.register_object
 class ComputeOp(BaseComputeOp):
     """Scalar operation."""
     pass
 
 
-@register_object
+@tvm._ffi.register_object
 class TensorComputeOp(BaseComputeOp):
     """Tensor operation."""
 
 
-@register_object
+@tvm._ffi.register_object
 class ScanOp(Operation):
     """Scan operation."""
     @property
@@ -196,12 +198,12 @@ class ScanOp(Operation):
         return self.__getattr__("scan_axis")
 
 
-@register_object
+@tvm._ffi.register_object
 class ExternOp(Operation):
     """External operation."""
 
 
-@register_object
+@tvm._ffi.register_object
 class HybridOp(Operation):
     """Hybrid operation."""
     @property
@@ -210,7 +212,7 @@ class HybridOp(Operation):
         return self.__getattr__("axis")
 
 
-@register_object
+@tvm._ffi.register_object
 class Layout(Object):
     """Layout is composed of upper cases, lower cases and numbers,
     where upper case indicates a primal axis and
@@ -270,7 +272,7 @@ class Layout(Object):
         return _api_internal._LayoutFactorOf(self, axis)
 
 
-@register_object
+@tvm._ffi.register_object
 class BijectiveLayout(Object):
     """Bijective mapping for two layouts (src-layout and dst-layout).
     It provides shape and index conversion between each other.
index 4665ccf..0b88af1 100644 (file)
@@ -15,7 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Tensor intrinsics"""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
 from . import _api_internal
 from . import api as _api
 from . import expr as _expr
@@ -24,7 +25,7 @@ from . import make as _make
 from . import tensor as _tensor
 from . import schedule as _schedule
 from .build_module import current_build_config
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
 
 
 def _get_region(tslice):
@@ -41,7 +42,7 @@ def _get_region(tslice):
             region.append(_make.range_by_min_extent(begin, 1))
     return region
 
-@register_object
+@tvm._ffi.register_object
 class TensorIntrin(Object):
     """Tensor intrinsic functions for certain computation.
 
index 920b271..efc31e8 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for CUDA TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.cuda", "topi.cuda")
+tvm._ffi._init_api("topi.cuda", "topi.cpp.cuda")
index a8a7165..e6bf250 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for generic TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.generic", "topi.generic")
+tvm._ffi._init_api("topi.generic", "topi.cpp.generic")
index 9ae407d..1081baa 100644 (file)
@@ -18,8 +18,8 @@
 import sys
 import os
 import ctypes
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
 from tvm._ffi import libinfo
 
 def _get_lib_names():
@@ -41,4 +41,4 @@ def _load_lib():
 
 _LIB, _LIB_NAME = _load_lib()
 
-_init_api_prefix("topi.cpp", "topi")
+tvm._ffi._init_api("topi", "topi.cpp")
index 59bf147..d11aa27 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for NN TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.nn", "topi.nn")
+tvm._ffi._init_api("topi.nn", "topi.cpp.nn")
index d57ce3e..c001a61 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for Rocm TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.rocm", "topi.rocm")
+tvm._ffi._init_api("topi.rocm", "topi.cpp.rocm")
index 90264bc..cc76dd9 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for TOPI utility functions"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.util", "topi.util")
+tvm._ffi._init_api("topi.util", "topi.cpp.util")
index bcdfc8c..6034e27 100644 (file)
@@ -16,9 +16,8 @@
 # under the License.
 
 """FFI for vision TOPI ops and schedules"""
-
-from tvm._ffi.function import _init_api_prefix
+import tvm._ffi
 
 from . import yolo
 
-_init_api_prefix("topi.cpp.vision", "topi.vision")
+tvm._ffi._init_api("topi.vision", "topi.cpp.vision")
index 072ab29..ff12498 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for Yolo TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
+tvm._ffi._init_api("topi.vision.yolo", "topi.cpp.vision.yolo")
index a6db26e..0681ffe 100644 (file)
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """FFI for x86 TOPI ops and schedules"""
+import tvm._ffi
 
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.x86", "topi.x86")
+tvm._ffi._init_api("topi.x86", "topi.cpp.x86")