[REFACTOR][PY][API-CHANGE] Establish tvm.target
authortqchen <tqchen@octoml.ai>
Wed, 12 Feb 2020 21:39:45 +0000 (13:39 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 13 Feb 2020 00:36:23 +0000 (16:36 -0800)
Move the related target modules into tvm.target.

API change:
- tvm.target.current_target -> tvm.target.Target.current
- tvm.datatype -> tvm.target.datatype

96 files changed:
docs/api/python/index.rst
docs/api/python/target.rst
python/tvm/__init__.py
python/tvm/_ffi/runtime_ctypes.py
python/tvm/autotvm/task/dispatcher.py
python/tvm/build_module.py
python/tvm/codegen.py [deleted file]
python/tvm/contrib/clang.py
python/tvm/contrib/rocm.py
python/tvm/datatype.py [deleted file]
python/tvm/hybrid/calls.py
python/tvm/hybrid/runtime.py
python/tvm/intrin.py
python/tvm/relay/backend/vm.py
python/tvm/relay/build_module.py
python/tvm/relay/qnn/op/legalizations.py
python/tvm/relay/quantize/_calibrate.py
python/tvm/relay/quantize/_partition.py
python/tvm/target.py [deleted file]
python/tvm/target/__init__.py [new file with mode: 0644]
python/tvm/target/_ffi_api.py [new file with mode: 0644]
python/tvm/target/codegen.py [new file with mode: 0644]
python/tvm/target/datatype.py [new file with mode: 0644]
python/tvm/target/generic_func.py [new file with mode: 0644]
python/tvm/target/target.py [new file with mode: 0644]
src/runtime/c_runtime_api.cc
src/target/codegen.cc
src/target/datatype/registry.cc
src/target/generic_func.cc
src/target/llvm/llvm_module.cc
src/target/target.cc
src/tir/pass/lower_custom_datatypes.cc
tests/cpp/build_module_test.cc
tests/python/integration/test_dot.py
tests/python/relay/test_op_level2.py
tests/python/unittest/test_autotvm_common.py
tests/python/unittest/test_codegen_c_host.py
tests/python/unittest/test_codegen_device.py
tests/python/unittest/test_codegen_llvm.py
tests/python/unittest/test_codegen_static_init.py
tests/python/unittest/test_codegen_vm_basic.py
tests/python/unittest/test_codegen_x86.py
tests/python/unittest/test_custom_datatypes_mybfloat16.py
tests/python/unittest/test_lang_target.py
tests/python/unittest/test_runtime_extension.py
tests/python/unittest/test_runtime_module_load.py
topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/bifrost/conv2d.py
topi/python/topi/bifrost/transforms.py
topi/python/topi/cuda/batch_matmul.py
topi/python/topi/cuda/conv1d.py
topi/python/topi/cuda/conv1d_transpose_ncw.py
topi/python/topi/cuda/conv2d.py
topi/python/topi/cuda/conv2d_direct.py
topi/python/topi/cuda/conv2d_transpose_nchw.py
topi/python/topi/cuda/conv2d_winograd.py
topi/python/topi/cuda/conv3d.py
topi/python/topi/cuda/conv3d_direct.py
topi/python/topi/cuda/deformable_conv2d.py
topi/python/topi/cuda/dense.py
topi/python/topi/cuda/depthwise_conv2d.py
topi/python/topi/cuda/group_conv2d_nchw.py
topi/python/topi/cuda/injective.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/nn.py
topi/python/topi/cuda/pooling.py
topi/python/topi/cuda/rcnn/proposal.py
topi/python/topi/cuda/reduction.py
topi/python/topi/cuda/sort.py
topi/python/topi/cuda/ssd/multibox.py
topi/python/topi/cuda/vision.py
topi/python/topi/generic/extern.py
topi/python/topi/generic/injective.py
topi/python/topi/generic/nn.py
topi/python/topi/generic/vision.py
topi/python/topi/intel_graphics/conv2d.py
topi/python/topi/intel_graphics/depthwise_conv2d.py
topi/python/topi/mali/conv2d.py
topi/python/topi/nn/conv2d.py
topi/python/topi/rocm/conv2d.py
topi/python/topi/rocm/dense.py
topi/python/topi/rocm/nn.py
topi/python/topi/x86/batch_matmul.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/conv2d_alter_op.py
topi/python/topi/x86/conv2d_int8.py
topi/python/topi/x86/dense.py
topi/python/topi/x86/tensor_intrin.py
topi/python/topi/x86/util.py
vta/python/vta/top/op.py
vta/scripts/tune_conv2d.py
vta/scripts/tune_conv2d_transpose.py
vta/scripts/tune_dense.py
vta/scripts/tune_group_conv2d.py
vta/scripts/tune_resnet.py
vta/tutorials/autotvm/tune_relay_vta.py

index 2daebad4e6767971ab4bc089a6189a75339d1e79..b37d44eda7b3f89b164a8a8f69a647a1a76d8dd7 100644 (file)
@@ -26,10 +26,10 @@ Python API
    ndarray
    error
    ir
+   target
    intrin
    tensor
    schedule
-   target
    build
    function
    autotvm
index a0f3569c6060d295b53827fe8f705e7f6b01cb17..6851c04c5b6b2f451f98085b28af49902c3ac8c3 100644 (file)
@@ -19,3 +19,4 @@ tvm.target
 ----------
 .. automodule:: tvm.target
     :members:
+   :imported-members:
index 69c24008c10fbd32f4b6fbd880c2b43abe2544c9..9327c0865689a9232e0ff08251d46a108faad7b0 100644 (file)
@@ -46,7 +46,6 @@ from . import expr
 from . import stmt
 from . import make
 from . import ir_pass
-from . import codegen
 from . import schedule
 
 from . import ir_builder
@@ -55,7 +54,6 @@ from . import generic
 from . import hybrid
 from . import testing
 from . import error
-from . import datatype
 
 
 from .api import *
index f779b516618c8080a8e46f9540b484133971d7ea..10f0fec82a809389afd0c800d4cb30aed43d10b4 100644 (file)
@@ -20,7 +20,6 @@ import ctypes
 import json
 import numpy as np
 from .base import _LIB, check_call
-from .. import _api_internal
 
 tvm_shape_index_t = ctypes.c_int64
 
@@ -48,6 +47,7 @@ class TVMByteArray(ctypes.Structure):
     _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
                 ("size", ctypes.c_size_t)]
 
+
 class DataType(ctypes.Structure):
     """TVM datatype structure"""
     _fields_ = [("type_code", ctypes.c_uint8),
@@ -89,11 +89,13 @@ class DataType(ctypes.Structure):
             bits = 64
             head = ""
         elif head.startswith("custom"):
+            # pylint: disable=import-outside-toplevel
+            import tvm.runtime._ffi_api
             low, high = head.find('['), head.find(']')
             if not low or not high or low >= high:
                 raise ValueError("Badly formatted custom type string %s" % type_str)
             type_name = head[low + 1:high]
-            self.type_code = _api_internal._datatype_get_type_code(type_name)
+            self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name)
             head = head[high+1:]
         else:
             raise ValueError("Do not know how to handle type %s" % type_str)
@@ -102,13 +104,15 @@ class DataType(ctypes.Structure):
 
 
     def __repr__(self):
+        # pylint: disable=import-outside-toplevel
         if self.bits == 1 and self.lanes == 1:
             return "bool"
         if self.type_code in DataType.CODE2STR:
             type_name = DataType.CODE2STR[self.type_code]
         else:
+            import tvm.runtime._ffi_api
             type_name = "custom[%s]" % \
-                        _api_internal._datatype_get_type_name(self.type_code)
+                        tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
         x = "%s%d" % (type_name, self.bits)
         if self.lanes != 1:
             x += "x%d" % self.lanes
@@ -168,28 +172,35 @@ class TVMContext(ctypes.Structure):
         self.device_type = device_type
         self.device_id = device_id
 
+    def _GetDeviceAttr(self, device_type, device_id, attr_id):
+        """Internal helper function to invoke runtime.GetDeviceAttr"""
+        # pylint: disable=import-outside-toplevel
+        import tvm.runtime._ffi_api
+        return tvm.runtime._ffi_api.GetDeviceAttr(
+            device_type, device_id, attr_id)
+
     @property
     def exist(self):
         """Whether this device exist."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 0) != 0
 
     @property
     def max_threads_per_block(self):
         """Maximum number of threads on each block."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 1)
 
     @property
     def warp_size(self):
         """Number of threads that executes in concurrent."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 2)
 
     @property
     def max_shared_memory_per_block(self):
         """Total amount of shared memory per block in bytes."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 3)
 
     @property
@@ -203,25 +214,25 @@ class TVMContext(ctypes.Structure):
         version : str
             The version string in `major.minor` format.
         """
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 4)
 
     @property
     def device_name(self):
         """Return the string name of device."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 5)
 
     @property
     def max_clock_rate(self):
         """Return the max clock frequency of device."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 6)
 
     @property
     def multi_processor_count(self):
         """Return the number of compute units of device."""
-        return _api_internal._GetDeviceAttr(
+        return self._GetDeviceAttr(
             self.device_type, self.device_id, 7)
 
     @property
@@ -233,7 +244,7 @@ class TVMContext(ctypes.Structure):
         dims: List of int
             The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
         """
