# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
-from __future__ import absolute_import as _abs
-
import multiprocessing
import sys
import traceback
-from . import _pyversion
+# import ffi related features
+from ._ffi.base import TVMError, __version__
+from ._ffi.runtime_ctypes import TypeCode, TVMType
+from ._ffi.ndarray import TVMContext
+from ._ffi.packed_func import PackedFunc as Function
+from ._ffi.registry import register_object, register_func, register_extension
+from ._ffi.object import Object
from . import tensor
from . import arith
from . import container
from . import schedule
from . import module
-from . import object
from . import attrs
from . import ir_builder
from . import target
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
-from ._ffi.runtime_ctypes import TypeCode, TVMType
-from ._ffi.ndarray import TVMContext
-from ._ffi.function import Function
-from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
-from .object import register_object
-from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
Some performance critical functions are implemented by cython
and have a ctypes fallback implementation.
"""
+from . import _pyversion
+from .base import register_error
+from .registry import register_object, register_func, register_extension
+from .registry import _init_api, get_global_func
# under the License.
# pylint: disable=invalid-name
"""Runtime NDArray api"""
-from __future__ import absolute_import
-
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
-from __future__ import absolute_import
-
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
"""Function configuration API."""
-from __future__ import absolute_import
-
import ctypes
import traceback
from numbers import Number, Integral
-from ..base import _LIB, get_last_ffi_error, py2cerror
+from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
-from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .object import ObjectBase, _set_class_object
from . import object as _object
-FunctionHandle = ctypes.c_void_p
+PackedFuncHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
+
+def _make_packed_func(handle, is_global):
+ """Make a packed function class"""
+ obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
+ obj.is_global = is_global
+ obj.handle = handle
+ return obj
+
+
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
_ = rv
return 0
- handle = FunctionHandle()
+ handle = PackedFuncHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
if _LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
- return _CLASS_FUNCTION(handle, False)
+ return _make_packed_func(handle, False)
def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
- elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
- arg = convert_to_object(arg)
+ elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
+ arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
- elif isinstance(arg, FunctionBase):
+ elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p):
return values, type_codes, num_args
-class FunctionBase(object):
+class PackedFuncBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
Parameters
----------
- handle : FunctionHandle
+ handle : PackedFuncHandle
the handle to the underlying function.
is_global : bool
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
- if not isinstance(handle, FunctionHandle):
- handle = FunctionHandle(handle)
- return _CLASS_FUNCTION(handle, False)
+ if not isinstance(handle, PackedFuncHandle):
+ handle = PackedFuncHandle(handle)
+ return _CLASS_PACKED_FUNC(handle, False)
+
+
+def _get_global_func(name, allow_missing=False):
+ handle = PackedFuncHandle()
+ check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
+
+ if handle.value:
+ return _make_packed_func(handle, False)
+
+ if allow_missing:
+ return None
+
+ raise ValueError("Cannot find global function %s" % name)
# setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None
-_CLASS_FUNCTION = None
+_CLASS_PACKED_FUNC = None
+_CLASS_OBJECT_GENERIC = None
+_FUNC_CONVERT_TO_OBJECT = None
+
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
-def _set_class_function(func_class):
- global _CLASS_FUNCTION
- _CLASS_FUNCTION = func_class
+def _set_class_packed_func(packed_func_class):
+ global _CLASS_PACKED_FUNC
+ _CLASS_PACKED_FUNC = packed_func_class
+
+def _set_class_object_generic(object_generic_class, func_convert_to_object):
+ global _CLASS_OBJECT_GENERIC
+ global _FUNC_CONVERT_TO_OBJECT
+ _CLASS_OBJECT_GENERIC = object_generic_class
+ _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
# under the License.
"""The C Types used in API."""
# pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
-
import ctypes
import struct
from ..base import py_str, check_call, _LIB
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
-ctypedef void* TVMFunctionHandle
+ctypedef void* TVMPackedFuncHandle
ctypedef void* ObjectHandle
ctypedef struct TVMObject:
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError()
- int TVMFuncCall(TVMFunctionHandle func,
+ int TVMFuncGetGlobal(const char* name,
+ TVMPackedFuncHandle* out);
+ int TVMFuncCall(TVMPackedFuncHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
- int TVMFuncFree(TVMFunctionHandle func)
+ int TVMFuncFree(TVMPackedFuncHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
- TVMFunctionHandle *out)
+ TVMPackedFuncHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
include "./base.pxi"
include "./object.pxi"
-# include "./node.pxi"
-include "./function.pxi"
+include "./packed_func.pxi"
include "./ndarray.pxi"
-
self.chandle = NULL
cdef void* chandle
ConstructorCall(
- (<FunctionBase>fconstructor).chandle,
+ (<PackedFuncBase>fconstructor).chandle,
kTVMObjectHandle, args, &chandle)
self.chandle = chandle
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
-from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
return 0
+cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
+ obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
+ (<PackedFuncBase>obj).chandle = chandle
+ (<PackedFuncBase>obj).is_global = is_global
+ return obj
+
+
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
tvmfunc: tvm.Function
The converted tvm function.
"""
- cdef TVMFunctionHandle chandle
+ cdef TVMPackedFuncHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
- ret = _CLASS_FUNCTION(None, False)
- (<FunctionBase>ret).chandle = chandle
- return ret
+ return make_packed_func(chandle, False)
cdef inline int make_arg(object arg,
value[0].v_str = tstr
tcode[0] = kTVMStr
temp_args.append(tstr)
- elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
- arg = convert_to_object(arg)
+ elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
+ arg = _FUNC_CONVERT_TO_OBJECT(arg)
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kTVMObjectHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kTVMModuleHandle
- elif isinstance(arg, FunctionBase):
- value[0].v_handle = (<FunctionBase>arg).chandle
+ elif isinstance(arg, PackedFuncBase):
+ value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kTVMOpaqueHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
- value[0].v_handle = (<FunctionBase>arg).chandle
+ value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0
+
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
raise RuntimeError('memmove failed')
return res
+
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kTVMObjectHandle:
elif tcode == kTVMModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kTVMPackedFuncHandle:
- fobj = _CLASS_FUNCTION(None, False)
- (<FunctionBase>fobj).chandle = value.v_handle
- return fobj
+ return make_packed_func(value.v_handle, False)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
return 0
-cdef class FunctionBase:
- cdef TVMFunctionHandle chandle
+cdef class PackedFuncBase:
+ cdef TVMPackedFuncHandle chandle
cdef int is_global
cdef inline _set_handle(self, handle):
return make_ret(ret_val, ret_tcode)
-_CLASS_FUNCTION = None
+def _get_global_func(name, allow_missing):
+ cdef TVMPackedFuncHandle chandle
+ CALL(TVMFuncGetGlobal(c_str(name), &chandle))
+ if chandle != NULL:
+ return make_packed_func(chandle, True)
+
+ if allow_missing:
+ return None
+
+ raise ValueError("Cannot find global function %s" % name)
+
+
+_CLASS_PACKED_FUNC = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
+_CLASS_OBJECT_GENERIC = None
+_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
-def _set_class_function(func_class):
- global _CLASS_FUNCTION
- _CLASS_FUNCTION = func_class
+def _set_class_packed_func(func_class):
+ global _CLASS_PACKED_FUNC
+ _CLASS_PACKED_FUNC = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
+
+def _set_class_object_generic(object_generic_class, func_convert_to_object):
+ global _CLASS_OBJECT_GENERIC
+ global _FUNC_CONVERT_TO_OBJECT
+ _CLASS_OBJECT_GENERIC = object_generic_class
+ _FUNC_CONVERT_TO_OBJECT = func_convert_to_object
"""
import sys
+#----------------------------
+# Python3 version.
+#----------------------------
if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5):
PY3STATEMENT = """TVM project proudly dropped support of Python2.
The minimal Python requirement is Python 3.5
# coding: utf-8
# pylint: disable=invalid-name
"""Base library for TVM FFI."""
-from __future__ import absolute_import
-
import sys
import os
import ctypes
#----------------------------
# library loading
#----------------------------
-if sys.version_info[0] == 3:
- string_types = (str,)
- integer_types = (int, np.int32)
- numeric_types = integer_types + (float, np.float32)
- # this function is needed for python3
- # to convert ctypes.char_p .value back to python str
- if sys.platform == "win32":
- def _py_str(x):
- try:
- return x.decode('utf-8')
- except UnicodeDecodeError:
- encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
- return x.decode(encoding)
- py_str = _py_str
- else:
- py_str = lambda x: x.decode('utf-8')
+string_types = (str,)
+integer_types = (int, np.int32)
+numeric_types = integer_types + (float, np.float32)
+
+# this function is needed for python3
+# to convert ctypes.char_p .value back to python str
+if sys.platform == "win32":
+ def _py_str(x):
+ try:
+ return x.decode('utf-8')
+ except UnicodeDecodeError:
+ encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
+ return x.decode(encoding)
+ py_str = _py_str
else:
- string_types = (basestring,)
- integer_types = (int, long, np.int32)
- numeric_types = integer_types + (float, np.float32)
- py_str = lambda x: x
+ py_str = lambda x: x.decode('utf-8')
def _load_lib():
# specific language governing permissions and limitations
# under the License.
"""Library information."""
-from __future__ import absolute_import
import sys
import os
return [p.strip() for p in os.environ[env_var].split(split)]
return []
+
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-import
+"""Runtime Module namespace."""
+import ctypes
+from .base import _LIB, check_call, c_str, string_types
+from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
+
+class ModuleBase(object):
+ """Base class for module"""
+ __slots__ = ["handle", "_entry", "entry_name"]
+
+ def __init__(self, handle):
+ self.handle = handle
+ self._entry = None
+ self.entry_name = "__tvm_main__"
+
+ def __del__(self):
+ check_call(_LIB.TVMModFree(self.handle))
+
+ def __hash__(self):
+ return ctypes.cast(self.handle, ctypes.c_void_p).value
+
+ @property
+ def entry_func(self):
+ """Get the entry function
+
+ Returns
+ -------
+ f : Function
+ The entry function if exist
+ """
+ if self._entry:
+ return self._entry
+ self._entry = self.get_function(self.entry_name)
+ return self._entry
+
+ def get_function(self, name, query_imports=False):
+ """Get function from the module.
+
+ Parameters
+ ----------
+ name : str
+ The name of the function
+
+ query_imports : bool
+ Whether also query modules imported by this module.
+
+ Returns
+ -------
+ f : Function
+ The result function.
+ """
+ ret_handle = PackedFuncHandle()
+ check_call(_LIB.TVMModGetFunction(
+ self.handle, c_str(name),
+ ctypes.c_int(query_imports),
+ ctypes.byref(ret_handle)))
+ if not ret_handle.value:
+ raise AttributeError(
+ "Module has no function '%s'" % name)
+ return PackedFunc(ret_handle, False)
+
+ def import_module(self, module):
+ """Add module to the import list of current one.
+
+ Parameters
+ ----------
+ module : Module
+ The other module.
+ """
+ check_call(_LIB.TVMModImport(self.handle, module.handle))
+
+ def __getitem__(self, name):
+ if not isinstance(name, string_types):
+ raise ValueError("Can only take string as function name")
+ return self.get_function(name)
+
+ def __call__(self, *args):
+ if self._entry:
+ return self._entry(*args)
+ f = self.entry_func
+ return f(*args)
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
-from __future__ import absolute_import
-
-import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
-
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
- if sys.version_info >= (3, 0):
- from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
- from ._cy3.core import NDArrayBase as _NDArrayBase
- from ._cy3.core import _reg_extension
- else:
- from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
- from ._cy2.core import NDArrayBase as _NDArrayBase
- from ._cy2.core import _reg_extension
-except IMPORT_EXCEPT:
+ from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
+ from ._cy3.core import NDArrayBase as _NDArrayBase
+except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
- from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))
-
-
-def register_extension(cls, fcreate=None):
- """Register a extension class to TVM.
-
- After the class is registered, the class will be able
- to directly pass as Function argument generated by TVM.
-
- Parameters
- ----------
- cls : class
- The class object to be registered as extension.
-
- fcreate : function, optional
- The creation function to create a class object given handle value.
-
- Note
- ----
- The registered class is requires one property: _tvm_handle.
-
- If the registered class is a subclass of NDArray,
- it is required to have a class attribute _array_type_code.
- Otherwise, it is required to have a class attribute _tvm_tcode.
-
- - ```_tvm_handle``` returns integer represents the address of the handle.
- - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
- code of the class.
-
- Returns
- -------
- cls : class
- The class being registered.
-
- Example
- -------
- The following code registers user defined class
- MyTensor to be DLTensor compatible.
-
- .. code-block:: python
-
- @tvm.register_extension
- class MyTensor(object):
- _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
-
- def __init__(self):
- self.handle = _LIB.NewDLTensor()
-
- @property
- def _tvm_handle(self):
- return self.handle.value
- """
- assert hasattr(cls, "_tvm_tcode")
- if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
- raise ValueError("Cannot register create when extension tcode is same as buildin")
- _reg_extension(cls, fcreate)
- return cls
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
-from __future__ import absolute_import
-
-import sys
import ctypes
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
-from .object_generic import ObjectGeneric, convert_to_object, const
-
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
- if sys.version_info >= (3, 0):
- from ._cy3.core import _set_class_object
- from ._cy3.core import ObjectBase as _ObjectBase
- from ._cy3.core import _register_object
- else:
- from ._cy2.core import _set_class_object
- from ._cy2.core import ObjectBase as _ObjectBase
- from ._cy2.core import _register_object
-except IMPORT_EXCEPT:
+ from ._cy3.core import _set_class_object, _set_class_object_generic
+ from ._cy3.core import ObjectBase
+except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
- from ._ctypes.function import _set_class_object
- from ._ctypes.object import ObjectBase as _ObjectBase
- from ._ctypes.object import _register_object
+ from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
+ from ._ctypes.object import ObjectBase
def _new_object(cls):
return cls.__new__(cls)
-class Object(_ObjectBase):
+class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
def __repr__(self):
return _api_internal._format_str(self)
return self.__hash__() == other.__hash__()
-def register_object(type_key=None):
- """register object type.
-
- Parameters
- ----------
- type_key : str or cls
- The type key of the node
-
- Examples
- --------
- The following code registers MyObject
- using type key "test.MyObject"
-
- .. code-block:: python
-
- @tvm.register_object("test.MyObject")
- class MyObject(Object):
- pass
- """
- object_name = type_key if isinstance(type_key, str) else type_key.__name__
-
- def register(cls):
- """internal register function"""
- if hasattr(cls, "_type_index"):
- tindex = cls._type_index
- else:
- tidx = ctypes.c_uint()
- if not _RUNTIME_ONLY:
- check_call(_LIB.TVMObjectTypeKey2Index(
- c_str(object_name), ctypes.byref(tidx)))
- else:
- # directly skip unknown objects during runtime.
- ret = _LIB.TVMObjectTypeKey2Index(
- c_str(object_name), ctypes.byref(tidx))
- if ret != 0:
- return cls
- tindex = tidx.value
- _register_object(tindex, cls)
- return cls
-
- if isinstance(type_key, str):
- return register
-
- return register(type_key)
-
-
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
# under the License.
"""Common implementation of object generic related logic"""
# pylint: disable=unused-import
-from __future__ import absolute_import
-
from numbers import Number, Integral
from .. import _api_internal
-from .base import string_types
-
-# Object base class
-_CLASS_OBJECTS = None
-
-def _set_class_objects(cls):
- global _CLASS_OBJECTS
- _CLASS_OBJECTS = cls
-
-def _scalar_type_inference(value):
- if hasattr(value, 'dtype'):
- dtype = str(value.dtype)
- elif isinstance(value, bool):
- dtype = 'bool'
- elif isinstance(value, float):
- # We intentionally convert the float to float32 since it's more common in DL.
- dtype = 'float32'
- elif isinstance(value, int):
- # We intentionally convert the python int to int32 since it's more common in DL.
- dtype = 'int32'
- else:
- raise NotImplementedError('Cannot automatically inference the type.'
- ' value={}'.format(value))
- return dtype
+from .base import string_types
+from .object import ObjectBase, _set_class_object_generic
+from .ndarray import NDArrayBase
+from .packed_func import PackedFuncBase, convert_to_tvm_func
+from .module import ModuleBase
class ObjectGeneric(object):
raise NotImplementedError()
+_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
+
+
def convert_to_object(value):
"""Convert a python value to corresponding object type.
raise ValueError("don't know how to convert type %s to object" % type(value))
+def convert(value):
+ """Convert value to TVM object or function.
+
+ Parameters
+ ----------
+ value : python value
+
+ Returns
+ -------
+ tvm_val : Object or Function
+ Converted value in TVM
+ """
+ if isinstance(value, (PackedFuncBase, ObjectBase)):
+ return value
+
+ if callable(value):
+ return convert_to_tvm_func(value)
+
+ return convert_to_object(value)
+
+
+def _scalar_type_inference(value):
+ if hasattr(value, 'dtype'):
+ dtype = str(value.dtype)
+ elif isinstance(value, bool):
+ dtype = 'bool'
+ elif isinstance(value, float):
+ # We intentionally convert the float to float32 since it's more common in DL.
+ dtype = 'float32'
+ elif isinstance(value, int):
+ # We intentionally convert the python int to int32 since it's more common in DL.
+ dtype = 'int32'
+ else:
+ raise NotImplementedError('Cannot automatically inference the type.'
+ ' value={}'.format(value))
+ return dtype
+
def const(value, dtype=None):
- """Construct a constant value for a given type.
+ """construct a constant
Parameters
----------
- value : int or float
- The input value
+ value : number
+ The content of the constant number.
dtype : str or None, optional
The data type.
Returns
-------
- expr : Expr
- Constant expression corresponds to the value.
+ const_val: tvm.Expr
+ The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
+ if dtype == "uint64" and value >= (1 << 63):
+ return _api_internal._LargeUIntImm(
+ dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
+
+
+_set_class_object_generic(ObjectGeneric, convert_to_object)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-import
+"""Packed Function namespace."""
+import ctypes
+from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
+
+try:
+ # pylint: disable=wrong-import-position
+ if _FFI_MODE == "ctypes":
+ raise ImportError()
+ from ._cy3.core import _set_class_packed_func, _set_class_module
+ from ._cy3.core import PackedFuncBase
+ from ._cy3.core import convert_to_tvm_func
+except (RuntimeError, ImportError):
+ # pylint: disable=wrong-import-position
+ from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
+ from ._ctypes.packed_func import PackedFuncBase
+ from ._ctypes.packed_func import convert_to_tvm_func
+
+
+PackedFuncHandle = ctypes.c_void_p
+
+class PackedFunc(PackedFuncBase):
+ """The PackedFunc object used in TVM.
+
+ Function plays an key role to bridge front and backend in TVM.
+ Function provide a type-erased interface, you can call function with positional arguments.
+
+ The compiled module returns Function.
+ TVM backend also registers and exposes its API as Functions.
+ For example, the developer function exposed in tvm.ir_pass are actually
+ C++ functions that are registered as PackedFunc
+
+ The following are list of common usage scenario of tvm.Function.
+
+ - Automatic exposure of C++ API into python
+ - To call PackedFunc from python side
+ - To call python callbacks to inspect results in generated code
+ - Bring python hook into C++ backend
+
+ See Also
+ --------
+ tvm.register_func: How to register global function.
+ tvm.get_global_func: How to get global function.
+ """
+
+_set_class_packed_func(PackedFunc)
# under the License.
# pylint: disable=invalid-name, unused-import
-"""Function namespace."""
-from __future__ import absolute_import
-
+"""FFI registry to register function and objects."""
import sys
import ctypes
-from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
-from .object_generic import _set_class_objects
+from .. import _api_internal
-IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
+from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
try:
- # pylint: disable=wrong-import-position
+ # pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
- if sys.version_info >= (3, 0):
- from ._cy3.core import _set_class_function, _set_class_module
- from ._cy3.core import FunctionBase as _FunctionBase
- from ._cy3.core import NDArrayBase as _NDArrayBase
- from ._cy3.core import ObjectBase as _ObjectBase
- from ._cy3.core import convert_to_tvm_func
- else:
- from ._cy2.core import _set_class_function, _set_class_module
- from ._cy2.core import FunctionBase as _FunctionBase
- from ._cy2.core import NDArrayBase as _NDArrayBase
- from ._cy2.core import ObjectBase as _ObjectBase
- from ._cy2.core import convert_to_tvm_func
-except IMPORT_EXCEPT:
- # pylint: disable=wrong-import-position
- from ._ctypes.function import _set_class_function, _set_class_module
- from ._ctypes.function import FunctionBase as _FunctionBase
- from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
- from ._ctypes.object import ObjectBase as _ObjectBase
- from ._ctypes.function import convert_to_tvm_func
-
-FunctionHandle = ctypes.c_void_p
-
-class Function(_FunctionBase):
- """The PackedFunc object used in TVM.
-
- Function plays an key role to bridge front and backend in TVM.
- Function provide a type-erased interface, you can call function with positional arguments.
-
- The compiled module returns Function.
- TVM backend also registers and exposes its API as Functions.
- For example, the developer function exposed in tvm.ir_pass are actually
- C++ functions that are registered as PackedFunc
-
- The following are list of common usage scenario of tvm.Function.
-
- - Automatic exposure of C++ API into python
- - To call PackedFunc from python side
- - To call python callbacks to inspect results in generated code
- - Bring python hook into C++ backend
-
- See Also
+ from ._cy3.core import _register_object
+ from ._cy3.core import _reg_extension
+ from ._cy3.core import convert_to_tvm_func, _get_global_func, PackedFuncBase
+except (RuntimeError, ImportError):
+ # pylint: disable=wrong-import-position,unused-import
+ from ._ctypes.object import _register_object
+ from ._ctypes.ndarray import _reg_extension
+ from ._ctypes.packed_func import convert_to_tvm_func, _get_global_func, PackedFuncBase
+
+
+def register_object(type_key=None):
+ """register object type.
+
+ Parameters
+ ----------
+ type_key : str or cls
+ The type key of the node
+
+ Examples
--------
- tvm.register_func: How to register global function.
- tvm.get_global_func: How to get global function.
+ The following code registers MyObject
+ using type key "test.MyObject"
+
+ .. code-block:: python
+
+ @tvm.register_object("test.MyObject")
+ class MyObject(Object):
+ pass
"""
+ object_name = type_key if isinstance(type_key, str) else type_key.__name__
+
+ def register(cls):
+ """internal register function"""
+ if hasattr(cls, "_type_index"):
+ tindex = cls._type_index
+ else:
+ tidx = ctypes.c_uint()
+ if not _RUNTIME_ONLY:
+ check_call(_LIB.TVMObjectTypeKey2Index(
+ c_str(object_name), ctypes.byref(tidx)))
+ else:
+ # directly skip unknown objects during runtime.
+ ret = _LIB.TVMObjectTypeKey2Index(
+ c_str(object_name), ctypes.byref(tidx))
+ if ret != 0:
+ return cls
+ tindex = tidx.value
+ _register_object(tindex, cls)
+ return cls
+
+ if isinstance(type_key, str):
+ return register
+
+ return register(type_key)
+
+
+def register_extension(cls, fcreate=None):
+ """Register a extension class to TVM.
+
+ After the class is registered, the class will be able
+ to directly pass as Function argument generated by TVM.
+ Parameters
+ ----------
+ cls : class
+ The class object to be registered as extension.
+
+ fcreate : function, optional
+ The creation function to create a class object given handle value.
-class ModuleBase(object):
- """Base class for module"""
- __slots__ = ["handle", "_entry", "entry_name"]
-
- def __init__(self, handle):
- self.handle = handle
- self._entry = None
- self.entry_name = "__tvm_main__"
-
- def __del__(self):
- check_call(_LIB.TVMModFree(self.handle))
-
- def __hash__(self):
- return ctypes.cast(self.handle, ctypes.c_void_p).value
-
- @property
- def entry_func(self):
- """Get the entry function
-
- Returns
- -------
- f : Function
- The entry function if exist
- """
- if self._entry:
- return self._entry
- self._entry = self.get_function(self.entry_name)
- return self._entry
-
- def get_function(self, name, query_imports=False):
- """Get function from the module.
-
- Parameters
- ----------
- name : str
- The name of the function
-
- query_imports : bool
- Whether also query modules imported by this module.
-
- Returns
- -------
- f : Function
- The result function.
- """
- ret_handle = FunctionHandle()
- check_call(_LIB.TVMModGetFunction(
- self.handle, c_str(name),
- ctypes.c_int(query_imports),
- ctypes.byref(ret_handle)))
- if not ret_handle.value:
- raise AttributeError(
- "Module has no function '%s'" % name)
- return Function(ret_handle, False)
-
- def import_module(self, module):
- """Add module to the import list of current one.
-
- Parameters
- ----------
- module : Module
- The other module.
- """
- check_call(_LIB.TVMModImport(self.handle, module.handle))
-
- def __getitem__(self, name):
- if not isinstance(name, string_types):
- raise ValueError("Can only take string as function name")
- return self.get_function(name)
-
- def __call__(self, *args):
- if self._entry:
- return self._entry(*args)
- f = self.entry_func
- return f(*args)
+ Note
+ ----
+ The registered class is requires one property: _tvm_handle.
+
+ If the registered class is a subclass of NDArray,
+ it is required to have a class attribute _array_type_code.
+ Otherwise, it is required to have a class attribute _tvm_tcode.
+
+ - ```_tvm_handle``` returns integer represents the address of the handle.
+ - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
+ code of the class.
+
+ Returns
+ -------
+ cls : class
+ The class being registered.
+
+ Example
+ -------
+ The following code registers user defined class
+ MyTensor to be DLTensor compatible.
+
+ .. code-block:: python
+
+ @tvm.register_extension
+ class MyTensor(object):
+ _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
+
+ def __init__(self):
+ self.handle = _LIB.NewDLTensor()
+
+ @property
+ def _tvm_handle(self):
+ return self.handle.value
+ """
+ assert hasattr(cls, "_tvm_tcode")
+ if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
+ raise ValueError("Cannot register create when extension tcode is same as buildin")
+ _reg_extension(cls, fcreate)
+ return cls
def register_func(func_name, f=None, override=False):
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
- assert isinstance(f, tvm.nd.Function)
+ assert isinstance(f, tvm.PackedFunc)
y = f(*targs)
assert y == 10
"""
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
- if not isinstance(myf, Function):
+ if not isinstance(myf, PackedFuncBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
Returns
-------
- func : tvm.Function
+ func : PackedFunc
The function to be returned, None if function is missing.
"""
- handle = FunctionHandle()
- check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
- if handle.value:
- return Function(handle, False)
-
- if allow_missing:
- return None
-
- raise ValueError("Cannot find global function %s" % name)
-
+ return _get_global_func(name, allow_missing)
def list_global_func_names():
flocal.is_global = True
return flocal
+
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
-
-_set_class_function(Function)
-_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
# under the License.
"""Common runtime ctypes."""
# pylint: disable=invalid-name
-from __future__ import absolute_import
-
import ctypes
import json
import numpy as np
# under the License.
"""Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
-from __future__ import absolute_import as _abs
-
from numbers import Integral as _Integral
+import tvm._ffi
+
from ._ffi.base import string_types, TVMError
-from ._ffi.object import register_object, Object
-from ._ffi.object import convert_to_object as _convert_to_object
-from ._ffi.object_generic import _scalar_type_inference
-from ._ffi.function import Function
-from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
-from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
+from ._ffi.object_generic import convert, const
+from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
return _api_internal._max_value(dtype)
-def const(value, dtype=None):
- """construct a constant
-
- Parameters
- ----------
- value : number
- The content of the constant number.
-
- dtype : str or None, optional
- The data type.
-
- Returns
- -------
- const_val: tvm.Expr
- The result expression.
- """
- if dtype is None:
- dtype = _scalar_type_inference(value)
- if dtype == "uint64" and value >= (1 << 63):
- return _api_internal._LargeUIntImm(
- dtype, value & ((1 << 32) - 1), value >> 32)
- return _api_internal._const(value, dtype)
-
-
def get_env_func(name):
"""Get an EnvFunc by a global name.
return _api_internal._EnvFuncGet(name)
-def convert(value):
- """Convert value to TVM node or function.
-
- Parameters
- ----------
- value : python value
-
- Returns
- -------
- tvm_val : Object or Function
- Converted value in TVM
- """
- if isinstance(value, (Function, Object)):
- return value
-
- if callable(value):
- return _convert_tvm_func(value)
-
- return _convert_to_object(value)
-
-
def load_json(json_str):
"""Load tvm object from json_str.
"""
return _make._OpFloorMod(a, b)
-
-_init_api("tvm.api")
-
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
+
+tvm._ffi._init_api("tvm.api")
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
+import tvm._ffi
-from ._ffi.object import Object, register_object
-from ._ffi.function import _init_api
+from ._ffi.object import Object
from . import _api_internal
class IntSet(Object):
return _api_internal._IntSetIsEverything(self)
-@register_object("arith.IntervalSet")
+@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
_make_IntervalSet, min_value, max_value)
-@register_object("arith.ModularSet")
+@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base):
_make_ModularSet, coeff, base)
-@register_object("arith.ConstIntBound")
+@tvm._ffi.register_object("arith.ConstIntBound")
class ConstIntBound(Object):
"""Represent constant integer bound
"Do not know how to handle type {}".format(type(info)))
-_init_api("tvm.arith")
+tvm._ffi._init_api("tvm.arith")
# specific language governing permissions and limitations
# under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
-from ._ffi.object import Object, register_object
-from ._ffi.function import _init_api
+import tvm._ffi
+
+from ._ffi.object import Object
from . import _api_internal
-@register_object
+@tvm._ffi.register_object
class Attrs(Object):
"""Attribute node, which is mainly use for defining attributes of relay operators.
return self.__getattr__(item)
-_init_api("tvm.attrs")
+tvm._ffi._init_api("tvm.attrs")
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
-from __future__ import absolute_import as _abs
import warnings
+import tvm._ffi
-from ._ffi.function import Function
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
from . import api
from . import _api_internal
from . import tensor
DumpIR.scope_level -= 1
-@register_object
+@tvm._ffi.register_object
class BuildConfig(Object):
"""Configuration scope to set a build config option.
# specific language governing permissions and limitations
# under the License.
"""Code generation related functions."""
-from ._ffi.function import _init_api
+import tvm._ffi
def build_module(lowered_func, target):
"""Build lowered_func into Module.
"""
return _Build(lowered_func, target)
-_init_api("tvm.codegen")
+tvm._ffi._init_api("tvm.codegen")
# specific language governing permissions and limitations
# under the License.
"""Container data structures used in TVM DSL."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
from tvm import ndarray as _nd
from . import _api_internal
-from ._ffi.object import Object, register_object, getitem_helper
-from ._ffi.function import _init_api
+from ._ffi.object import Object, getitem_helper
+
-@register_object
+@tvm._ffi.register_object
class Array(Object):
"""Array container of TVM.
return _api_internal._ArraySize(self)
-@register_object
+@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
return _api_internal._EnvFuncGetPackedFunc(self)
-@register_object
+@tvm._ffi.register_object
class Map(Object):
"""Map container of TVM.
return _api_internal._MapSize(self)
-@register_object
+@tvm._ffi.register_object
class StrMap(Map):
"""A special map container that has str as key.
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
-@register_object
+@tvm._ffi.register_object
class Range(Object):
"""Represent a range in TVM.
"""
-@register_object
+@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
DeviceFunc = 2
-@register_object("vm.ADT")
+@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
return _Tuple(*fields)
-_init_api("tvm.container")
+tvm._ffi._init_api("tvm.container")
import os
import tempfile
import shutil
+import tvm._ffi
+
from tvm._ffi.base import string_types
-from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from . import debug_result
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create")
else:
- fcreate = get_global_func("tvm.graph_runtime_debug.create")
+ fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create")
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
# under the License.
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
import numpy as np
+import tvm._ffi
from .._ffi.base import string_types
-from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
else:
- fcreate = get_global_func("tvm.graph_runtime.create")
+ fcreate = tvm._ffi.get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
# specific language governing permissions and limitations
# under the License.
"""External function interface to NNPACK libraries."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
-from .._ffi.function import _init_api
+
def is_available():
"""Check whether NNPACK is available, that is, `nnp_initialize()`
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
-_init_api("tvm.contrib.nnpack")
+tvm._ffi._init_api("tvm.contrib.nnpack")
# specific language governing permissions and limitations
# under the License.
"""External function interface to random library."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
-from .._ffi.function import _init_api
def randint(low, high, size, dtype='int32'):
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
-_init_api("tvm.contrib.random")
+tvm._ffi._init_api("tvm.contrib.random")
# specific language governing permissions and limitations
# under the License.
"""TFLite runtime that load and run tflite models."""
-from .._ffi.function import get_global_func
+import tvm._ffi
from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx, runtime_target='cpu'):
if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function(runtime_func)
else:
- fcreate = get_global_func(runtime_func)
+ fcreate = tvm._ffi.get_global_func(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
# specific language governing permissions and limitations
# under the License.
"""Custom datatype functionality"""
-from __future__ import absolute_import as _abs
+import tvm._ffi
-from ._ffi.function import register_func as _register_func
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
else:
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
+ type_name
- _register_func(lower_func_name, lower_func)
+ tvm._ffi.register_func(lower_func_name, lower_func)
def create_lower_func(extern_func_name):
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object, ObjectGeneric
+import tvm._ffi
+
+from ._ffi.object import Object
+from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
class LogicalExpr(PrimExpr):
pass
-@register_object("Variable")
+@tvm._ffi.register_object("Variable")
class Var(PrimExpr):
"""Symbolic variable.
_api_internal._Var, name, dtype)
-@register_object
+@tvm._ffi.register_object
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero
_api_internal._SizeVar, name, dtype)
-@register_object
+@tvm._ffi.register_object
class Reduce(PrimExpr):
"""Reduce node.
condition, value_index)
-@register_object
+@tvm._ffi.register_object
class FloatImm(ConstExpr):
"""Float constant.
self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value)
-@register_object
+@tvm._ffi.register_object
class IntImm(ConstExpr):
"""Int constant.
return self.value
-@register_object
+@tvm._ffi.register_object
class StringImm(ConstExpr):
"""String constant.
return self.value != other
-@register_object
+@tvm._ffi.register_object
class Cast(PrimExpr):
"""Cast expression.
_make.Cast, dtype, value)
-@register_object
+@tvm._ffi.register_object
class Add(BinaryOpExpr):
"""Add node.
_make.Add, a, b)
-@register_object
+@tvm._ffi.register_object
class Sub(BinaryOpExpr):
"""Sub node.
_make.Sub, a, b)
-@register_object
+@tvm._ffi.register_object
class Mul(BinaryOpExpr):
"""Mul node.
_make.Mul, a, b)
-@register_object
+@tvm._ffi.register_object
class Div(BinaryOpExpr):
"""Div node.
_make.Div, a, b)
-@register_object
+@tvm._ffi.register_object
class Mod(BinaryOpExpr):
"""Mod node.
_make.Mod, a, b)
-@register_object
+@tvm._ffi.register_object
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
_make.FloorDiv, a, b)
-@register_object
+@tvm._ffi.register_object
class FloorMod(BinaryOpExpr):
"""FloorMod node.
_make.FloorMod, a, b)
-@register_object
+@tvm._ffi.register_object
class Min(BinaryOpExpr):
"""Min node.
_make.Min, a, b)
-@register_object
+@tvm._ffi.register_object
class Max(BinaryOpExpr):
"""Max node.
_make.Max, a, b)
-@register_object
+@tvm._ffi.register_object
class EQ(CmpExpr):
"""EQ node.
_make.EQ, a, b)
-@register_object
+@tvm._ffi.register_object
class NE(CmpExpr):
"""NE node.
_make.NE, a, b)
-@register_object
+@tvm._ffi.register_object
class LT(CmpExpr):
"""LT node.
_make.LT, a, b)
-@register_object
+@tvm._ffi.register_object
class LE(CmpExpr):
"""LE node.
_make.LE, a, b)
-@register_object
+@tvm._ffi.register_object
class GT(CmpExpr):
"""GT node.
_make.GT, a, b)
-@register_object
+@tvm._ffi.register_object
class GE(CmpExpr):
"""GE node.
_make.GE, a, b)
-@register_object
+@tvm._ffi.register_object
class And(LogicalExpr):
"""And node.
_make.And, a, b)
-@register_object
+@tvm._ffi.register_object
class Or(LogicalExpr):
"""Or node.
_make.Or, a, b)
-@register_object
+@tvm._ffi.register_object
class Not(LogicalExpr):
"""Not node.
_make.Not, a)
-@register_object
+@tvm._ffi.register_object
class Select(PrimExpr):
"""Select node.
_make.Select, condition, true_value, false_value)
-@register_object
+@tvm._ffi.register_object
class Load(PrimExpr):
"""Load node.
_make.Load, dtype, buffer_var, index, predicate)
-@register_object
+@tvm._ffi.register_object
class Ramp(PrimExpr):
"""Ramp node.
_make.Ramp, base, stride, lanes)
-@register_object
+@tvm._ffi.register_object
class Broadcast(PrimExpr):
"""Broadcast node.
_make.Broadcast, value, lanes)
-@register_object
+@tvm._ffi.register_object
class Shuffle(PrimExpr):
"""Shuffle node.
_make.Shuffle, vectors, indices)
-@register_object
+@tvm._ffi.register_object
class Call(PrimExpr):
"""Call node.
_make.Call, dtype, name, args, call_type, func, value_index)
-@register_object
+@tvm._ffi.register_object
class Let(PrimExpr):
"""Let node.
# TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR
-
-from __future__ import absolute_import as _abs
-
import inspect
+import tvm._ffi
from .._ffi.base import decorate
-from .._ffi.function import _init_api
from ..build_module import form_body
from .module import HybridModule
return HybridModule(src, name)
-_init_api("tvm.hybrid")
+tvm._ffi._init_api("tvm.hybrid")
# under the License.
"""Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
-from __future__ import absolute_import as _abs
+import tvm._ffi
+import tvm.codegen
-from ._ffi.function import register_func as _register_func
from . import make as _make
from .api import convert, const
from .expr import Call as _Call
call : Expr
The call expression.
"""
- import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
- return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
+ return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op):
return call_pure_extern(op.dtype, op.name, *op.args)
return None
-@_register_func("tvm.default_trace_action")
+@tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args):
print(list(args))
from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
-from ._ffi.object import ObjectGeneric
+from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
You can read "include/tvm/tir/ir_pass.h" for the function signature and
"src/api/api_pass.cc" for the PackedFunc's body of these functions.
"""
-from ._ffi.function import _init_api
+import tvm._ffi
-_init_api("tvm.ir_pass")
+tvm._ffi._init_api("tvm.ir_pass")
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
-from __future__ import absolute_import as _abs
-from ._ffi.function import _init_api
+import tvm._ffi
def range_by_min_extent(min_value, extent):
return _Node(*args)
-_init_api("tvm.make")
+tvm._ffi._init_api("tvm.make")
from enum import Enum
import tvm
+import tvm._ffi
+
from tvm.contrib import util as _util
from tvm.contrib import cc as _cc
-from .._ffi.function import _init_api
+
class LibType(Enum):
"""Enumeration of library types that can be compiled and loaded onto a device"""
return micro_device_dir
-_init_api("tvm.micro", "tvm.micro.base")
+tvm._ffi._init_api("tvm.micro", "tvm.micro.base")
import struct
from collections import namedtuple
+import tvm._ffi
-from ._ffi.function import ModuleBase, _set_class_module
-from ._ffi.function import _init_api
+from ._ffi.module import ModuleBase, _set_class_module
from ._ffi.libinfo import find_include_path
from .contrib import cc as _cc, tar as _tar, util as _util
return _Enabled(target)
-_init_api("tvm.module")
+tvm._ffi._init_api("tvm.module")
_set_class_module(Module)
the correctness of the program.
"""
# pylint: disable=invalid-name,unused-import
-from __future__ import absolute_import as _abs
+import tvm._ffi
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
-from ._ffi.ndarray import register_extension
-from ._ffi.object import register_object
-@register_object
+@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Node is the base class of all TVM AST.
-
-Normally user do not need to touch this api.
-"""
-# pylint: disable=unused-import
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the passes for Relay program analysis."""
+import tvm._ffi
-from tvm._ffi.function import _init_api
-
-_init_api("relay._analysis", __name__)
+tvm._ffi._init_api("relay._analysis", __name__)
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._base", __name__)
+tvm._ffi._init_api("relay._base", __name__)
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.build_module", __name__)
+tvm._ffi._init_api("relay.build_module", __name__)
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._expr", __name__)
+tvm._ffi._init_api("relay._expr", __name__)
This module includes MyPy type signatures for all of the
exposed modules.
"""
-from .._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._make", __name__)
+tvm._ffi._init_api("relay._make", __name__)
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Module exposed from C++."""
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._module", __name__)
+tvm._ffi._init_api("relay._module", __name__)
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the Relay type inference and checking."""
+import tvm._ffi
-from tvm._ffi.function import _init_api
-
-_init_api("relay._transform", __name__)
+tvm._ffi._init_api("relay._transform", __name__)
# specific language governing permissions and limitations
# under the License.
"""The interface of expr function exposed from C++."""
-from __future__ import absolute_import
+import tvm._ffi
from ... import build_module as _build
from ... import container as _container
-from ..._ffi.function import _init_api, register_func
-@register_func("relay.backend.lower")
+@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
f, (_container.Array, tuple, list)) else [f]
-@register_func("relay.backend.build")
+@tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
return _build.build(funcs, target=target, target_host=target_host)
-@register_func("relay._tensor_value_repr")
+@tvm._ffi.register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
-@register_func("relay._constant_repr")
+@tvm._ffi.register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())
-_init_api("relay.backend", __name__)
+tvm._ffi._init_api("relay.backend", __name__)
# under the License.
"""The Relay virtual machine FFI namespace.
"""
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._vm", __name__)
+tvm._ffi._init_api("relay._vm", __name__)
import tvm
import tvm.ndarray as _nd
from tvm import autotvm, container
-from tvm.object import Object
+from tvm._ffi.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
-from __future__ import absolute_import as _abs
-from .._ffi.object import register_object as _register_tvm_node
+import tvm._ffi
+
from .._ffi.object import Object
from . import _make
from . import _expr
The type key of the node.
"""
if not isinstance(type_key, str):
- return _register_tvm_node(
+ return tvm._ffi.register_object(
"relay." + type_key.__name__)(type_key)
- return _register_tvm_node(type_key)
+ return tvm._ffi.register_object(type_key)
def register_relay_attr_node(type_key=None):
The type key of the node.
"""
if not isinstance(type_key, str):
- return _register_tvm_node(
+ return tvm._ffi.register_object(
"relay.attrs." + type_key.__name__)(type_key)
- return _register_tvm_node(type_key)
+ return tvm._ffi.register_object(type_key)
class RelayNode(Object):
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ..._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op._make", __name__)
+tvm._ffi._init_api("relay.op._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.annotation._make", __name__)
+tvm._ffi._init_api("relay.op.annotation._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.contrib._make", __name__)
+tvm._ffi._init_api("relay.op.contrib._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.image._make", __name__)
+tvm._ffi._init_api("relay.op.image._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.memory._make", __name__)
+tvm._ffi._init_api("relay.op.memory._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.nn._make", __name__)
+tvm._ffi._init_api("relay.op.nn._make", __name__)
#pylint: disable=unused-argument
"""The base node types for the Relay language."""
import topi
-
-from ..._ffi.function import _init_api
+import tvm._ffi
from ..base import register_relay_node
from ..expr import Expr
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
-_init_api("relay.op", __name__)
-
@register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
name = ''
return _make.debug(expr, name)
+
+tvm._ffi._init_api("relay.op", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.op.vision._make", __name__)
+tvm._ffi._init_api("relay.op.vision._make", __name__)
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
-from ...._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay.qnn.op._make", __name__)
+tvm._ffi._init_api("relay.qnn.op._make", __name__)
# under the License.
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
-from __future__ import absolute_import
import warnings
-
import topi
-from ..._ffi.function import register_func
+import tvm._ffi
+
from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
qctx.qnode_map[key] = qnode
return qnode
-register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
+tvm._ffi.register_func(
+ "relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
@register_annotate_function("nn.contrib_conv2d_NCHWc")
# under the License.
#pylint: disable=unused-argument
"""Internal module for quantization."""
-from __future__ import absolute_import
-from tvm._ffi.function import _init_api
+import tvm._ffi
-_init_api("relay._quantize", __name__)
+tvm._ffi._init_api("relay._quantize", __name__)
import struct
import random
import logging
+import tvm._ffi
-from .._ffi.function import _init_api
from .._ffi.base import py_str
# Magic header for RPC data plane
# Still use tvm.rpc for the foreign functions
-_init_api("tvm.rpc", "tvm.rpc.base")
+tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
import socket
import struct
import time
+import tvm._ffi
from . import base
from ..contrib import util
from .._ffi.base import TVMError
-from .._ffi import function
from .._ffi import ndarray as nd
from ..module import load as _load_module
def __init__(self):
# pylint: disable=super-init-not-called
self.context = nd.context
- self.get_function = function.get_global_func
+ self.get_function = tvm._ffi.get_global_func
self._temp = util.tempdir()
def upload(self, data, target=None):
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
-
-from __future__ import absolute_import
-
import os
import ctypes
import socket
import time
import sys
import signal
+import tvm._ffi
-from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
temp = util.tempdir()
# pylint: disable=unused-variable
- @register_func("tvm.rpc.server.workpath")
+ @tvm._ffi.register_func("tvm.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
- @register_func("tvm.rpc.server.load_module", override=True)
+ @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
# specific language governing permissions and limitations
# under the License.
"""The computation schedule api of TVM."""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
from ._ffi.base import string_types
-from ._ffi.object import Object, register_object
-from ._ffi.object import convert_to_object as _convert_to_object
-from ._ffi.function import _init_api, Function
-from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
+from ._ffi.object import Object
+from ._ffi.object_generic import convert
+
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
-def convert(value):
- """Convert value to TVM object or function.
-
- Parameters
- ----------
- value : python value
-
- Returns
- -------
- tvm_val : Object or Function
- Converted value in TVM
- """
- if isinstance(value, (Function, Object)):
- return value
-
- if callable(value):
- return _convert_tvm_func(value)
-
- return _convert_to_object(value)
-@register_object
+@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
return _api_internal._BufferVStore(self, begin, value)
-@register_object
+@tvm._ffi.register_object
class Split(Object):
"""Split operation on axis."""
-@register_object
+@tvm._ffi.register_object
class Fuse(Object):
"""Fuse operation on axis."""
-@register_object
+@tvm._ffi.register_object
class Singleton(Object):
"""Singleton axis."""
-@register_object
+@tvm._ffi.register_object
class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable.
return _api_internal._CreateSchedule(ops)
-@register_object
+@tvm._ffi.register_object
class Schedule(Object):
"""Schedule for all the stages."""
def __getitem__(self, k):
return factored[0] if len(factored) == 1 else factored
-@register_object
+@tvm._ffi.register_object
class Stage(Object):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None):
"""
_api_internal._StageOpenGL(self)
-_init_api("tvm.schedule")
+tvm._ffi._init_api("tvm.schedule")
assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a)
"""
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object
+import tvm._ffi
+from ._ffi.object import Object
from . import make as _make
class Stmt(Object):
pass
-@register_object
+@tvm._ffi.register_object
class LetStmt(Stmt):
"""LetStmt node.
_make.LetStmt, var, value, body)
-@register_object
+@tvm._ffi.register_object
class AssertStmt(Stmt):
"""AssertStmt node.
_make.AssertStmt, condition, message, body)
-@register_object
+@tvm._ffi.register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
_make.ProducerConsumer, func, is_producer, body)
-@register_object
+@tvm._ffi.register_object
class For(Stmt):
"""For node.
for_type, device_api, body)
-@register_object
+@tvm._ffi.register_object
class Store(Stmt):
"""Store node.
_make.Store, buffer_var, value, index, predicate)
-@register_object
+@tvm._ffi.register_object
class Provide(Stmt):
"""Provide node.
_make.Provide, func, value_index, value, args)
-@register_object
+@tvm._ffi.register_object
class Allocate(Stmt):
"""Allocate node.
extents, condition, body)
-@register_object
+@tvm._ffi.register_object
class AttrStmt(Stmt):
"""AttrStmt node.
_make.AttrStmt, node, attr_key, value, body)
-@register_object
+@tvm._ffi.register_object
class Free(Stmt):
"""Free node.
_make.Free, buffer_var)
-@register_object
+@tvm._ffi.register_object
class Realize(Stmt):
"""Realize node.
bounds, condition, body)
-@register_object
+@tvm._ffi.register_object
class SeqStmt(Stmt):
"""Sequence of statements.
return len(self.seq)
-@register_object
+@tvm._ffi.register_object
class IfThenElse(Stmt):
"""IfThenElse node.
_make.IfThenElse, condition, then_case, else_case)
-@register_object
+@tvm._ffi.register_object
class Evaluate(Stmt):
"""Evaluate node.
_make.Evaluate, value)
-@register_object
+@tvm._ffi.register_object
class Prefetch(Stmt):
"""Prefetch node.
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
-from __future__ import absolute_import
-
import warnings
+import tvm._ffi
from ._ffi.base import _LIB_NAME
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
from . import _api_internal
try:
return opts
-@register_object
+@tvm._ffi.register_object
class Target(Object):
"""Target device information, use through TVM API.
_api_internal._ExitTargetScope(self)
-@register_object
+@tvm._ffi.register_object
class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
# under the License.
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object, ObjectGeneric, \
- convert_to_object
+import tvm._ffi
+
+from ._ffi.object import Object
+from ._ffi.object_generic import ObjectGeneric, convert_to_object
+
from . import _api_internal
from . import make as _make
from . import expr as _expr
"""Data content of the tensor."""
return self.tensor.dtype
-@register_object
+@tvm._ffi.register_object
class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None
-@register_object
+@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
return _api_internal._OpInputTensors(self)
-@register_object
+@tvm._ffi.register_object
class PlaceholderOp(Operation):
"""Placeholder operation."""
-@register_object
+@tvm._ffi.register_object
class BaseComputeOp(Operation):
"""Compute operation."""
@property
return self.__getattr__("reduce_axis")
-@register_object
+@tvm._ffi.register_object
class ComputeOp(BaseComputeOp):
"""Scalar operation."""
pass
-@register_object
+@tvm._ffi.register_object
class TensorComputeOp(BaseComputeOp):
"""Tensor operation."""
-@register_object
+@tvm._ffi.register_object
class ScanOp(Operation):
"""Scan operation."""
@property
return self.__getattr__("scan_axis")
-@register_object
+@tvm._ffi.register_object
class ExternOp(Operation):
"""External operation."""
-@register_object
+@tvm._ffi.register_object
class HybridOp(Operation):
"""Hybrid operation."""
@property
return self.__getattr__("axis")
-@register_object
+@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
return _api_internal._LayoutFactorOf(self, axis)
-@register_object
+@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
# specific language governing permissions and limitations
# under the License.
"""Tensor intrinsics"""
-from __future__ import absolute_import as _abs
+import tvm._ffi
+
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
-from ._ffi.object import Object, register_object
+from ._ffi.object import Object
def _get_region(tslice):
region.append(_make.range_by_min_extent(begin, 1))
return region
-@register_object
+@tvm._ffi.register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
# specific language governing permissions and limitations
# under the License.
"""FFI for CUDA TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.cuda", "topi.cuda")
+tvm._ffi._init_api("topi.cuda", "topi.cpp.cuda")
# specific language governing permissions and limitations
# under the License.
"""FFI for generic TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.generic", "topi.generic")
+tvm._ffi._init_api("topi.generic", "topi.cpp.generic")
import sys
import os
import ctypes
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
from tvm._ffi import libinfo
def _get_lib_names():
_LIB, _LIB_NAME = _load_lib()
-_init_api_prefix("topi.cpp", "topi")
+tvm._ffi._init_api("topi", "topi.cpp")
# specific language governing permissions and limitations
# under the License.
"""FFI for NN TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.nn", "topi.nn")
+tvm._ffi._init_api("topi.nn", "topi.cpp.nn")
# specific language governing permissions and limitations
# under the License.
"""FFI for Rocm TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.rocm", "topi.rocm")
+tvm._ffi._init_api("topi.rocm", "topi.cpp.rocm")
# specific language governing permissions and limitations
# under the License.
"""FFI for TOPI utility functions"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.util", "topi.util")
+tvm._ffi._init_api("topi.util", "topi.cpp.util")
# under the License.
"""FFI for vision TOPI ops and schedules"""
-
-from tvm._ffi.function import _init_api_prefix
+import tvm._ffi
from . import yolo
-_init_api_prefix("topi.cpp.vision", "topi.vision")
+tvm._ffi._init_api("topi.vision", "topi.cpp.vision")
# specific language governing permissions and limitations
# under the License.
"""FFI for Yolo TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
+tvm._ffi._init_api("topi.vision.yolo", "topi.cpp.vision.yolo")
# specific language governing permissions and limitations
# under the License.
"""FFI for x86 TOPI ops and schedules"""
+import tvm._ffi
-from tvm._ffi.function import _init_api_prefix
-
-_init_api_prefix("topi.cpp.x86", "topi.x86")
+tvm._ffi._init_api("topi.x86", "topi.cpp.x86")