-        return json.loads(_api_internal._GetDeviceAttr(
+        return json.loads(self._GetDeviceAttr(
             self.device_type, self.device_id, 8))
 
     def sync(self):
index 15ed2953b28a77d978d5ab8ec7990baae01b9e2e..28a9fbba28340e87ec4a3ce49d1c2135f3c7bda8 100644 (file)
@@ -106,7 +106,7 @@ class DispatchContext(object):
             def _alter_conv2d_layout(attrs, inputs, tinfo):
                 workload = get_conv2d_workload(...)
                 dispatch_ctx = autotvm.task.DispatchContext.current
-                target = tvm.target.current_target()
+                target = tvm.target.Target.current()
                 config = dispatch_ctx.query(target, workload)
 
                 # Get conv2d_NCHWc workload from config
@@ -207,7 +207,7 @@ def dispatcher(fworkload):
 
     def dispatch_func(func, *args, **kwargs):
         """The wrapped dispatch function"""
-        tgt = _target.current_target()
+        tgt = _target.Target.current()
         workload = func(*args, **kwargs)
         cfg = DispatchContext.current.query(tgt, workload)
         if cfg.is_fallback and not cfg.template_key:
index 768f438844188f57dc0980fcbd2f18fd9d8507af..c2993ac278190832298c4c727f7122863fd1486a 100644 (file)
@@ -25,6 +25,8 @@ import tvm.runtime
 
 from tvm.runtime import Object, ndarray
 from tvm.ir import container
+from tvm.target import codegen
+
 from . import api
 from . import _api_internal
 from . import tensor
@@ -32,7 +34,6 @@ from . import schedule
 from . import expr
 from . import ir_pass
 from . import stmt as _stmt
-from . import codegen
 from . import target as _target
 from . import make
 from .stmt import LoweredFunc
@@ -602,7 +603,7 @@ def build(inputs,
                          "LoweredFunc.")
 
     if not isinstance(inputs, (dict, container.Map)):
-        target = _target.current_target() if target is None else target
+        target = _target.Target.current() if target is None else target
         target = target if target else "llvm"
         target_flist = {target: flist}
     else:
diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py
deleted file mode 100644 (file)
index 7dc7bea..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Code generation related functions."""
-import tvm._ffi
-
-def build_module(lowered_func, target):
-    """Build lowered_func into Module.
-
-    Parameters
-    ----------
-    lowered_func : LoweredFunc
-        The lowered function
-
-    target : str
-        The target module type.
-
-    Returns
-    -------
-    module : Module
-        The corressponding module.
-    """
-    return _Build(lowered_func, target)
-
-tvm._ffi._init_api("tvm.codegen")
index c8c6c57d0538059186e10d5b56b168ed0ebad03d..cb7bdcc1fd31dd0814667309d4f8e689902887f3 100644 (file)
 # under the License.
 """Util to invoke clang in the system."""
 # pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
 import subprocess
 
-from .._ffi.base import py_str
-from .. import codegen
+from tvm._ffi.base import py_str
+import tvm.target
 from . import util
 
 
@@ -44,8 +43,8 @@ def find_clang(required=True):
     matches the major llvm version that built with tvm
     """
     cc_list = []
-    if hasattr(codegen, "llvm_version_major"):
-        major = codegen.llvm_version_major()
+    major = tvm.target.codegen.llvm_version_major(allow_none=True)
+    if major is not None:
         cc_list += ["clang-%d.0" % major]
         cc_list += ["clang-%d" % major]
     cc_list += ["clang"]
index fba57f8524d06bd25d2d299434a4aca213be70d7..e5cebdd3f5dc107e4e3c6914c390cb27af556750 100644 (file)
 """Utility for ROCm backend"""
 import subprocess
 from os.path import join, exists
+
+from tvm._ffi.base import py_str
+import tvm.target
+
 from . import util
-from .._ffi.base import py_str
-from .. import codegen
 from ..api import register_func, convert
 
 def find_lld(required=True):
@@ -42,8 +44,8 @@ def find_lld(required=True):
     matches the major llvm version that built with tvm
     """
     lld_list = []
-    if hasattr(codegen, "llvm_version_major"):
-        major = codegen.llvm_version_major()
+    major = tvm.target.codegen.llvm_version_major(allow_none=True)
+    if major is not None:
         lld_list += ["ld.lld-%d.0" % major]
         lld_list += ["ld.lld-%d" % major]
     lld_list += ["ld.lld"]
diff --git a/python/tvm/datatype.py b/python/tvm/datatype.py
deleted file mode 100644 (file)
index 8a93673..0000000
+++ /dev/null
@@ -1,145 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Custom datatype functionality"""
-import tvm._ffi
-
-from . import make as _make
-from .api import convert
-from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
-from ._ffi.runtime_ctypes import DataType
-from . import _api_internal
-
-
-def register(type_name, type_code):
-    """Register a custom datatype with the given type name and type code
-    Currently, the type code is manually allocated by the user, and the
-    user must ensure that no two custom types share the same code.
-    Generally, this should be straightforward, as the user will be
-    manually registering all of their custom types.
-
-    Parameters
-    ----------
-    type_name : str
-        The name of the custom datatype
-
-    type_code : int
-        The type's code, which should be >= kCustomBegin
-    """
-    _api_internal._datatype_register(type_name, type_code)
-
-
-def get_type_name(type_code):
-    """Get the type name from the type code
-
-    Parameters
-    ----------
-    type_code : int
-        The type code
-    """
-    return _api_internal._datatype_get_type_name(type_code)
-
-
-def get_type_code(type_name):
-    """Get the type code from the type name
-
-    Parameters
-    ----------
-    type_name : str
-        The type name
-    """
-    return _api_internal._datatype_get_type_code(type_name)
-
-
-def get_type_registered(type_code):
-    """Get a boolean representing whether the type is registered
-
-    Parameters
-    ----------
-    type_code: int
-        The type code
-    """
-    return _api_internal._datatype_get_type_registered(type_code)
-
-
-def register_op(lower_func, op_name, target, type_name, src_type_name=None):
-    """Register an external function which computes the given op.
-
-    Currently, this will only work with Casts and binary expressions
-    whose arguments are named `a` and `b`.
-    TODO(gus) figure out what other special cases must be handled by
-        looking through expr.py.
-
-    Parameters
-    ----------
-    lower_func : function
-        The lowering function to call. See create_lower_func.
-
-    op_name : str
-        The name of the operation which the function computes, given by its
-        Halide::Internal class name (e.g. Add, LE, Cast).
-
-    target : str
-        The name of codegen target.
-
-    type_name : str
-        The name of the custom datatype, e.g. posit (but not custom[posit]8).
-
-    src_type_name : str
-        If op_name is "Cast", then this should be set to the source datatype of
-        the argument to the Cast. If op_name is not "Cast", this is unused.
-    """
-
-    if op_name == "Cast":
-        assert src_type_name is not None
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
-                          + type_name + "." + src_type_name
-    else:
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
-                          + type_name
-    tvm._ffi.register_func(lower_func_name, lower_func)
-
-
-def create_lower_func(extern_func_name):
-    """Returns a function which lowers an operation to a function call.
-
-    Parameters
-    ----------
-    extern_func_name : str
-        The name of the extern "C" function to lower to
-    """
-
-    def lower(op):
-        """
-        Takes an op---either a Cast or a binary op (e.g. an Add) and returns a
-        call to the specified external function, passing the op's argument
-        (Cast) or arguments (a binary op). The return type of the call depends
-        on the type of the op: if it is a custom type, then a uint of the same
-        width as the custom type is returned. Otherwise, the type is
-        unchanged."""
-        dtype = op.dtype
-        t = DataType(dtype)
-        if get_type_registered(t.type_code):
-            dtype = "uint" + str(t.bits)
-            if t.lanes > 1:
-                dtype += "x" + str(t.lanes)
-        if isinstance(op, (_Cast, _FloatImm)):
-            return _make.Call(dtype, extern_func_name, convert([op.value]),
-                              _Call.Extern, None, 0)
-        return _make.Call(dtype, extern_func_name, convert([op.a, op.b]),
-                          _Call.Extern, None, 0)
-
-    return lower
index 630c10fcf2dd16378552688dffcacb014d0fda30..78ce2e20e1fdb2d019c6c4a5badff664bcc09c0b 100644 (file)
@@ -154,8 +154,8 @@ def max_num_threads(func_id, args):
     _internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
     _internal_assert(args.__len__() <= 1, "At most one argument accepted!")
     if args.__len__() == 0:
-        res = _tgt.current_target().max_num_threads
+        res = _tgt.Target.current().max_num_threads
     else:
         _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
-        res = _tgt.current_target(args[0].value).max_num_threads
+        res = _tgt.Target.current(args[0].value).max_num_threads
     return _api.convert(res)
index aa00b4b802514aa031dc79a2733e85cc0c4108f6..9f92b80444e83a675e84bae0ff009f8607c6e320 100644 (file)
@@ -107,7 +107,7 @@ def sigmoid(x):
 
 def max_num_threads(allow_none=True):
     """Get max number of threads for GPU targets."""
-    return target.current_target(allow_none).max_num_threads
+    return target.Target.current(allow_none).max_num_threads
 
 
 HYBRID_GLOBALS = {
index 6146a718931888f15c1ef26e2d3b4bdc5f781a00..04cbf9ee86c7a0c23649460aa6d04bfddf4ee7b1 100644 (file)
@@ -17,7 +17,7 @@
 """Expression Intrinsics and math functions in TVM."""
 # pylint: disable=redefined-builtin
 import tvm._ffi
-import tvm.codegen
+import tvm.target.codegen
 
 from . import make as _make
 from .api import convert, const
@@ -189,7 +189,7 @@ def call_llvm_intrin(dtype, name, *args):
     call : Expr
         The call expression.
     """
-    llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
+    llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name)
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
 
index 557b9fd6e46df66a77cf689137528c4d616115a3..68b2b1c97c036865740320fdb182f0ab8b2a76af 100644 (file)
@@ -176,7 +176,7 @@ class VMCompiler(object):
 
     def _update_target(self, target):
         """Update target."""
-        target = target if target else tvm.target.current_target()
+        target = target if target else tvm.target.Target.current()
         if target is None:
             raise ValueError("Target is not set in env or passed as argument.")
         tgts = {}
index 8b347883fe3a3ffed156ecc0a8a0381a3eb6b912..6d9c850cb7ffa20cd69abf3a677e3d3c0e723cdd 100644 (file)
@@ -33,7 +33,7 @@ from .backend import interpreter as _interpreter
 from .backend.vm import VMExecutor
 
 def _update_target(target):
-    target = target if target else _target.current_target()
+    target = target if target else _target.Target.current()
     if target is None:
         raise ValueError("Target is not set in env or passed as argument.")
 
index 22785eec6b41dc1d2762c034d1eb08f13625297d..ad71313fef52cb1e64207d9e453c56e3b721f00c 100644 (file)
@@ -220,13 +220,13 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
 
 def is_fast_int8_on_intel():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
     return intel_supported_arches.intersection(set(target.options))
 
 def is_fast_int8_on_arm():
     """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     return '+v8.2a,+dotprod' in ' '.join(target.options)
 
 ########################
index 482a6f292f544e689ce30b260a1c4b239f795512..8f83bfbf0659a399b29332a435f58d930343aa0f 100644 (file)
@@ -37,8 +37,8 @@ def _get_profile_runtime(mod):
     func = mod['main']
     func = _quantize.CreateStatsCollector(func)
 
-    if tvm.target.current_target():
-        target = tvm.target.current_target()
+    if tvm.target.Target.current():
+        target = tvm.target.Target.current()
         ctx = tvm.context(target.target_name)
     else:
         target = 'llvm'
index c6a621db368a92cd43a2a3abc232efeafafc155c..fbac767cea24dab4ccd9fdc2307beef5d140b92a 100644 (file)
@@ -16,9 +16,7 @@
 # under the License.
 #pylint: disable=unused-argument,inconsistent-return-statements
 """Internal module for registering attribute for annotation."""
-from __future__ import absolute_import
-
-from ... import target as _target
+import tvm
 from .. import expr as _expr
 from .. import analysis as _analysis
 from ..base import register_relay_node
@@ -133,7 +131,7 @@ def add_partition_generic(ref_call, new_args, ctx):
 @register_partition_function("add")
 def add_partition_function(ref_call, new_args, ctx):
     """Rewrite function for ewise add for partition"""
-    target = _target.current_target()
+    target = tvm.target.Target.current()
     if target and 'cuda' in target.keys:
         #TODO(wuwei/ziheng) cuda specific rules
         return add_partition_generic(ref_call, new_args, ctx)
diff --git a/python/tvm/target.py b/python/tvm/target.py
deleted file mode 100644 (file)
index e149c89..0000000
+++ /dev/null
@@ -1,559 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Target management API of TVM.
-
-TVM's target string is in fomat ``<target_name> [-option=value]...``.
-
-Note
-----
-The list of options include:
-
-- **-device=<device name>**
-
-   The device name.
-
-- **-mtriple=<target triple>** or **-target**
-
-   Specify the target triple, which is useful for cross
-   compilation.
-
-- **-mcpu=<cpuname>**
-
-   Specify a specific chip in the current architecture to
-   generate code for. By default this is infered from the
-   target triple and autodetected to the current architecture.
-
-- **-mattr=a1,+a2,-a3,...**
-
-   Override or control specific attributes of the target,
-   such as whether SIMD operations are enabled or not. The
-   default set of attributes is set by the current CPU.
-
-- **-system-lib**
-
-   Build TVM system library module. System lib is a global module that contains
-   self registered functions in program startup. User can get the module using
-   :any:`tvm.runtime.system_lib`.
-   It is useful in environments where dynamic loading api like dlopen is banned.
-   The system lib will be available as long as the result code is linked by the program.
-
-We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
-We can also use other specific function in this module to create specific targets.
-"""
-import warnings
-import tvm._ffi
-
-from tvm.runtime import Object
-from ._ffi.base import _LIB_NAME
-from . import _api_internal
-
-try:
-    from decorator import decorate
-except ImportError as err_msg:
-    # Allow decorator to be missing in runtime
-    if _LIB_NAME != "libtvm_runtime.so":
-        raise err_msg
-
-def _merge_opts(opts, new_opts):
-    """Helper function to merge options"""
-    if isinstance(new_opts, str):
-        new_opts = new_opts.split()
-    if new_opts:
-        opt_set = set(opts)
-        new_opts = [opt for opt in new_opts if opt not in opt_set]
-        return opts + new_opts
-    return opts
-
-
-@tvm._ffi.register_object
-class Target(Object):
-    """Target device information, use through TVM API.
-
-    Note
-    ----
-    Do not use class constructor, you can create target using the following functions
-
-    - :any:`tvm.target.create` create target from string
-    - :any:`tvm.target.arm_cpu` create arm_cpu target
-    - :any:`tvm.target.cuda` create CUDA target
-    - :any:`tvm.target.rocm` create ROCM target
-    - :any:`tvm.target.mali` create Mali target
-    - :any:`tvm.target.intel_graphics` create Intel Graphics target
-    """
-    def __new__(cls):
-        # Always override new to enable class
-        obj = Object.__new__(cls)
-        obj._keys = None
-        obj._options = None
-        obj._libs = None
-        return obj
-
-    @property
-    def keys(self):
-        if not self._keys:
-            self._keys = [k.value for k in self.keys_array]
-        return self._keys
-
-    @property
-    def options(self):
-        if not self._options:
-            self._options = [o.value for o in self.options_array]
-        return self._options
-
-    @property
-    def libs(self):
-        if not self._libs:
-            self._libs = [l.value for l in self.libs_array]
-        return self._libs
-
-    @property
-    def model(self):
-        for opt in self.options_array:
-            if opt.value.startswith('-model='):
-                return opt.value[7:]
-        return 'unknown'
-
-    @property
-    def mcpu(self):
-        """Returns the mcpu from the target if it exists."""
-        mcpu = ''
-        if self.options is not None:
-            for opt in self.options:
-                if 'mcpu' in opt:
-                    mcpu = opt.split('=')[1]
-        return mcpu
-
-    def __enter__(self):
-        _api_internal._EnterTargetScope(self)
-        return self
-
-    def __exit__(self, ptype, value, trace):
-        _api_internal._ExitTargetScope(self)
-
-
-@tvm._ffi.register_object
-class GenericFunc(Object):
-    """GenericFunc node reference. This represents a generic function
-    that may be specialized for different targets. When this object is
-    called, a specialization is chosen based on the current target.
-
-    Note
-    ----
-    Do not construct an instance of this object, it should only ever be
-    used as a return value from calling into C++.
-    """
-    def __call__(self, *args):
-        return _api_internal._GenericFuncCallFunc(self, *args)
-
-    def set_default(self, func, allow_override=False):
-        """Set the default function to be used if no specializations match
-        the current target.
-
-        Parameters
-        ----------
-        func : function
-            The default function
-
-        allow_override : bool
-            Whether to allow the current default to be overridden
-        """
-        _api_internal._GenericFuncSetDefault(self, func, allow_override)
-
-    def register(self, func, key_list, allow_override=False):
-        """Register a specialization for this GenericFunc.
-
-        Parameters
-        ----------
-        func : function
-            The function to be registered.
-
-        key : str or list of str
-            The key to be registered.
-
-        allow_override : bool, optional
-            Whether to allow existing keys to be overridden.
-        """
-        key_list = [key_list] if isinstance(key_list, str) else key_list
-        _api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override)
-
-
-def get_native_generic_func(name):
-    """Get a generic function from the global registry. If no
-    function is registered under the given name, a new generic
-    function is created.
-
-    Parameters
-    ----------
-    name : string
-        The name of the generic function to get
-
-    Returns
-    -------
-    func : GenericFunc
-        The generic function for the given name
-    """
-    return _api_internal._GenericFuncGetGlobal(name)
-
-
-def override_native_generic_func(func_name):
-    """Override a generic function defined in C++
-
-    Generic function allows registration of further functions
-    that can be dispatched on current target context.
-    If no registered dispatch is matched, the fdefault will be called.
-
-    Parameters
-    ----------
-    func_name : string
-        The name of the generic func to be overridden
-
-    Returns
-    -------
-    fgeneric : function
-        A wrapped generic function.
-
-    Example
-    -------
-    .. code-block:: python
-
-      import tvm
-      # wrap function as target generic
-      @tvm.target.override_native_generic_func("my_func")
-      def my_func(a):
-          return a + 1
-      # register specialization of my_func under target cuda
-      @my_func.register("cuda")
-      def my_func_cuda(a):
-          return a + 2
-      # displays 3, because my_func is called
-      print(my_func(2))
-      # displays 4, because my_func_cuda is called
-      with tvm.target.cuda():
-          print(my_func(2))
-    """
-    generic_func_node = get_native_generic_func(func_name)
-
-    def fdecorate(fdefault):
-        """Wrap a target generic function, overriding the previous
-        default that was set for the generic function.
-
-        Parameters
-        ----------
-        fdefault : function
-            The default function.
-
-        Returns
-        -------
-        fgeneric : function
-            A wrapped generic function.
-
-        """
-        generic_func_node.set_default(fdefault, allow_override=True)
-
-        def register(key, func=None, override=True):
-            """Register function to be the dispatch function.
-
-            Parameters
-            ----------
-            key : str or list of str
-                The key to be registered.
-
-            func : function
-                The function to be registered.
-
-            override : bool, optional
-                Whether override existing registration.
-
-            Returns
-            -------
-            The register function is necessary.
-            """
-            def _do_reg(myf):
-                generic_func_node.register(myf, key, override)
-                return myf
-            if func:
-                return _do_reg(func)
-            return _do_reg
-
-        def dispatch_func(func, *args, **kwargs):
-            #pylint: disable=unused-argument
-            """The wrapped dispath function"""
-            if kwargs:
-                raise RuntimeError(
-                    "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
-            return generic_func_node(*args)
-        fresult = decorate(fdefault, dispatch_func)
-        fresult.fdefault = fdefault
-        fresult.register = register
-        return fresult
-    return fdecorate
-
-def generic_func(fdefault):
-    """Wrap a target generic function.
-
-    Generic function allows registration of further functions
-    that can be dispatched on current target context.
-    If no registered dispatch is matched, the fdefault will be called.
-
-    Parameters
-    ----------
-    fdefault : function
-        The default function.
-
-    Returns
-    -------
-    fgeneric : function
-        A wrapped generic function.
-
-    Example
-    -------
-    .. code-block:: python
-
-      import tvm
-      # wrap function as target generic
-      @tvm.target.generic_func
-      def my_func(a):
-          return a + 1
-      # register specialization of my_func under target cuda
-      @my_func.register("cuda")
-      def my_func_cuda(a):
-          return a + 2
-      # displays 3, because my_func is called
-      print(my_func(2))
-      # displays 4, because my_func_cuda is called
-      with tvm.target.cuda():
-          print(my_func(2))
-    """
-    dispatch_dict = {}
-    func_name = fdefault.__name__
-
-    def register(key, func=None, override=False):
-        """Register function to be the dispatch function.
-
-        Parameters
-        ----------
-        key : str or list of str
-            The key to be registered.
-
-        func : function
-            The function to be registered.
-
-        override : bool
-            Whether override existing registration.
-
-        Returns
-        -------
-        The register function is necessary.
-        """
-        def _do_reg(myf):
-            key_list = [key] if isinstance(key, str) else key
-            for k in key_list:
-                if k in dispatch_dict and not override:
-                    raise ValueError(
-                        "Key is already registered for %s" % func_name)
-                dispatch_dict[k] = myf
-            return myf
-        if func:
-            return _do_reg(func)
-        return _do_reg
-
-    def dispatch_func(func, *args, **kwargs):
-        """The wrapped dispath function"""
-        target = current_target()
-        if target is None:
-            return func(*args, **kwargs)
-        for k in target.keys:
-            if k in dispatch_dict:
-                return dispatch_dict[k](*args, **kwargs)
-        return func(*args, **kwargs)
-    fdecorate = decorate(fdefault, dispatch_func)
-    fdecorate.register = register
-    fdecorate.fdefault = fdefault
-    return fdecorate
-
-
-def cuda(model='unknown', options=None):
-    """Returns a cuda target.
-
-    Parameters
-    ----------
-    model: str
-        The model of cuda device (e.g. 1080ti)
-    options : str or list of str
-        Additional options
-    """
-    opts = _merge_opts(['-model=%s' % model], options)
-    return _api_internal._TargetCreate("cuda", *opts)
-
-
-def rocm(model='unknown', options=None):
-    """Returns a ROCM target.
-
-    Parameters
-    ----------
-    model: str
-        The model of this device
-    options : str or list of str
-        Additional options
-    """
-    opts = _merge_opts(["-model=%s" % model], options)
-    return _api_internal._TargetCreate("rocm", *opts)
-
-
-def mali(model='unknown', options=None):
-    """Returns a ARM Mali GPU target.
-
-    Parameters
-    ----------
-    model: str
-        The model of this device
-    options : str or list of str
-        Additional options
-    """
-    opts = ["-device=mali", '-model=%s' % model]
-    opts = _merge_opts(opts, options)
-    return _api_internal._TargetCreate("opencl", *opts)
-
-
-def intel_graphics(model='unknown', options=None):
-    """Returns an Intel Graphics target.
-
-    Parameters
-    ----------
-    model: str
-        The model of this device
-    options : str or list of str
-        Additional options
-    """
-    opts = ["-device=intel_graphics", '-model=%s' % model]
-    opts = _merge_opts(opts, options)
-    return _api_internal._TargetCreate("opencl", *opts)
-
-
-def opengl(model='unknown', options=None):
-    """Returns a OpenGL target.
-
-    Parameters
-    ----------
-    options : str or list of str
-        Additional options
-    """
-    opts = _merge_opts(["-model=%s" % model], options)
-    return _api_internal._TargetCreate("opengl", *opts)
-
-
-def arm_cpu(model='unknown', options=None):
-    """Returns a ARM CPU target.
-    This function will also download pre-tuned op parameters when there is none.
-
-    Parameters
-    ----------
-    model: str
-        SoC name or phone name of the arm board.
-    options : str or list of str
-        Additional options
-    """
-    trans_table = {
-        "pixel2":    ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"],
-        "mate10":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
-        "mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
-        "p20":       ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
-        "p20pro":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
-        "rasp3b":    ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"],
-        "rasp4b":    ["-model=bcm2711", "-target=arm-linux-gnueabihf -mattr=+neon"],
-        "rk3399":    ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"],
-        "pynq":      ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"],
-        "ultra96":   ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"],
-    }
-    pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
-
-    opts = ["-device=arm_cpu"] + pre_defined_opt
-    opts = _merge_opts(opts, options)
-    return _api_internal._TargetCreate("llvm", *opts)
-
-
-def rasp(options=None):
-    """Return a Raspberry 3b target.
-
-    Parameters
-    ----------
-    options : str or list of str
-        Additional options
-    """
-    warnings.warn('tvm.target.rasp() is going to be deprecated. '
-                  'Please use tvm.target.arm_cpu("rasp3b")')
-    return arm_cpu('rasp3b', options)
-
-
-def vta(model='unknown', options=None):
-    opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
-    opts = _merge_opts(opts, options)
-    ret = _api_internal._TargetCreate("ext_dev", *opts)
-    return ret
-
-
-def bifrost(model='unknown', options=None):
-    """Return an ARM Mali GPU target (Bifrost architecture).
-
-    Parameters
-    ----------
-    options : str or list of str
-        Additional options
-    """
-    opts = ["-device=bifrost", '-model=%s' % model]
-    opts = _merge_opts(opts, options)
-    return _api_internal._TargetCreate("opencl", *opts)
-
-
-def create(target_str):
-    """Get a target given target string.
-
-    Parameters
-    ----------
-    target_str : str
-        The target string.
-
-    Returns
-    -------
-    target : Target
-        The target object
-
-    Note
-    ----
-    See the note on :any:`tvm.target` on target string format.
-    """
-    if isinstance(target_str, Target):
-        return target_str
-    if not isinstance(target_str, str):
-        raise ValueError("target_str has to be string type")
-
-    return _api_internal._TargetFromString(target_str)
-
-
-def current_target(allow_none=True):
-    """Returns the current target.
-
-    Parameters
-    ----------
-    allow_none : bool
-       Whether allow the current target to be none
-
-    Raises
-    ------
-    ValueError if current target is not set.
-    """
-    return _api_internal._GetCurrentTarget(allow_none)
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
new file mode 100644 (file)
index 0000000..abe8436
--- /dev/null
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Target description and codgen module.
+
+TVM's target string is in fomat ``<target_name> [-option=value]...``.
+
+Note
+----
+The list of options include:
+
+- **-device=<device name>**
+
+   The device name.
+
+- **-mtriple=<target triple>** or **-target**
+
+   Specify the target triple, which is useful for cross
+   compilation.
+
+- **-mcpu=<cpuname>**
+
+   Specify a specific chip in the current architecture to
+   generate code for. By default this is infered from the
+   target triple and autodetected to the current architecture.
+
+- **-mattr=a1,+a2,-a3,...**
+
+   Override or control specific attributes of the target,
+   such as whether SIMD operations are enabled or not. The
+   default set of attributes is set by the current CPU.
+
+- **-system-lib**
+
+   Build TVM system library module. System lib is a global module that contains
+   self registered functions in program startup. User can get the module using
+   :any:`tvm.runtime.system_lib`.
+   It is useful in environments where dynamic loading api like dlopen is banned.
+   The system lib will be available as long as the result code is linked by the program.
+
+We can use :py:func:`~tvm.target.create` to create a tvm.target.Target from the target string.
+We can also use other specific function in this module to create specific targets.
+"""
+from .target import Target, create
+from .target import cuda, rocm, mali, intel_graphics, opengl, arm_cpu, rasp, vta, bifrost
+from .generic_func import GenericFunc
+from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
+from . import datatype
+from . import codegen
diff --git a/python/tvm/target/_ffi_api.py b/python/tvm/target/_ffi_api.py
new file mode 100644 (file)
index 0000000..3f3c4f2
--- /dev/null
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.target"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("target", __name__)
diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py
new file mode 100644 (file)
index 0000000..e7bedaa
--- /dev/null
@@ -0,0 +1,76 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Code generation related functions."""
+from . import _ffi_api
+
+
+def build_module(lowered_func, target):
+    """Build lowered_func into Module.
+
+    Parameters
+    ----------
+    lowered_func : LoweredFunc
+        The lowered function
+
+    target : str
+        The target module type.
+
+    Returns
+    -------
+    module : runtime.Module
+        The corressponding module.
+    """
+    return _ffi_api.Build(lowered_func, target)
+
+
+def llvm_lookup_intrinsic_id(name):
+    """Lookup LLVM intrinsic id by name.
+
+    Parameters
+    ----------
+    name : str
+        The name of the intrinsic.
+
+    Returns
+    -------
+    intrin_id : int
+        The intrinsic id.
+    """
+    return _ffi_api.llvm_lookup_intrinsic_id(name)
+
+
+def llvm_version_major(allow_none=False):
+    """Get the major LLVM version.
+
+    Parameters
+    ----------
+    allow_none : bool
+        Whether do we allow none.
+
+    Returns
+    -------
+    major : int
+        The major LLVM version.
+    """
+    try:
+        return _ffi_api.llvm_version_major()
+    except AttributeError:
+        if allow_none:
+            return None
+        raise RuntimeError(
+            "LLVM version is not available, please check if you build with LLVM")
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
new file mode 100644 (file)
index 0000000..a9506b3
--- /dev/null
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Custom datatype functionality"""
+import tvm._ffi
+
+import tvm.runtime._ffi_api
+from tvm.runtime import convert, DataType
+from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
+
+
+def register(type_name, type_code):
+    """Register a custom datatype with the given type name and type code
+    Currently, the type code is manually allocated by the user, and the
+    user must ensure that no two custom types share the same code.
+    Generally, this should be straightforward, as the user will be
+    manually registering all of their custom types.
+
+    Parameters
+    ----------
+    type_name : str
+        The name of the custom datatype
+
+    type_code : int
+        The type's code, which should be >= kCustomBegin
+    """
+    tvm.runtime._ffi_api._datatype_register(type_name, type_code)
+
+
+def get_type_name(type_code):
+    """Get the type name from the type code
+
+    Parameters
+    ----------
+    type_code : int
+        The type code
+    """
+    return tvm.runtime._ffi_api._datatype_get_type_name(type_code)
+
+
+def get_type_code(type_name):
+    """Get the type code from the type name
+
+    Parameters
+    ----------
+    type_name : str
+        The type name
+    """
+    return tvm.runtime._ffi_api._datatype_get_type_code(type_name)
+
+
+def get_type_registered(type_code):
+    """Get a boolean representing whether the type is registered
+
+    Parameters
+    ----------
+    type_code: int
+        The type code
+    """
+    return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
+
+
+def register_op(lower_func, op_name, target, type_name, src_type_name=None):
+    """Register an external function which computes the given op.
+
+    Currently, this will only work with Casts and binary expressions
+    whose arguments are named `a` and `b`.
+    TODO(gus) figure out what other special cases must be handled by
+        looking through expr.py.
+
+    Parameters
+    ----------
+    lower_func : function
+        The lowering function to call. See create_lower_func.
+
+    op_name : str
+        The name of the operation which the function computes, given by its
+        Halide::Internal class name (e.g. Add, LE, Cast).
+
+    target : str
+        The name of codegen target.
+
+    type_name : str
+        The name of the custom datatype, e.g. posit (but not custom[posit]8).
+
+    src_type_name : str
+        If op_name is "Cast", then this should be set to the source datatype of
+        the argument to the Cast. If op_name is not "Cast", this is unused.
+    """
+
+    if op_name == "Cast":
+        assert src_type_name is not None
+        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
+                          + type_name + "." + src_type_name
+    else:
+        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
+                          + type_name
+    tvm._ffi.register_func(lower_func_name, lower_func)
+
+
+def create_lower_func(extern_func_name):
+    """Returns a function which lowers an operation to a function call.
+
+    Parameters
+    ----------
+    extern_func_name : str
+        The name of the extern "C" function to lower to
+    """
+
+    def lower(op):
+        """
+        Takes an op---either a Cast or a binary op (e.g. an Add) and returns a
+        call to the specified external function, passing the op's argument
+        (Cast) or arguments (a binary op). The return type of the call depends
+        on the type of the op: if it is a custom type, then a uint of the same
+        width as the custom type is returned. Otherwise, the type is
+        unchanged."""
+        dtype = op.dtype
+        t = DataType(dtype)
+        if get_type_registered(t.type_code):
+            dtype = "uint" + str(t.bits)
+            if t.lanes > 1:
+                dtype += "x" + str(t.lanes)
+        if isinstance(op, (_Cast, _FloatImm)):
+            return _Call(dtype, extern_func_name, convert([op.value]),
+                         _Call.Extern, None, 0)
+        return _Call(dtype, extern_func_name, convert([op.a, op.b]),
+                     _Call.Extern, None, 0)
+
+    return lower
diff --git a/python/tvm/target/generic_func.py b/python/tvm/target/generic_func.py
new file mode 100644 (file)
index 0000000..862fbed
--- /dev/null
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Generic function."""
+
+import tvm._ffi
+
+try:
+    from decorator import decorate
+except ImportError as err_msg:
+    # Allow decorator to be missing in runtime
+    if _LIB_NAME != "libtvm_runtime.so":
+        raise err_msg
+
+from tvm.runtime import Object
+from . target import Target
+from . import _ffi_api
+
+
+@tvm._ffi.register_object
+class GenericFunc(Object):
+    """GenericFunc node reference. This represents a generic function
+    that may be specialized for different targets. When this object is
+    called, a specialization is chosen based on the current target.
+
+    Note
+    ----
+    Do not construct an instance of this object, it should only ever be
+    used as a return value from calling into C++.
+    """
+    def __call__(self, *args):
+        return _ffi_api.GenericFuncCallFunc(self, *args)
+
+    def set_default(self, func, allow_override=False):
+        """Set the default function to be used if no specializations match
+        the current target.
+
+        Parameters
+        ----------
+        func : function
+            The default function
+
+        allow_override : bool
+            Whether to allow the current default to be overridden
+        """
+        _ffi_api.GenericFuncSetDefault(self, func, allow_override)
+
+    def register(self, func, key_list, allow_override=False):
+        """Register a specialization for this GenericFunc.
+
+        Parameters
+        ----------
+        func : function
+            The function to be registered.
+
+        key : str or list of str
+            The key to be registered.
+
+        allow_override : bool, optional
+            Whether to allow existing keys to be overridden.
+        """
+        key_list = [key_list] if isinstance(key_list, str) else key_list
+        _ffi_api.GenericFuncRegisterFunc(self, func, key_list, allow_override)
+
+
+def get_native_generic_func(name):
+    """Get a generic function from the global registry. If no
+    function is registered under the given name, a new generic
+    function is created.
+
+    Parameters
+    ----------
+    name : string
+        The name of the generic function to get
+
+    Returns
+    -------
+    func : GenericFunc
+        The generic function for the given name
+    """
+    return _ffi_api.GenericFuncGetGlobal(name)
+
+
+def override_native_generic_func(func_name):
+    """Override a generic function defined in C++
+
+    Generic function allows registration of further functions
+    that can be dispatched on current target context.
+    If no registered dispatch is matched, the fdefault will be called.
+
+    Parameters
+    ----------
+    func_name : string
+        The name of the generic func to be overridden
+
+    Returns
+    -------
+    fgeneric : function
+        A wrapped generic function.
+
+    Example
+    -------
+    .. code-block:: python
+
+      import tvm
+      # wrap function as target generic
+      @tvm.target.override_native_generic_func("my_func")
+      def my_func(a):
+          return a + 1
+      # register specialization of my_func under target cuda
+      @my_func.register("cuda")
+      def my_func_cuda(a):
+          return a + 2
+      # displays 3, because my_func is called
+      print(my_func(2))
+      # displays 4, because my_func_cuda is called
+      with tvm.target.cuda():
+          print(my_func(2))
+    """
+    generic_func_node = get_native_generic_func(func_name)
+
+    def fdecorate(fdefault):
+        """Wrap a target generic function, overriding the previous
+        default that was set for the generic function.
+
+        Parameters
+        ----------
+        fdefault : function
+            The default function.
+
+        Returns
+        -------
+        fgeneric : function
+            A wrapped generic function.
+
+        """
+        generic_func_node.set_default(fdefault, allow_override=True)
+
+        def register(key, func=None, override=True):
+            """Register function to be the dispatch function.
+
+            Parameters
+            ----------
+            key : str or list of str
+                The key to be registered.
+
+            func : function
+                The function to be registered.
+
+            override : bool, optional
+                Whether override existing registration.
+
+            Returns
+            -------
+            The register function is necessary.
+            """
+            def _do_reg(myf):
+                generic_func_node.register(myf, key, override)
+                return myf
+            if func:
+                return _do_reg(func)
+            return _do_reg
+
+        def dispatch_func(func, *args, **kwargs):
+            #pylint: disable=unused-argument
+            """The wrapped dispath function"""
+            if kwargs:
+                raise RuntimeError(
+                    "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
+            return generic_func_node(*args)
+        fresult = decorate(fdefault, dispatch_func)
+        fresult.fdefault = fdefault
+        fresult.register = register
+        return fresult
+    return fdecorate
+
+def generic_func(fdefault):
+    """Wrap a target generic function.
+
+    Generic function allows registration of further functions
+    that can be dispatched on current target context.
+    If no registered dispatch is matched, the fdefault will be called.
+
+    Parameters
+    ----------
+    fdefault : function
+        The default function.
+
+    Returns
+    -------
+    fgeneric : function
+        A wrapped generic function.
+
+    Example
+    -------
+    .. code-block:: python
+
+      import tvm
+      # wrap function as target generic
+      @tvm.target.generic_func
+      def my_func(a):
+          return a + 1
+      # register specialization of my_func under target cuda
+      @my_func.register("cuda")
+      def my_func_cuda(a):
+          return a + 2
+      # displays 3, because my_func is called
+      print(my_func(2))
+      # displays 4, because my_func_cuda is called
+      with tvm.target.cuda():
+          print(my_func(2))
+    """
+    dispatch_dict = {}
+    func_name = fdefault.__name__
+
+    def register(key, func=None, override=False):
+        """Register function to be the dispatch function.
+
+        Parameters
+        ----------
+        key : str or list of str
+            The key to be registered.
+
+        func : function
+            The function to be registered.
+
+        override : bool
+            Whether override existing registration.
+
+        Returns
+        -------
+        The register function is necessary.
+        """
+        def _do_reg(myf):
+            key_list = [key] if isinstance(key, str) else key
+            for k in key_list:
+                if k in dispatch_dict and not override:
+                    raise ValueError(
+                        "Key is already registered for %s" % func_name)
+                dispatch_dict[k] = myf
+            return myf
+        if func:
+            return _do_reg(func)
+        return _do_reg
+
+    def dispatch_func(func, *args, **kwargs):
+        """The wrapped dispath function"""
+        target = Target.current()
+        if target is None:
+            return func(*args, **kwargs)
+        for k in target.keys:
+            if k in dispatch_dict:
+                return dispatch_dict[k](*args, **kwargs)
+        return func(*args, **kwargs)
+    fdecorate = decorate(fdefault, dispatch_func)
+    fdecorate.register = register
+    fdecorate.fdefault = fdefault
+    return fdecorate
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
new file mode 100644 (file)
index 0000000..8405bb1
--- /dev/null
@@ -0,0 +1,272 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Target data structure."""
+import warnings
+import tvm._ffi
+
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object
+class Target(Object):
+    """Target device information, use through TVM API.
+
+    Note
+    ----
+    Do not use class constructor, you can create target using the following functions
+
+    - :py:func:`~tvm.target.create` create target from string
+    - :py:func:`~tvm.target.arm_cpu` create arm_cpu target
+    - :py:func:`~tvm.target.cuda` create CUDA target
+    - :py:func:`~tvm.target.rocm` create ROCM target
+    - :py:func:`~tvm.target.mali` create Mali target
+    - :py:func:`~tvm.target.intel_graphics` create Intel Graphics target
+    """
+    def __new__(cls):
+        # Always override new to enable class
+        obj = Object.__new__(cls)
+        obj._keys = None
+        obj._options = None
+        obj._libs = None
+        return obj
+
+    @property
+    def keys(self):
+        if not self._keys:
+            self._keys = [k.value for k in self.keys_array]
+        return self._keys
+
+    @property
+    def options(self):
+        if not self._options:
+            self._options = [o.value for o in self.options_array]
+        return self._options
+
+    @property
+    def libs(self):
+        if not self._libs:
+            self._libs = [l.value for l in self.libs_array]
+        return self._libs
+
+    @property
+    def model(self):
+        for opt in self.options_array:
+            if opt.value.startswith('-model='):
+                return opt.value[7:]
+        return 'unknown'
+
+    @property
+    def mcpu(self):
+        """Returns the mcpu from the target if it exists."""
+        mcpu = ''
+        if self.options is not None:
+            for opt in self.options:
+                if 'mcpu' in opt:
+                    mcpu = opt.split('=')[1]
+        return mcpu
+
+    def __enter__(self):
+        _ffi_api.EnterTargetScope(self)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        _ffi_api.ExitTargetScope(self)
+
+    @staticmethod
+    def current(allow_none=True):
+        """Returns the current target.
+
+        Parameters
+        ----------
+        allow_none : bool
+            Whether allow the current target to be none
+
+        Raises
+        ------
+        ValueError if current target is not set.
+        """
+        return _ffi_api.GetCurrentTarget(allow_none)
+
+
+def _merge_opts(opts, new_opts):
+    """Helper function to merge options"""
+    if isinstance(new_opts, str):
+        new_opts = new_opts.split()
+    if new_opts:
+        opt_set = set(opts)
+        new_opts = [opt for opt in new_opts if opt not in opt_set]
+        return opts + new_opts
+    return opts
+
+
+def cuda(model='unknown', options=None):
+    """Returns a cuda target.
+
+    Parameters
+    ----------
+    model: str
+        The model of cuda device (e.g. 1080ti)
+    options : str or list of str
+        Additional options
+    """
+    opts = _merge_opts(['-model=%s' % model], options)
+    return _ffi_api.TargetCreate("cuda", *opts)
+
+
+def rocm(model='unknown', options=None):
+    """Returns a ROCM target.
+
+    Parameters
+    ----------
+    model: str
+        The model of this device
+    options : str or list of str
+        Additional options
+    """
+    opts = _merge_opts(["-model=%s" % model], options)
+    return _ffi_api.TargetCreate("rocm", *opts)
+
+
+def mali(model='unknown', options=None):
+    """Returns a ARM Mali GPU target.
+
+    Parameters
+    ----------
+    model: str
+        The model of this device
+    options : str or list of str
+        Additional options
+    """
+    opts = ["-device=mali", '-model=%s' % model]
+    opts = _merge_opts(opts, options)
+    return _ffi_api.TargetCreate("opencl", *opts)
+
+
+def intel_graphics(model='unknown', options=None):
+    """Returns an Intel Graphics target.
+
+    Parameters
+    ----------
+    model: str
+        The model of this device
+    options : str or list of str
+        Additional options
+    """
+    opts = ["-device=intel_graphics", '-model=%s' % model]
+    opts = _merge_opts(opts, options)
+    return _ffi_api.TargetCreate("opencl", *opts)
+
+
+def opengl(model='unknown', options=None):
+    """Returns a OpenGL target.
+
+    Parameters
+    ----------
+    options : str or list of str
+        Additional options
+    """
+    opts = _merge_opts(["-model=%s" % model], options)
+    return _ffi_api.TargetCreate("opengl", *opts)
+
+
+def arm_cpu(model='unknown', options=None):
+    """Returns a ARM CPU target.
+    This function will also download pre-tuned op parameters when there is none.
+
+    Parameters
+    ----------
+    model: str
+        SoC name or phone name of the arm board.
+    options : str or list of str
+        Additional options
+    """
+    trans_table = {
+        "pixel2":    ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"],
+        "mate10":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
+        "mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
+        "p20":       ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
+        "p20pro":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
+        "rasp3b":    ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"],
+        "rasp4b":    ["-model=bcm2711", "-target=arm-linux-gnueabihf -mattr=+neon"],
+        "rk3399":    ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"],
+        "pynq":      ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"],
+        "ultra96":   ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"],
+    }
+    pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
+
+    opts = ["-device=arm_cpu"] + pre_defined_opt
+    opts = _merge_opts(opts, options)
+    return _ffi_api.TargetCreate("llvm", *opts)
+
+
+def rasp(options=None):
+    """Return a Raspberry 3b target.
+
+    Parameters
+    ----------
+    options : str or list of str
+        Additional options
+    """
+    warnings.warn('tvm.target.rasp() is going to be deprecated. '
+                  'Please use tvm.target.arm_cpu("rasp3b")')
+    return arm_cpu('rasp3b', options)
+
+
+def vta(model='unknown', options=None):
+    opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
+    opts = _merge_opts(opts, options)
+    ret = _ffi_api.TargetCreate("ext_dev", *opts)
+    return ret
+
+
+def bifrost(model='unknown', options=None):
+    """Return an ARM Mali GPU target (Bifrost architecture).
+
+    Parameters
+    ----------
+    options : str or list of str
+        Additional options
+    """
+    opts = ["-device=bifrost", '-model=%s' % model]
+    opts = _merge_opts(opts, options)
+    return _ffi_api.TargetCreate("opencl", *opts)
+
+
+def create(target_str):
+    """Get a target given target string.
+
+    Parameters
+    ----------
+    target_str : str
+        The target string.
+
+    Returns
+    -------
+    target : Target
+        The target object
+
+    Note
+    ----
+    See the note on :py:mod:`~tvm.target` on target string format.
+    """
+    if isinstance(target_str, Target):
+        return target_str
+    if not isinstance(target_str, str):
+        raise ValueError("target_str has to be string type")
+
+    return _ffi_api.TargetFromString(target_str)
index c60b4a8b95b1e75262255b7b4990bf3f29429d2b..8af2bd00bb35f863aa492e6c9869cfc46e338287 100644 (file)
@@ -46,20 +46,20 @@ namespace tvm {
 namespace runtime {
 
 std::string GetCustomTypeName(uint8_t type_code) {
-  auto f = tvm::runtime::Registry::Get("_datatype_get_type_name");
-  CHECK(f) << "Function _datatype_get_type_name not found";
+  auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name");
+  CHECK(f) << "Function runtime._datatype_get_type_name not found";
   return (*f)(type_code).operator std::string();
 }
 
 uint8_t GetCustomTypeCode(const std::string& type_name) {
-  auto f = tvm::runtime::Registry::Get("_datatype_get_type_code");
-  CHECK(f) << "Function _datatype_get_type_code not found";
+  auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code");
+  CHECK(f) << "Function runtime._datatype_get_type_code not found";
   return (*f)(type_name).operator int();
 }
 
 bool GetCustomTypeRegistered(uint8_t type_code) {
-  auto f = tvm::runtime::Registry::Get("_datatype_get_type_registered");
-  CHECK(f) << "Function _datatype_get_type_registered not found";
+  auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered");
+  CHECK(f) << "Function runtime._datatype_get_type_registered not found";
   return (*f)(type_code).operator bool();
 }
 
@@ -612,7 +612,7 @@ TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
   });
 
 // set device api
-TVM_REGISTER_GLOBAL("_GetDeviceAttr")
+TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     TVMContext ctx;
     ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
index a9c820160cde86ba50914f789e6f46f9d17aa73e..ee5e6a62b646263a25369a687b899a108b063fd2 100644 (file)
@@ -244,7 +244,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
   return (*codegen_f)(blob_byte_array, system_lib, target_triple);
 }
 
-TVM_REGISTER_GLOBAL("codegen._Build")
+TVM_REGISTER_GLOBAL("target.Build")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
   if (args[0].IsObjectRef<tir::LoweredFunc>()) {
       *ret = Build({args[0]}, args[1]);
index b49b395ec08a5d8627bfc14da7062744648b1fab..c16182da3674c2d4746042e5ca81f3d16ff6d818 100644 (file)
@@ -25,22 +25,22 @@ namespace datatype {
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
-TVM_REGISTER_GLOBAL("_datatype_register")
+TVM_REGISTER_GLOBAL("runtime._datatype_register")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
 });
 
-TVM_REGISTER_GLOBAL("_datatype_get_type_code")
+TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
 });
 
-TVM_REGISTER_GLOBAL("_datatype_get_type_name")
+TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = Registry::Global()->GetTypeName(args[0].operator int());
 });
 
-TVM_REGISTER_GLOBAL("_datatype_get_type_registered")
+TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
 });
@@ -90,7 +90,6 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t
   } else {
     ss << runtime::TypeCode2Str(src_type_code);
   }
-
   return runtime::Registry::Get(ss.str());
 }
 
index 817d48f0cdbf9d73550c6f7bfee9053eebb0ea04..8eef4b75ff408fe200f40bebb1732370b884df19 100644 (file)
@@ -123,18 +123,18 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
   func.CallPacked(args, ret);
 }
 
-TVM_REGISTER_GLOBAL("_GenericFuncCreate")
+TVM_REGISTER_GLOBAL("target.GenericFuncCreate")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = GenericFunc(make_object<GenericFuncNode>());
   });
 
-TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal")
+TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   std::string func_name = args[0];
   *ret = GenericFunc::Get(func_name);
   });
 
-TVM_REGISTER_GLOBAL("_GenericFuncSetDefault")
+TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   GenericFunc generic_func = args[0];
   // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncSetDefault")
     .set_default(*func, allow_override);
   });
 
-TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
+TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   GenericFunc generic_func = args[0];
   // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
@@ -162,7 +162,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
     .register_func(tags_vector, *func, allow_override);
   });
 
-TVM_REGISTER_GLOBAL("_GenericFuncCallFunc")
+TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   GenericFunc generic_func = args[0];
   TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
index f28bad4d63a6e95af6d4f5685cc06582e69d6208..30755fcfc1253faabb4d037fb580cfa571d3c09f 100644 (file)
@@ -349,11 +349,6 @@ unsigned LookupLLVMIntrinsic(const std::string& name) {
   return llvm::Function::lookupIntrinsicID(name);
 }
 
-TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
-  });
-
 TVM_REGISTER_GLOBAL("codegen.build_llvm")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
     auto n = make_object<LLVMModuleNode>();
@@ -361,9 +356,13 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm")
     *rv = runtime::Module(n);
   });
 
-TVM_REGISTER_GLOBAL("codegen.llvm_version_major")
+TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
+  });
+
+TVM_REGISTER_GLOBAL("target.llvm_version_major")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    std::ostringstream os;
     int major = TVM_LLVM_VERSION / 10;
     *rv = major;
   });
index 245425a63921428a3a95e23fbf8f27b74a85d452..05253a5a2bc935c23a79f568ec4da9145f892c69 100644 (file)
@@ -144,7 +144,7 @@ Target CreateTarget(const std::string& target_name,
   return Target(t);
 }
 
-TVM_REGISTER_GLOBAL("_TargetCreate")
+TVM_REGISTER_GLOBAL("target.TargetCreate")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   std::string target_name = args[0];
   std::vector<std::string> options;
@@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("_TargetCreate")
   *ret = CreateTarget(target_name, options);
   });
 
-TVM_REGISTER_GLOBAL("_TargetFromString")
+TVM_REGISTER_GLOBAL("target.TargetFromString")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   std::string target_str = args[0];
   *ret = Target::Create(target_str);
@@ -269,7 +269,7 @@ tvm::Target Target::Current(bool allow_not_defined) {
   return Target();
 }
 
-TVM_REGISTER_GLOBAL("_GetCurrentTarget")
+TVM_REGISTER_GLOBAL("target.GetCurrentTarget")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   bool allow_not_defined = args[0];
   *ret = Target::Current(allow_not_defined);
@@ -284,10 +284,10 @@ class Target::Internal {
   }
 };
 
-TVM_REGISTER_GLOBAL("_EnterTargetScope")
+TVM_REGISTER_GLOBAL("target.EnterTargetScope")
 .set_body_typed(Target::Internal::EnterScope);
 
-TVM_REGISTER_GLOBAL("_ExitTargetScope")
+TVM_REGISTER_GLOBAL("target.ExitTargetScope")
 .set_body_typed(Target::Internal::ExitScope);
 
 namespace target {
index 66ea5743240aaecc07689953e72750b06e972965..b24fdf158f4a25438c4f2c9a2d1d61d8ff480761 100644 (file)
@@ -95,19 +95,19 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
-#define DEFINE_MUTATE__(OP, NodeName)                                              \
-  inline PrimExpr VisitExpr_(const NodeName* op) final {                                     \
-    auto type_code = op->dtype.code();                                             \
+#define DEFINE_MUTATE__(OP, NodeName)                                   \
+  inline PrimExpr VisitExpr_(const NodeName* op) final {                \
+    auto type_code = op->dtype.code();                                  \
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                                   \
-    op = expr.as<NodeName>();                                                            \
-    if (toBeLowered) {                                                             \
-      auto lower = datatype::Get##OP##LowerFunc(target_, type_code);               \
-      CHECK(lower) << #OP " lowering function for target " << target_ << " type "  \
-                   << static_cast<unsigned>(type_code) << " not found";            \
-      return (*lower)(expr);                                                       \
-    }                                                                              \
-    return expr;                                                                   \
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                    \
+    op = expr.as<NodeName>();                                           \
+    if (toBeLowered) {                                                  \
+      auto lower = datatype::Get##OP##LowerFunc(target_, type_code);    \
+      CHECK(lower) << #OP " lowering function for target " << target_ << " type " \
+                   << static_cast<unsigned>(type_code) << " not found"; \
+      return (*lower)(expr);                                            \
+    }                                                                   \
+    return expr;                                                        \
   }
 
   DEFINE_MUTATE__(Add, AddNode);
index 02626a468aa2c54070548ac8fc8430b7ced7459d..c2c808fa7895429717cae4d5d8acdcb5d63b812d 100644 (file)
@@ -103,11 +103,11 @@ TEST(BuildModule, Heterogeneous) {
     return copy[i] - C[i];
   }, "elemwise_sub");
 
-  const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope");
-  (*enter_target_scope_func)(target_cuda);
+  With<Target> cuda_scope(target_cuda);
   auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
 
-  (*enter_target_scope_func)(target_llvm);
+
+  With<Target> llvm_scope(target_llvm);
   auto s2 = create_schedule({elemwise_sub->op});
 
   auto config = BuildConfig::Create();
index db5214b91d1ffe105db42f08fe6ee426387fab13..f95787dd94a40dfb5d4b4e641ef5b460d32d82ad 100644 (file)
@@ -55,7 +55,7 @@ def test_dot():
         if not tvm.runtime.enabled(target):
             print("Target %s is not enabled" % target)
             return
-        f = tvm.codegen.build_module(fapi, target)
+        f = tvm.target.codegen.build_module(fapi, target)
         # verify
         ctx = tvm.cpu(0)
         a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
index ea729618097e729874809ec844267f2dd72f773a..5876a7052a2d5bbf75c7f947f5a6c6b0330d588f 100644 (file)
@@ -1115,7 +1115,7 @@ def test_conv2d_int8_intrinsics():
 
     # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
     targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
-    llvm_version = tvm.codegen.llvm_version_major()
+    llvm_version = tvm.target.codegen.llvm_version_major()
     for target in targets:
         if llvm_version >= 8:
             dtypes = ('uint8', 'int8', 'int32')
@@ -1208,7 +1208,7 @@ def test_depthwise_conv2d_int8():
     parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
 
     targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
-    llvm_version = tvm.codegen.llvm_version_major()
+    llvm_version = tvm.target.codegen.llvm_version_major()
     for target in targets:
         if llvm_version >= 8:
             with relay.build_config(opt_level=3):
index 4f8758e2eaf7891563df38ecee82f049b92fece2..7043e473ec4d3cd5e65df05ad0073ed5e74c4963 100644 (file)
@@ -50,7 +50,7 @@ def matmul(N, L, M, dtype):
 
 @autotvm.template
 def bad_matmul(N, L, M, dtype):
-    if 'bad_device' in tvm.target.current_target().keys:
+    if 'bad_device' in tvm.target.Target.current().keys:
         A = tvm.placeholder((N, L), name='A', dtype=dtype)
         B = tvm.placeholder((L, M), name='B', dtype=dtype)
 
index 271237b51503222cf58ab96fc01ddc4aaeed72df..a126c07c8ac149cc6d5027dffc038dd762bb7e8d 100644 (file)
@@ -75,7 +75,7 @@ def test_add_pipeline():
         f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
         fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
         fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
-        mhost = tvm.codegen.build_module(fsplits[0], "c")
+        mhost = tvm.target.codegen.build_module(fsplits[0], "c")
         temp = util.tempdir()
         path_dso = temp.relpath("temp.so")
         mhost.export_library(path_dso)
index fe416e6312d99054c95ce820f80b1800dde705b9..63ee03028e7ee18c088c2053f5df3518ee7ab47f 100644 (file)
@@ -84,8 +84,8 @@ def test_add_pipeline():
             return
         if not tvm.runtime.enabled(host):
             return
-        mhost = tvm.codegen.build_module(fsplits[0], host)
-        mdev = tvm.codegen.build_module(fsplits[1:], device)
+        mhost = tvm.target.codegen.build_module(fsplits[0], host)
+        mdev = tvm.target.codegen.build_module(fsplits[1:], device)
         mhost.import_module(mdev)
         code = mdev.get_source()
         f = mhost.entry_func
@@ -110,8 +110,8 @@ def test_add_pipeline():
             fmt = "hsaco"
         else:
             fmt = device
-        mhost = tvm.codegen.build_module(fsplits[0], host)
-        mdev = tvm.codegen.build_module(fsplits[1:], device)
+        mhost = tvm.target.codegen.build_module(fsplits[0], host)
+        mdev = tvm.target.codegen.build_module(fsplits[1:], device)
         temp = util.tempdir()
         mpath = temp.relpath("test.%s" % fmt)
         mdev.save(mpath)
index a37bc2a736e321120920a72ae34c1ef95183debe..c60f3816722c19730a4799ad0cd582d42ac09932 100644 (file)
@@ -570,9 +570,9 @@ def test_dwarf_debug_information():
     def check_llvm_object():
         if not tvm.runtime.enabled("llvm"):
             return
-        if tvm.codegen.llvm_version_major() < 5:
+        if tvm.target.codegen.llvm_version_major() < 5:
             return
-        if tvm.codegen.llvm_version_major() > 6:
+        if tvm.target.codegen.llvm_version_major() > 6:
             return
         # build two functions
         f2 = tvm.lower(s, [A, B, C], name="fadd1")
@@ -607,9 +607,9 @@ def test_dwarf_debug_information():
     def check_llvm_ir():
         if not tvm.runtime.enabled("llvm"):
             return
-        if tvm.codegen.llvm_version_major() < 5:
+        if tvm.target.codegen.llvm_version_major() < 5:
             return
-        if tvm.codegen.llvm_version_major() > 6:
+        if tvm.target.codegen.llvm_version_major() > 6:
             return
         # build two functions
         f2 = tvm.lower(s, [A, B, C], name="fadd1")
index 80c4fa4df0e889bed85373867b1a96871027af70..3bfe01319a3a3c6c53603c172a296728bc522318 100644 (file)
@@ -33,7 +33,7 @@ def test_static_callback():
     stmt = ib.get()
     fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
     fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
-    f = tvm.codegen.build_module(fapi, "llvm")
+    f = tvm.target.codegen.build_module(fapi, "llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
     f(a)
@@ -57,7 +57,7 @@ def test_static_init():
     stmt = ib.get()
     fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
     fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
-    f = tvm.codegen.build_module(fapi, "llvm")
+    f = tvm.target.codegen.build_module(fapi, "llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
 
index 60a948db68bb358479735a33c9d098fa62dde2f1..d477983b097911b08146dac983395f01c9dac1ad 100644 (file)
@@ -21,7 +21,7 @@ def run_jit(fapi, check):
     for target in ["llvm", "stackvm"]:
         if not tvm.runtime.enabled(target):
             continue
-        f = tvm.codegen.build_module(fapi, target)
+        f = tvm.target.codegen.build_module(fapi, target)
         s = f.get_source()
         check(f)
 
index 06591a3cbf76b64d8e874a1751e87f28ea5a77ba..e17c6cf8cbcceedd4c939ccd24ebce8eebbfeffe 100644 (file)
@@ -19,9 +19,9 @@ import re
 
 
 def test_fp16_to_fp32():
-    if tvm.codegen.llvm_version_major() < 6:
+    if tvm.target.codegen.llvm_version_major() < 6:
         print("Skipping due to LLVM version being {} < 6".format(
-            tvm.codegen.llvm_version_major()))
+            tvm.target.codegen.llvm_version_major()))
         return
 
     def fp16_to_fp32(target, width, match=None, not_match=None):
index 79c02efa6cc7cebec7039a417039b7d74ae231ea..00f9b3329835bf62261c5937c9c410affd417dfb 100644 (file)
@@ -29,19 +29,19 @@ def setup_module():
     # In this case, we have built the test functions used below right into TVM.
     # CDLL("libmybfloat16.so", RTLD_GLOBAL)
 
-    tvm.datatype.register("bfloat", 129)
+    tvm.target.datatype.register("bfloat", 129)
 
-    tvm.datatype.register_op(
-        tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast",
+    tvm.target.datatype.register_op(
+        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast",
         "llvm", "bfloat", "float")
-    tvm.datatype.register_op(
-        tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast",
+    tvm.target.datatype.register_op(
+        tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast",
         "llvm", "float", "bfloat")
-    tvm.datatype.register_op(
-        tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm",
+    tvm.target.datatype.register_op(
+        tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm",
         "bfloat")
-    tvm.datatype.register_op(
-        tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm",
+    tvm.target.datatype.register_op(
+        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm",
         "llvm", "bfloat")
 
 def lower_datatypes_and_build(schedule, args):
index 85417d462c3330d73d7427f35194aabf09a5b69f..6da99f82704737d26c0b37d2f18765c4455602aa 100644 (file)
@@ -50,7 +50,7 @@ def test_target_dispatch():
     with tvm.target.create("metal"):
         assert mygeneric(1) == 3
 
-    assert tvm.target.current_target() is None
+    assert tvm.target.Target.current() is None
 
 
 def test_target_string_parse():
index 38a7b43761b58b872b0830fecedcbf5f2bf40e9a..5207b0956941df9823260e5783672226f0112f0b 100644 (file)
@@ -39,7 +39,7 @@ def test_dltensor_compatible():
     stmt = ib.get()
     fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
     fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
-    f = tvm.codegen.build_module(fapi, "stackvm")
+    f = tvm.target.codegen.build_module(fapi, "stackvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     aview = MyTensorView(a)
     f(aview)
index e47db94c4353bbcf6910195c40664c6625cafdb7..b1a784bc48b69df6a7582f9ef08ac5ded39ab574 100644 (file)
@@ -57,7 +57,7 @@ def test_dso_module_load():
                            i + 1))
         fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
         fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
-        m = tvm.codegen.build_module(fapi, "llvm")
+        m = tvm.target.codegen.build_module(fapi, "llvm")
         for name in names:
             m.save(name)
 
index b0d4d1361ccc696e928ec18936ada1a2f816a009..f0d650adeac1a9fefb35abe45c53721754010110 100644 (file)
@@ -588,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     idxd = tvm.indexdiv
 
     if groups == 1:
-        target = tvm.target.current_target()
+        target = tvm.target.Target.current()
         dispatch_ctx = autotvm.DispatchContext.current
         cfg = dispatch_ctx.query(target, workload)
 
@@ -693,12 +693,12 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
         else:
             raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key)
     else:
-        target = tvm.target.current_target()
+        target = tvm.target.Target.current()
         dispatch_ctx = autotvm.DispatchContext.current
         cfg = dispatch_ctx.query(target, workload)
 
         if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
+            autotvm.task.clear_fallback_cache(tvm.target.Target.current(), workload)
             if layout == 'NHWC' and kernel_layout == 'HWOI':
                 new_attrs['data_layout'] = 'NCHW'
                 new_attrs['kernel_layout'] = 'OIHW'
index 3ee231eea41986bab757027bcd1e7f64ab6afe58..2ae65800e9256c9ce1232d32944ae59011bbebd6 100644 (file)
@@ -156,7 +156,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
             # this part to make tuning records correct
             s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region')
         else:
-            max_threads = tvm.target.current_target(allow_none=False).max_num_threads
+            max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
             co, ci, kh, kw, vc = s[kernel_vec].op.axis
             fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
             fused, vec = s[kernel_vec].split(fused, VC)
index ea3e51082657e677a3911b508b130c0d6397275b..d7fc292f0adec0319b37bc076145440e2affc874 100644 (file)
@@ -24,7 +24,7 @@ def fuse_and_bind(s, tensor, axis=None, num_thread=None):
     """Fuse all the axis and bind to GPU threads"""
     axis = axis or s[tensor].op.axis
     fused = s[tensor].fuse(*axis)
-    max_threads = tvm.target.current_target(allow_none=False).max_num_threads
+    max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
     bx, tx = s[tensor].split(fused, num_thread or max_threads)
     s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
     s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
index 2d1b93ec038276851d62fe34250fbde51cfda6f6..24fc2a17aa183c9541cd941dcbee6c007f3fc9fa 100644 (file)
@@ -41,7 +41,7 @@ def batch_matmul_cuda(x, y):
     output : tvm.Tensor
         3-D with shape [batch, M, N]
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name == "cuda" and "cublas" in target.libs:
         return cublas.batch_matmul(x, y, False, True)
     return batch_matmul_default(x, y)
@@ -61,7 +61,7 @@ def schedule_batch_matmul(outs):
     s: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name == "cuda" and "cublas" in target.libs:
         return generic.schedule_extern(outs)
 
index 201921564cbf25ffca69eed18a4ad65417769e00..43754a31df48d9a3b580f70d1223fe3ea3a90bfe 100644 (file)
@@ -115,7 +115,7 @@ def schedule_conv1d_ncw(cfg, outs):
             cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
@@ -230,7 +230,7 @@ def schedule_conv1d_nwc(cfg, outs):
             cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
index be7824e71e814be0fecb62e7012a0667938c9616..4cedbd529f024f5994101929f274e03544a3fe6e 100644 (file)
@@ -116,7 +116,7 @@ def schedule_conv1d_transpose_ncw_cuda(cfg, outs):
             cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
index d831ba494d9f4932260202945c748dcddddda13e..f26069cfc3f0c86adcab397a764600b33671dbe6 100644 (file)
@@ -69,7 +69,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     if "cudnn" in target.libs:
         if layout == 'NCHW':
@@ -148,7 +148,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if 'cudnn' in target.libs:
         return generic.schedule_extern(outs)
 
@@ -186,7 +186,7 @@ def schedule_conv2d_nhwc_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if 'cudnn' in target.libs:
         return generic.schedule_extern(outs)
 
index d64712550855fb2a751debc90890c6fa76192538..b7df88579f4936e67561f45fffbd71a0a19161dd 100644 (file)
@@ -34,7 +34,7 @@ def schedule_direct_cuda(cfg, s, conv):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
index 26bc261696742d005902e31dfbf766fe3f55482b..be9f31567bc9f9e73516777c62ebbcd2bc19a09f 100644 (file)
@@ -170,7 +170,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
             cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
index dfa569a556ce1b34b44b44c5ac4ae29998c8137f..37307d62357d9142067a8c8cae6713917c36d29a 100644 (file)
@@ -194,7 +194,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_x", x, num_outputs=4)
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
@@ -325,7 +325,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
     Unlike other TOPI functions, this function operates on both graph level and operator level,
     so we have to pass 'F' to make it support our two versions of graph IR,  Relay.
     """
-    if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs:
+    if 'cudnn' in tvm.target.Target.current().libs or 'miopen' in tvm.target.Target.current().libs:
         return None
 
     copy_inputs = list(inputs)
@@ -349,7 +349,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
     CO, _, KH, KW = get_const_tuple(kernel.shape)
 
     dispatch_ctx = autotvm.DispatchContext.current
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     if groups == 1:
         # query config of this workload
index 7d3c0b4afc1b436d7cfcf8a0c8723913bafe18d6..b46f284ef5b7e034d7eb7ab05e8fa501b72020a6 100644 (file)
@@ -64,7 +64,7 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
     output : tvm.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     if "cudnn" in target.libs:
         if layout == 'NCDHW':
@@ -126,7 +126,7 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if 'cudnn' in target.libs:
         return generic.schedule_extern(outs)
 
@@ -160,7 +160,7 @@ def schedule_conv3d_ndhwc_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if 'cudnn' in target.libs:
         return generic.schedule_extern(outs)
 
index e38dbcbfa0020203d7e1b4bed163bd9126580ac3..ad48deb2753983982ee49a596e480ede52a0e7c2 100644 (file)
@@ -36,7 +36,7 @@ def schedule_direct_3d_cuda(cfg, s, conv):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
index a0e1cb8f5fc652bdf1cc5861816822d9a44a7de0..33a8c9adc1ca660cf52904b1aa8bdcbacf3e551f 100644 (file)
@@ -67,7 +67,7 @@ def schedule_direct_cuda(cfg, s, conv):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
index f17feb0b7a981574c03439bd46982c38c58f2a80..1a1af703c55cf85300be860b54bbf4d0d9bf8295 100644 (file)
@@ -60,7 +60,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
         out_dtype = data.dtype
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cublas" in target.libs:
         matmul = cublas.matmul(data, weight, False, True, out_dtype)
         if bias is not None:
@@ -87,7 +87,7 @@ def schedule_dense(cfg, outs):
         The computation schedule for dense.
     """
     # pylint: disable=unused-argument
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name == "cuda" and "cublas" in target.libs:
@@ -259,7 +259,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
     batch, in_dim = get_const_tuple(data.shape)
     out_dim, _ = get_const_tuple(weight.shape)
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cublas" in target.libs:
         matmul = cublas.matmul(data, weight, False, True, out_dtype)
         if bias is not None:
@@ -290,7 +290,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
 def schedule_dense_int8(cfg, outs):
     """Dense schedule for int8 on CUDA"""
     s = tvm.create_schedule([x.op for x in outs])
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if "cublas" in target.libs:
index 6dbfbfe39cae8795dd196a35dc00311d08ae686a..05e1117ac2cee1530b00eba3bf9eb25d960ad794 100644 (file)
@@ -57,7 +57,7 @@ def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
             cfg.define_split("tile_x", x, num_outputs=4)
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
@@ -166,7 +166,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
 
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
-        target = tvm.target.current_target()
+        target = tvm.target.Target.current()
         if target and (target.target_name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
index f4bb73470651eb03133caaf703ba8176a5988f92..54e8427daf79a8d9591baa3f3c5f350c7c0a90e1 100644 (file)
@@ -340,7 +340,7 @@ def schedule_group_conv2d_nchw_direct(cfg, s, conv):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
index 0a131148be68674aded432a8433bdf39cdfc9685..b77a9792471627d4fa132777e97a3bcc4e149415 100644 (file)
@@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out):
          The updated schedule.
     """
     fused = sch[out].fuse(*sch[out].op.axis)
-    num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+    num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
     max_block = 256
 
     try:
index 8ddb44efc738225cca526feac067479c32cf93d0..38f87a9523c8b703890de564f98fefe44818d695 100644 (file)
@@ -71,7 +71,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
     id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
 
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -120,7 +120,7 @@ def get_valid_counts_upsweep(data, idx_in, idx, partial):
     idx_in = ib.buffer_ptr(idx_in)
     idx = ib.buffer_ptr(idx)
     partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     elem_per_thread = num_anchors // max_threads + 1
     nthread_tx = max_threads
     nthread_bx = batch_size
@@ -176,7 +176,7 @@ def get_valid_counts_scan(data, partial_in, partial):
     ib = tvm.ir_builder.create()
     partial_in = ib.buffer_ptr(partial_in)
     partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     elem_per_thread = num_anchors // max_threads + 1
     nthread_tx = max_threads
     nthread_bx = batch_size
@@ -234,7 +234,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
     idx_in = ib.buffer_ptr(idx_in)
     idx = ib.buffer_ptr(idx)
     partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     elem_per_thread = num_anchors // max_threads + 1
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors // max_threads + 1
@@ -297,7 +297,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
     valid_count = ib.buffer_ptr(valid_count)
     out = ib.buffer_ptr(out)
 
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -356,7 +356,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     elem_per_thread = num_anchors // max_threads + 1
     new_range = num_anchors // elem_per_thread + 1
     temp_flag_buf = api.decl_buffer(
@@ -482,7 +482,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
 
     max_threads = int(
-        tvm.target.current_target(allow_none=False).max_num_threads)
+        tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -594,7 +594,7 @@ def invalid_to_bottom_pre(data, flag, idx):
     idx = ib.buffer_ptr(idx)
 
     max_threads = int(math.sqrt(
-        tvm.target.current_target(allow_none=False).max_num_threads))
+        tvm.target.Target.current(allow_none=False).max_num_threads))
     nthread_tx = max_threads
     nthread_bx = num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -654,7 +654,7 @@ def invalid_to_bottom_ir(data, flag, idx, out):
     out = ib.buffer_ptr(out)
 
     max_threads = int(math.sqrt(
-        tvm.target.current_target(allow_none=False).max_num_threads))
+        tvm.target.Target.current(allow_none=False).max_num_threads))
     nthread_tx = max_threads
     nthread_bx = num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
index a5c310eb1a45aba41b0291bd625283ace4a6c1a1..327afa87edb5e61a4d416ebf295367d13ba41226 100644 (file)
@@ -37,6 +37,6 @@ def schedule_lrn(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.cuda.schedule_lrn(cpp_target, outs)
index f11085afb3cd9a2688c4d75592713527bba0847e..2bf1e6bb9ef0db80f4bebe2ab625720f43be3cb0 100644 (file)
@@ -112,7 +112,7 @@ def schedule_pool(outs, layout):
     def _schedule(PaddedInput, Pool):
         if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
             s[PaddedInput].compute_inline()
-        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+        num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
         if Pool.op in s.outputs:
             Out = Pool
             OL = s.cache_write(Pool, "local")
@@ -177,7 +177,7 @@ def schedule_pool_grad_cuda(outs):
         else:
             out = outs[0].op.output(0)
         fused = s[out].fuse(*s[out].op.axis)
-        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+        num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
         bx, tx = s[out].split(fused, factor=num_thread)
         s[out].bind(bx, tvm.thread_axis("blockIdx.x"))
         s[out].bind(tx, tvm.thread_axis("threadIdx.x"))
index 7567c651772b71a27dc8f02132dbd6907bbbcae2..4344226d787e59885d0cd46e8bc1b4c3a0b26ccf 100644 (file)
@@ -64,7 +64,7 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
     """
     batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
     num_anchors //= 2
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = (batch * height * width) // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -152,7 +152,7 @@ def argsort_ir(data_buf, out_index_buf):
         The result IR statement.
     """
     batch, num_bbox = get_const_tuple(data_buf.shape)
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
     p_data = ib.buffer_ptr(data_buf)
     index_out = ib.buffer_ptr(out_index_buf)
@@ -225,7 +225,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
         return i / u
 
     batch, num_bbox = get_const_tuple(out_buf.shape)
-    max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
+    max_threads = int(math.sqrt(tvm.target.Target.current(allow_none=False).max_num_threads))
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
     ib = tvm.ir_builder.create()
index 2968ab75e040aba4756a02d25439bad8e7f2a35e..69c685cb50b4c67d6a97d9f2679c317c6b6152e9 100644 (file)
@@ -35,7 +35,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
     if len(sch[data_out].op.axis) > 0:
         all_reduce = False
         num_thread = 32
-        target = tvm.target.current_target()
+        target = tvm.target.Target.current()
         if target and target.target_name == "opencl":
             # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
             # don't know why
@@ -45,7 +45,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
     else:
         all_reduce = True
-        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+        num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
         thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
 
     # Fuse and refactor the reduce axis
index 0e7a23eb14ab15192f2474ac16338d50fbcbb8a5..b32cce75362f8af2a17378fa3aa121405f38f02b 100644 (file)
@@ -87,7 +87,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
@@ -186,7 +186,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
     data = ib.buffer_ptr(data)
     valid_count = ib.buffer_ptr(valid_count)
index e1af4365520ee6a77987a884817b9365e3f4557a..10ba7a1051ea620db763ee421a4d3cf9b538979d 100644 (file)
@@ -60,7 +60,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
         The result IR statement.
     """
     max_threads = int(math.sqrt(
-        tvm.target.current_target(allow_none=False).max_num_threads))
+        tvm.target.Target.current(allow_none=False).max_num_threads))
     tx = tvm.thread_axis("threadIdx.x")
     ty = tvm.thread_axis("threadIdx.y")
     bx = tvm.thread_axis("blockIdx.x")
@@ -196,7 +196,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
 
     threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)
 
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = (batch_size *  num_anchors) // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -307,7 +307,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
     score = ib.buffer_ptr(temp_score)
     out_loc = ib.buffer_ptr(out)
 
-    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = (batch_size * num_anchors) // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
index 2df273ff50e3dab9f241243f3cea13849eeb9fb4..d456aadf4f5efe53df2b0049f8879d53f3218c1b 100644 (file)
@@ -53,7 +53,7 @@ def schedule_reorg(outs):
     s: Schedule
         The computation schedule for reorg.
     """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.cuda.schedule_injective(cpp_target, outs)
 
index a0601147d5e8aeb4a50ff44fa029a03984d0563c..e895385e8b66fb4a6706d6d91922d91ffe688806 100644 (file)
@@ -36,5 +36,5 @@ def schedule_extern(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     return cpp.generic.schedule_extern(target, outs)
index 178363dc0d4ff0635f14780cf592f72e2f791807..2aff96f9636c1a31b713d7421ad39efe99c93c7d 100644 (file)
@@ -54,7 +54,7 @@ def schedule_injective(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     if target.target_name != "llvm":
         raise RuntimeError("schedule_injective not registered for '%s'" % target)
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
index be9e54e97a1e74484d5b0b725f97863e2c538fee..8831829412027cb1360740a45fbe02ff525f46f6 100644 (file)
@@ -22,7 +22,7 @@ from .. import cpp
 
 def _default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name not in ("llvm", "c"):
         raise RuntimeError("schedule not registered for '%s'" % target)
@@ -645,7 +645,7 @@ def schedule_lrn(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
@@ -686,6 +686,6 @@ def schedule_sparse_transpose(outs):
 
 @tvm.target.generic_func
 def schedule_batch_matmul(outs):
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
index a1e096a85880614ecc5bb38a4f25230ce5e68143..85d9153e6424b6ecd3754d4ca006fce0463145d6 100644 (file)
@@ -22,7 +22,7 @@ from .. import cpp
 
 def _default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name != "llvm":
         raise RuntimeError("schedule not registered for '%s'" % target)
@@ -48,7 +48,7 @@ def schedule_reorg(outs):
     s: Schedule
       The computation schedule for the op.
     """
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
index f02eb497f519ce9736d114f0a23ff9379aab1acf..65ea590905f911e31105bee2b323641cf6db1787 100644 (file)
@@ -221,7 +221,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
         return None
 
     dispatch_ctx = autotvm.task.DispatchContext.current
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     # query schedule and fallback if necessary
     workload = autotvm.task.args_to_workload(
index c747c539d7fe6ec50a942083274e985ec16d34a8..97b7376933de8dcd9bb30bfe561dbabe438934c5 100644 (file)
@@ -59,7 +59,7 @@ def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs):
             cfg.define_split("tile_x", x, num_outputs=4)
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
-            target = tvm.target.current_target()
+            target = tvm.target.Target.current()
             if target.target_name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
@@ -167,7 +167,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
 
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
-        target = tvm.target.current_target()
+        target = tvm.target.Target.current()
         if target and (target.target_name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
index ea4661f7602e23b5ace572c58c2094e3fa556096..35a86e991c236bde1ba510859c13b6107c8760a5 100644 (file)
@@ -153,7 +153,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
             # this part to make tuning records correct
             s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region')
         else:
-            max_threads = tvm.target.current_target(allow_none=False).max_num_threads
+            max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
             co, ci, kh, kw, vc = s[kernel_vec].op.axis
             fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
             fused, vec = s[kernel_vec].split(fused, VC)
index abdb5f22e36980768d7ae44faf364001c2aa648c..52f4b12a1d2dc5b51dc53957c3d65c5fcb5673ef 100644 (file)
@@ -465,7 +465,7 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
 
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
     in_channel = ic_chunk * ic_bn
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
         get_const_tuple(kernel.shape)
     num_filter = oc_chunk * oc_bn
index 0a41838aa50ed349a4d80455e524644c8c1665b4..be29c6f6b0cc59a7425fc6b3671aa3576b164a35 100644 (file)
@@ -57,7 +57,7 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
         4-D with shape [batch, out_channel, out_height, out_width]
     """
 
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "miopen" in target.libs:
         assert layout == 'NCHW', "Only NCHW layout is supported."
         CO, CI, KH, KW = get_const_tuple(kernel.shape)
@@ -106,7 +106,7 @@ def schedule_conv2d_nchw_rocm(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target and "miopen" in target.libs:
         return generic.schedule_extern(outs)
 
index 6fca7cd7965670edfee00b1fb2bf42d4e678436e..f2adeaabef61ae49e436b76f97e845380100be19 100644 (file)
@@ -56,7 +56,7 @@ def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
         out_dtype = data.dtype
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "rocblas" in target.libs:
         assert out_dtype == data.dtype, "Mixed precision not supported."
         matmul = rocblas.matmul(data, weight, False, True)
@@ -83,7 +83,7 @@ def schedule_dense(cfg, outs):
     s: Schedule
         The computation schedule for dense.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if target.target_name == "rocm" and "rocblas" in target.libs:
         return generic.schedule_extern(outs)
     return topi.cuda.schedule_dense(cfg, outs)
index bb6a8bf43557aa7381380c461c289f755a2bd509..8a9c8c393da6e5514fb05ac5e5d3817b50754792 100644 (file)
@@ -23,6 +23,6 @@ from .. import cpp
 
 @generic.schedule_lrn.register(["rocm", "gpu"])
 def schedule_lrn(outs):
-    target = tvm.target.current_target(allow_none=False)
+    target = tvm.target.Target.current(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.rocm.schedule_lrn(cpp_target, outs)
index 25b49d12400db7945f2568913c03dedd1381cad9..fef6c48d6bedc0052d54d7025f4d7638c7008d66 100644 (file)
@@ -43,7 +43,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y):
     output : tvm.Tensor
         3-D with shape [batch, M, N]
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         return cblas.batch_matmul(x, y, False, True)
 
@@ -83,7 +83,7 @@ def schedule_batch_matmul(cfg, outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         return generic.schedule_extern(outs)
 
index 1ba4f68be6c4cc3be279d6c2d6ef233294cee83a..95ce3376ac3a1a82756af6fdcfad792c9cb3a3bc 100644 (file)
@@ -74,7 +74,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
         kh, kw, oc, _ = kshape
     elif pat.match(layout) is not None:
         n, ic_chunk, h, w, ic_bn = dshape
-        target = tvm.target.current_target(allow_none=False)
+        target = tvm.target.Target.current(allow_none=False)
         oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
         assert ic_chunk == k_ic_chunk
         assert ic_bn == k_ic_bn
@@ -423,7 +423,7 @@ def _schedule_conv2d_NCHWc(cfg, outs):
                 data = data_pad.op.input_tensors[0]
 
             args = [s, cfg, data_vec, conv_out, outs[0]]
-            target = tvm.target.current_target(allow_none=False)
+            target = tvm.target.Target.current(allow_none=False)
             _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
             if kh == 1 and kw == 1:
                 conv2d_avx_1x1._schedule_conv_NCHWc(*args)
index 9374387fb23a47af2bd938aae65a27a99f4947ba..8b0c13c2c0bba3eac74dc46e7fb70a06eae9983c 100644 (file)
@@ -75,7 +75,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
 
     # Set workload. Config update.
     dispatch_ctx = autotvm.task.DispatchContext.current
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
 
     if is_depthwise:
         workload = autotvm.task.args_to_workload(
index 1701643844e105d05c724a6fff970dbaca306888..20712d2f6f4fe0e01e5ed5622f8bde6760ddddda 100644 (file)
@@ -64,11 +64,11 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
     is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
 
     # 2) Check LLVM support
-    llvm_version = tvm.codegen.llvm_version_major()
+    llvm_version = tvm.target.codegen.llvm_version_major()
     is_llvm_support = llvm_version >= 8
 
     # 3) Check target
-    mcpu = tvm.target.current_target().mcpu
+    mcpu = tvm.target.Target.current().mcpu
     is_target_support = False
     if mcpu in ('skylake-avx512', 'cascadelake'):
         is_target_support = True
@@ -89,7 +89,7 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay
         kh, kw, oc, _ = kshape
     elif pat.match(layout) is not None:
         n, ic_chunk, h, w, ic_bn = dshape
-        target = tvm.target.current_target(allow_none=False)
+        target = tvm.target.Target.current(allow_none=False)
         oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
         ic = ic_chunk * ic_bn
         assert ic == k_ic * k_ic_f * k_ic_s
@@ -205,7 +205,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, outs):
                 data = data_pad.op.input_tensors[0]
 
             args = [s, cfg, data_vec, conv_out, outs[0]]
-            target = tvm.target.current_target(allow_none=False)
+            target = tvm.target.Target.current(allow_none=False)
             # int8 conv kernel is 7-dim
             _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
             if kh == 1 and kw == 1:
index dd1822f0fd7375f14cb967e763bbd5a8cc4010e2..c6c3d5e667acebfcfa5aa05d66ae2cd4cf9fc119 100644 (file)
@@ -28,7 +28,7 @@ from ..util import traverse_inline, get_const_tuple
 
 @autotvm.register_topi_compute(nn.dense, "cpu", "direct")
 def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         C = cblas.matmul(data, weight, False, True)
         if bias is not None:
@@ -119,7 +119,7 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
 
 @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
 def _schedule_dense(cfg, outs):
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         return generic.schedule_extern(outs)
 
@@ -136,7 +136,7 @@ def _schedule_dense(cfg, outs):
 
 @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
 def _schedule_dense_pack(cfg, outs):
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         return generic.schedule_extern(outs)
 
@@ -151,7 +151,7 @@ def _schedule_dense_pack(cfg, outs):
 
 @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
 def _schedule_dense_nopack(cfg, outs):
-    target = tvm.target.current_target()
+    target = tvm.target.Target.current()
     if "cblas" in target.libs:
         return generic.schedule_extern(outs)
 
index a8ad251115d79d07c49027c23d2e97f0707912e6..dc9e1456d2cdef1fb7a0cb273471ba522eeebcb1 100644 (file)
 """Core kernel of dot product of 4 Int8 operations"""
 #pylint: disable=invalid-name
 import tvm
+import tvm.target.codegen
 
 
 def dot_16x1x16_uint8_int8_int32():
     """Dispatch the most optimized intrin depending on the target"""
-    mcpu = tvm.target.current_target().mcpu
+    mcpu = tvm.target.Target.current().mcpu
 
     assert mcpu in ("skylake-avx512", "cascadelake"), \
             "An old Intel machine that does not have fast Int8 support."
@@ -254,7 +255,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
             vec_b = ins[1].vload([0, 0], "int8x64")
 
             vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512'
-            llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
+            llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
 
             if llvm_id != 0: # VNNI is available for current LLVM version
                 vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b)
index aff37aa3620255b957495c3d9bf4807ef76b88b9..04931f577b5118c0400fd2390690ea47f1b8c599 100644 (file)
@@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs
 import tvm
 
 def get_fp32_len():
-    mcpu = tvm.target.current_target().mcpu
+    mcpu = tvm.target.Target.current().mcpu
     fp32_vec_len = 8
     if mcpu in ('skylake-avx512', 'cascadelake'):
         fp32_vec_len = 16
index ae77f00fb8a9f9cb4df36c9dbd111294726ab538..bf6409cc9405fba290f23b8c3ec21d860a06f497 100644 (file)
@@ -84,7 +84,7 @@ def compute_conv2d(attrs, inputs, output_type, target):
                                               groups,
                                               out_dtype)]
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
             return _nn.compute_conv2d(attrs, inputs, output_type, target)
 
     # If VTA is not the target, default to _nn def
@@ -105,8 +105,8 @@ def schedule_conv2d(attrs, outs, target):
                 return topi.generic.schedule_conv2d_nchw(outs)
             return topi.generic.schedule_group_conv2d_nchw(outs)
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
-            return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
+            return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current())
 
     # If VTA is not the target, default to _nn def
     return _nn.schedule_conv2d(attrs, outs, target)
@@ -128,7 +128,7 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target):
             return [topi.nn.conv2d_transpose_nchw(
                 inputs[0], inputs[1], strides, padding, out_dtype)]
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
             return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
 
     # If VTA is not the target, default to _nn def
@@ -145,11 +145,11 @@ def schedule_conv2d_transpose(attrs, outputs, target):
         if is_packed_layout(layout):
             return topi.nn.schedule_conv2d_transpose_nchw(outputs)
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
-            return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target())
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
+            return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
 
     # If VTA is not the target, default to _nn def
-    return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target())
+    return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
 
 
 @reg.register_compute("nn.dense", level=15)
@@ -163,7 +163,7 @@ def compute_dense(attrs, inputs, out_type, target):
             target = tvm.target.create(target)
             return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
             return _nn.compute_dense(attrs, inputs, out_type, target)
 
     # If VTA is not the target, default to _nn def
@@ -179,8 +179,8 @@ def schedule_dense(attrs, outs, target):
             assert target.device_name == "vta"
             return topi.generic.schedule_dense(outs)
         # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.current_target().model):
-            return _nn.schedule_dense(attrs, outs, tvm.target.current_target())
+        with tvm.target.arm_cpu(tvm.target.Target.current().model):
+            return _nn.schedule_dense(attrs, outs, tvm.target.Target.current())
 
     # If VTA is not the target, default to _nn def
     return _nn.schedule_dense(attrs, outs, target)
index 2780f26ca57cc33a38d59eb905f6d24ee185a035..265a6392b0546e547be20c229d616bbc8fb7da82 100644 (file)
@@ -80,7 +80,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.current_target().device_name == 'vta':
+    if tvm.target.Target.current().device_name == 'vta':
         s = topi.generic.schedule_conv2d_nchw([res])
     else:
         s = tvm.create_schedule([res.op])
index f779b76f8277561cf94c55ccc744351d10212e88..d6475abff66793bc9e69d7e496798cd7be250590 100644 (file)
@@ -68,7 +68,7 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding):
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.current_target().device_name == 'vta':
+    if tvm.target.Target.current().device_name == 'vta':
         s = topi.generic.schedule_conv2d_transpose_nchw([res])
     else:
         s = tvm.create_schedule([res.op])
index 7813b00fc878a33af531b25fb99e115ac2989d0c..fa49be7f9b278641e04b5b2d4ec71c4c89815b3a 100644 (file)
@@ -59,7 +59,7 @@ def dense(N, CI, CO):
         res = my_clip(res, 0, 127)
         res = topi.cast(res, "int8")
 
-    if tvm.target.current_target().device_name == 'vta':
+    if tvm.target.Target.current().device_name == 'vta':
         s = topi.generic.schedule_dense([res])
     else:
         s = tvm.create_schedule([res.op])
index c578090e26aa3d4b00f301b3966137053f400ea0..555154d708fce4253b53a0e45738f96f5d0e10f4 100644 (file)
@@ -80,7 +80,7 @@ def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
 
-    if tvm.target.current_target().device_name == 'vta':
+    if tvm.target.Target.current().device_name == 'vta':
         s = topi.generic.schedule_group_conv2d_nchw([res])
     else:
         s = tvm.create_schedule([res.op])
index 9d8ed8980bb01d1e56ea552a550046bc8679f9e9..b9edc30e5ba314706392d254ffbdfcfd7b472698 100644 (file)
@@ -84,7 +84,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.current_target().device_name == 'vta':
+        if tvm.target.Target.current().device_name == 'vta':
             s = topi.generic.schedule_conv2d_nchw([res])
         else:
             s = tvm.create_schedule([res.op])
@@ -102,7 +102,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.current_target().device_name == 'vta':
+        if tvm.target.Target.current().device_name == 'vta':
             s = topi.generic.schedule_dense([res])
         else:
             s = tvm.create_schedule([res.op])
index 3221c3b77b1fb71cf74b60ea1fd38a8c7a2082e2..94fba3db29890b739af1d321eef1c6d634e63ba5 100644 (file)
@@ -321,7 +321,7 @@ def register_vta_tuning_tasks():
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
-        if tvm.target.current_target().device_name == 'vta':
+        if tvm.target.Target.current().device_name == 'vta':
             s = topi.generic.schedule_conv2d_nchw([res])
         else:
             s = tvm.create_schedule([res.op])