[REFACTOR][PY] Establish tvm.tir
authortqchen <tqchen@octoml.ai>
Wed, 12 Feb 2020 23:38:39 +0000 (15:38 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 14 Feb 2020 04:24:01 +0000 (20:24 -0800)
- Move related files into the corresponding location as in C++
- Keep the top-level TVM API backward compatible to make minimum changes in topi

121 files changed:
python/tvm/__init__.py
python/tvm/api.py
python/tvm/autotvm/task/task.py
python/tvm/build_module.py
python/tvm/contrib/cblas.py
python/tvm/contrib/cublas.py
python/tvm/contrib/cublaslt.py
python/tvm/contrib/cudnn.py
python/tvm/contrib/miopen.py
python/tvm/contrib/mps.py
python/tvm/contrib/nnpack.py
python/tvm/contrib/random.py
python/tvm/contrib/rocblas.py
python/tvm/generic.py
python/tvm/hybrid/calls.py
python/tvm/hybrid/parser.py
python/tvm/hybrid/util.py
python/tvm/intrin.py
python/tvm/ir/__init__.py
python/tvm/ir/attrs.py
python/tvm/ir/base.py
python/tvm/ir/expr.py
python/tvm/make.py
python/tvm/relay/_parser.py
python/tvm/relay/backend/vm.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/op/_transform.py
python/tvm/schedule.py
python/tvm/target/__init__.py
python/tvm/target/datatype.py
python/tvm/target/intrin.py [new file with mode: 0644]
python/tvm/tensor.py
python/tvm/tensor_intrin.py
python/tvm/tir/__init__.py [new file with mode: 0644]
python/tvm/tir/_ffi_api.py [new file with mode: 0644]
python/tvm/tir/buffer.py [new file with mode: 0644]
python/tvm/tir/data_layout.py [new file with mode: 0644]
python/tvm/tir/expr.py [moved from python/tvm/expr.py with 77% similarity]
python/tvm/tir/generic.py [new file with mode: 0644]
python/tvm/tir/ir_builder.py [moved from python/tvm/ir_builder.py with 89% similarity]
python/tvm/tir/ir_pass.py [moved from python/tvm/ir_pass.py with 96% similarity]
python/tvm/tir/op.py [new file with mode: 0644]
python/tvm/tir/stmt.py [moved from python/tvm/stmt.py with 85% similarity]
src/api/api_ir.cc
src/api/api_lang.cc
src/ir/expr.cc
src/node/reflection.cc
src/printer/relay_text_printer.cc
src/tir/ir/buffer.cc
src/tir/ir/data_layout.cc
tests/python/integration/test_reduce.py
tests/python/relay/test_backend_compile_engine.py
tests/python/relay/test_cpp_build_module.py
tests/python/relay/test_external_codegen.py
tests/python/relay/test_external_runtime.py
tests/python/relay/test_ir_nodes.py
tests/python/relay/test_ir_parser.py
tests/python/relay/test_op_level2.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level4.py
tests/python/relay/test_pass_alpha_equal.py
tests/python/relay/test_pass_partition_graph.py
tests/python/relay/test_type_functor.py
tests/python/unittest/test_arith_canonical_simplify.py
tests/python/unittest/test_arith_const_int_bound.py
tests/python/unittest/test_arith_deduce_bound.py
tests/python/unittest/test_arith_domain_touched.py
tests/python/unittest/test_arith_intset.py
tests/python/unittest/test_arith_modular_set.py
tests/python/unittest/test_arith_rewrite_simplify.py
tests/python/unittest/test_arith_stmt_simplify.py
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_codegen_llvm.py
tests/python/unittest/test_codegen_opencl.py
tests/python/unittest/test_codegen_static_init.py
tests/python/unittest/test_codegen_vm_basic.py
tests/python/unittest/test_codegen_vulkan.py
tests/python/unittest/test_custom_datatypes_mybfloat16.py
tests/python/unittest/test_hybrid_script.py
tests/python/unittest/test_ir_builder.py
tests/python/unittest/test_lang_basic.py
tests/python/unittest/test_lang_constructor.py
tests/python/unittest/test_lang_container.py
tests/python/unittest/test_lang_data_layout.py
tests/python/unittest/test_lang_operator.py
tests/python/unittest/test_lang_reflection.py
tests/python/unittest/test_lang_schedule.py
tests/python/unittest/test_lang_tensor.py
tests/python/unittest/test_lang_tensor_overload_op.py
tests/python/unittest/test_pass_attrs_hash_equal.py
tests/python/unittest/test_pass_basic.py
tests/python/unittest/test_pass_bound_checkers.py
tests/python/unittest/test_pass_combine_context_call.py
tests/python/unittest/test_pass_decorate_device_scope.py
tests/python/unittest/test_pass_hoist_if.py
tests/python/unittest/test_pass_inject_copy_intrin.py
tests/python/unittest/test_pass_inject_double_buffer.py
tests/python/unittest/test_pass_inline.py
tests/python/unittest/test_pass_ir_transform.py
tests/python/unittest/test_pass_lift_attr_scope.py
tests/python/unittest/test_pass_loop_partition.py
tests/python/unittest/test_pass_lower_intrin.py
tests/python/unittest/test_pass_remove_no_op.py
tests/python/unittest/test_pass_rewrite_unsafe_select.py
tests/python/unittest/test_pass_storage_flatten.py
tests/python/unittest/test_pass_storage_rewrite.py
tests/python/unittest/test_pass_storage_sync.py
tests/python/unittest/test_pass_unroll.py
tests/python/unittest/test_pass_vectorize.py
tests/python/unittest/test_runtime_module_load.py
tests/python/unittest/test_schedule_schedule_ops.py
tests/python/unittest/test_schedule_tensorize.py
topi/python/topi/math.py
topi/python/topi/vision/rcnn/roi_pool.py
tutorials/dev/low_level_custom_pass.py
tutorials/language/intrin_math.py
tutorials/language/tuple_inputs.py
vta/python/vta/build_module.py
vta/python/vta/environment.py
vta/python/vta/intrin.py
vta/python/vta/ir_pass.py

index 9327c08..7d9cc1a 100644 (file)
@@ -33,36 +33,40 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
 from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
 from .runtime import ndarray as nd
 
+# tvm.error
+from . import error
+
 # tvm.ir
 from .ir import IRModule
 from .ir import transform
 from .ir import container
 from . import ir
 
+# tvm.tir
+from . import tir
+
+# tvm.target
+from . import target
+
 # others
 from . import tensor
 from . import arith
-from . import expr
-from . import stmt
 from . import make
-from . import ir_pass
 from . import schedule
-
-from . import ir_builder
-from . import target
-from . import generic
 from . import hybrid
 from . import testing
-from . import error
-
 
 from .api import *
-from .intrin import *
 from .tensor_intrin import decl_tensor_intrin
 from .schedule import create_schedule
 from .build_module import build, lower, build_config
 from .tag import tag_scope
 
+# backward compact for topi, to be removed later
+from .tir import expr, stmt, ir_builder, ir_pass, generic
+from .tir.op import *
+from . import intrin
+
 # Contrib initializers
 from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
 
index e7778d6..3a8eedc 100644 (file)
@@ -23,13 +23,17 @@ import tvm.ir
 
 from tvm.runtime import convert, const, DataType
 from tvm.ir import container as _container
+from tvm.tir import expr as _expr
+from tvm.tir import stmt as _stmt
+from tvm.tir import decl_buffer, layout, bijective_layout
+from tvm.tir import min_value, max_value, indexdiv, indexmod
+import tvm.tir._ffi_api
 
 from ._ffi.base import string_types, TVMError
 from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
 
 from . import _api_internal
 from . import make as _make
-from . import expr as _expr
 from . import tensor as _tensor
 from . import schedule as _schedule
 from . import tag as _tag
@@ -40,37 +44,6 @@ float32 = "float32"
 handle = "handle"
 
 
-def min_value(dtype):
-    """minimum value of dtype
-
-    Parameters
-    ----------
-    dtype : str
-        The data type.
-
-    Returns
-    -------
-    value : tvm.Expr
-        The minimum value of dtype.
-    """
-    return _api_internal._min_value(dtype)
-
-
-def max_value(dtype):
-    """maximum value of dtype
-
-    Parameters
-    ----------
-    dtype : str
-        The data type.
-
-    Returns
-    -------
-    value : tvm.Expr
-        The maximum value of dtype.
-    """
-    return _api_internal._max_value(dtype)
-
 def var(name="tindex", dtype=int32):
     """Create a new variable with specified name and dtype
 
@@ -87,7 +60,7 @@ def var(name="tindex", dtype=int32):
     var : Var
         The result symbolic variable.
     """
-    return _api_internal._Var(name, dtype)
+    return _expr.Var(name, dtype)
 
 
 def size_var(name="size", dtype=int32):
@@ -106,7 +79,7 @@ def size_var(name="size", dtype=int32):
     var : SizeVar
         The result symbolic shape variable.
     """
-    return _api_internal._SizeVar(name, dtype)
+    return _expr.SizeVar(name, dtype)
 
 
 def any(*args):
@@ -126,9 +99,9 @@ def any(*args):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    ret = _make._OpOr(args[0], args[1])
+    ret = tvm.tir._ffi_api._OpOr(args[0], args[1])
     for i in range(2, len(args)):
-        ret = _make._OpOr(ret, args[i])
+        ret = tvm.tir._ffi_api._OpOr(ret, args[i])
     return ret
 
 
@@ -150,9 +123,9 @@ def all(*args):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    ret = _make._OpAnd(args[0], args[1])
+    ret = tvm.tir._ffi_api._OpAnd(args[0], args[1])
     for i in range(2, len(args)):
-        ret = _make._OpAnd(ret, args[i])
+        ret = tvm.tir._ffi_api._OpAnd(ret, args[i])
     return ret
 
 
@@ -438,7 +411,7 @@ def extern(shape,
             output_placeholders.append(decl_buffer(shp, dt, name))
     body = fcompute(input_placeholders, output_placeholders)
     if isinstance(body, _expr.PrimExpr):
-        body = _make.Evaluate(body)
+        body = _stmt.Evaluate(body)
 
     op = _api_internal._ExternOp(name, tag, attrs,
                                  inputs, input_placeholders,
@@ -447,159 +420,6 @@ def extern(shape,
     return res[0] if len(res) == 1 else res
 
 
-def decl_buffer(shape,
-                dtype=None,
-                name="buffer",
-                data=None,
-                strides=None,
-                elem_offset=None,
-                scope="",
-                data_alignment=-1,
-                offset_factor=0,
-                buffer_type=""):
-    """Declare a new symbolic buffer.
-
-    Normally buffer is created automatically during lower and build.
-    This is only needed if user want to specify their own buffer layout.
-
-    See the note below for detailed discussion on usage of buffer.
-
-    Parameters
-    ----------
-    shape : tuple of Expr
-        The shape of the buffer.
-
-    dtype : str, optional
-        The data type of the buffer.
-
-    name : str, optional
-        The name of the buffer.
-
-    data : Var, optional
-        The data pointer in the buffer.
-
-    strides: array of Expr
-        The stride of the buffer.
-
-    elem_offset: Expr, optional
-        The beginning offset of the array to data.
-        In terms of number of elements of dtype.
-
-    scope: str, optional
-        The storage scope of the buffer, if not global.
-        If scope equals empty string, it means it is global memory.
-
-    data_alignment: int, optional
-        The alignment of data pointer in bytes.
-        If -1 is passed, the alignment will be set to TVM's internal default.
-
-    offset_factor: int, optional
-        The factor of elem_offset field, when set,
-        elem_offset is required to be multiple of offset_factor.
-        If 0 is pssed, the alignment will be set to 1.
-        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
-
-    buffer_type: str, optional, {"", "auto_broadcast"}
-        auto_broadcast buffer allows one to implement broadcast computation
-        without considering whether dimension size equals to one.
-        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
-
-    Returns
-    -------
-    buffer : Buffer
-        The created buffer
-
-    Example
-    -------
-    Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
-
-    .. code-block:: python
-
-        m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
-        n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
-        o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
-        A = tvm.placeholder((m0, m1, m2), name='A')
-        B = tvm.placeholder((n0, n1, n2), name='B')
-        C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
-        Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
-        Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
-        s = tvm.create_schedule(C.op)
-        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
-        ctx = tvm.cpu(0)
-        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
-        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
-        fadd(a, b, c)
-        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
-
-    Note
-    ----
-    Buffer data structure reflects the DLTensor structure in dlpack.
-    While DLTensor data structure is very general, it is usually helpful
-    to create function that only handles specific case of data structure
-    and make compiled function benefit from it.
-
-    If user pass strides and elem_offset is passed as None
-    when constructing the function, then the function will be specialized
-    for the DLTensor that is compact and aligned.
-    If user pass a fully generic symbolic array to the strides,
-    then the resulting function becomes fully generic.
-    """
-    shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
-    dtype = float32 if dtype is None else dtype
-    strides = () if strides is None else strides
-    if offset_factor != 0 and elem_offset is None:
-        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
-        elem_offset = var('%s_elem_offset' % name, shape_dtype)
-    if data is None:
-        data = var(name, "handle")
-    return _api_internal._Buffer(
-        data, dtype, shape, strides, elem_offset, name, scope,
-        data_alignment, offset_factor, buffer_type)
-
-def layout(layout_str):
-    """Create a layout node from a string.
-
-    Parameters
-    ----------
-    layout_str : str
-        A layout representation is composed of upper cases, lower cases and numbers,
-        where upper case indicates a primal axis and
-        the corresponding lower case with factor size indicates the subordinate axis.
-        For example, NCHW16c can describe a 5-D tensor of
-        [batch_size, channel, height, width, channel_block].
-        Here subordinate axis channel_block=16 is the factor size of
-        the primal axis C (channel).
-
-    Returns
-    -------
-    layout : Layout
-        The created layout
-    """
-    return _api_internal._Layout(layout_str)
-
-def bijective_layout(src_layout, dst_layout):
-    """Create a bijective layout mapping.
-
-    Parameters
-    ----------
-    src_layout : str or Layout
-        source layout.
-
-    dst_layout : str or Layout
-        destination layout.
-
-    Returns
-    -------
-    bijective_layout : BijectiveLayout
-        The created bijective layout
-    """
-    if isinstance(src_layout, str):
-        src_layout = layout(src_layout)
-    if isinstance(dst_layout, str):
-        dst_layout = layout(dst_layout)
-    return _api_internal._BijectiveLayout(src_layout, dst_layout)
-
 def _IterVar(dom, name, iter_type, thread_tag=''):
     """Internal function to create IterVar
 
@@ -758,7 +578,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
             expr = convert([expr])
         result = convert(result)
         id_elem = convert(id_elem)
-        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
+        combiner = _expr.CommReducer(lhs, rhs, result, id_elem)
         axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
         if where is None:
             where = convert(True)
@@ -810,164 +630,9 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     reducer.__doc__ = doc_str.format(name)
     return reducer
 
-def div(a, b):
-    """Compute a / b as in C/C++ semantics.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand, known to be non-negative.
-
-    b : Expr
-        The right hand operand, known to be non-negative.
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-    Note
-    ----
-    When operands are integers, returns truncdiv(a, b).
-    """
-    return _make._OpDiv(a, b)
-
-
-def indexdiv(a, b):
-    """Compute floor(a / b) where a and b are non-negative.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand, known to be non-negative.
-
-    b : Expr
-        The right hand operand, known to be non-negative.
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-
-    Note
-    ----
-    Use this function to split non-negative indices.
-    This function may take advantage of operands'
-    non-negativeness.
-    """
-    return _make._OpIndexDiv(a, b)
-
-
-def indexmod(a, b):
-    """Compute the remainder of indexdiv. a and b are non-negative.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand, known to be non-negative.
-
-    b : Expr
-        The right hand operand, known to be non-negative.
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-
-    Note
-    ----
-    Use this function to split non-negative indices.
-    This function may take advantage of operands'
-    non-negativeness.
-    """
-    return _make._OpIndexMod(a, b)
-
-
-def truncdiv(a, b):
-    """Compute the truncdiv of two expressions.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand
-
-    b : Expr
-        The right hand operand
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-
-    Note
-    ----
-    This is the default integer division behavior in C.
-    """
-    return _make._OpTruncDiv(a, b)
-
-
-def truncmod(a, b):
-    """Compute the truncmod of two expressions.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand
-
-    b : Expr
-        The right hand operand
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-
-    Note
-    ----
-    This is the default integer division behavior in C.
-    """
-    return _make._OpTruncMod(a, b)
-
-
-def floordiv(a, b):
-    """Compute the floordiv of two expressions.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand
-
-    b : Expr
-        The right hand operand
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-    """
-    return _make._OpFloorDiv(a, b)
-
-
-def floormod(a, b):
-    """Compute the floormod of two expressions.
-
-    Parameters
-    ----------
-    a : Expr
-        The left hand operand
-
-    b : Expr
-        The right hand operand
-
-    Returns
-    -------
-    res : Expr
-        The result expression.
-    """
-    return _make._OpFloorMod(a, b)
-
-#pylint: disable=unnecessary-lambda
+# pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
-min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
-max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
+min = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMin(x, y), max_value, name='min')
+max = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMax(x, y), min_value, name='max')
 
 tvm._ffi._init_api("tvm.api")
index 5067277..9ff8b24 100644 (file)
@@ -227,7 +227,7 @@ def args_to_workload(x, topi_compute_func=None):
         workload = 0
     else:
         raise RuntimeError('Do not support type "%s" in argument. Consider to use'
-                           'primitive types or tvm.expr.Var only' % type(x))
+                           'primitive types or tvm.tir.Var only' % type(x))
     return (get_func_name(topi_compute_func), ) + workload  if topi_compute_func else workload
 
 def template(func):
index c2993ac..de78a3e 100644 (file)
@@ -26,17 +26,19 @@ import tvm.runtime
 from tvm.runtime import Object, ndarray
 from tvm.ir import container
 from tvm.target import codegen
+from tvm.tir import expr
+from tvm.tir import ir_pass
+from tvm.tir import Stmt
+from tvm.tir.stmt import LoweredFunc
+
+from . import target as _target
 
 from . import api
 from . import _api_internal
 from . import tensor
 from . import schedule
-from . import expr
-from . import ir_pass
-from . import stmt as _stmt
-from . import target as _target
 from . import make
-from .stmt import LoweredFunc
+
 
 
 class DumpIR(object):
@@ -61,7 +63,7 @@ class DumpIR(object):
         def dump(*args, **kwargs):
             """dump function"""
             retv = func(*args, **kwargs)
-            if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
+            if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
                 return retv
             fname = func.func_name if hasattr(func, 'func_name') else func.__name__
             pname = str(self._pass_id) + "_" + fname + "_ir.cc"
index 7c024b7..cdd4ce2 100644 (file)
@@ -15,9 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to BLAS libraries."""
-from __future__ import absolute_import as _abs
-
-from .. import api as _api, intrin as _intrin
+import tvm
+from .. import api as _api
 
 
 def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
@@ -46,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
     return _api.extern(
         (n, m),
         [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
         ),
         name="C",
@@ -78,7 +77,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
     return _api.extern(
         (b, n, m),
         [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cblas.batch_matmul"
             if not iterative
             else "tvm.contrib.cblas.batch_matmul_iterative",
index cf4e1f5..75290a8 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to cuBLAS libraries."""
-from __future__ import absolute_import as _abs
-
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
 
 def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     """Create an extern op that compute matrix mult of A and rhs with cuBLAS
@@ -44,7 +42,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     dtype = dtype if dtype is not None else lhs.dtype
     return _api.extern(
         (n, m), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cublas.matmul",
             ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
 
@@ -73,6 +71,6 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
     dtype = dtype if dtype is not None else lhs.dtype
     return _api.extern(
         (b, n, m), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cublas.batch_matmul",
             ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
index 5470fd0..1000ede 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to cuBLASlt libraries."""
-from __future__ import absolute_import as _abs
-
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
+
 
 def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
     """Create an extern op that compute matrix mult of A and rhs with cuBLAS
@@ -46,6 +45,6 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
     dtype = dtype if dtype is not None else lhs.dtype
     return _api.extern(
         (n, m), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cublaslt.matmul",
             ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
index 1b5caca..20b42d7 100644 (file)
@@ -18,8 +18,8 @@
 # pylint: disable-msg=C0103
 import ctypes
 import numpy as np
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
 from .. import get_global_func as _get_global_func
 
 # algos can be read from cudnn.h
@@ -365,7 +365,7 @@ def conv_forward(x,
     if dims == 4:
         return _api.extern(
             oshape, [x, w],
-            lambda ins, outs: _intrin.call_packed(
+            lambda ins, outs: tvm.tir.call_packed(
                 "tvm.contrib.cudnn.conv2d.forward",
                 conv_mode,
                 tensor_format,
@@ -383,7 +383,7 @@ def conv_forward(x,
 
     return _api.extern(
         oshape, [x, w],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.cudnn.conv3d.forward",
             conv_mode,
             tensor_format,
index e062ac1..7f024f7 100644 (file)
@@ -18,8 +18,8 @@
 # pylint: disable-msg=C0103
 import ctypes
 import numpy as np
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
 from .. import get_global_func as _get_global_func
 
 
@@ -113,7 +113,7 @@ def conv2d_forward(x,
 
     return _api.extern(
         list(oshape), [x, w],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.miopen.conv2d.forward",
             conv_mode,
             data_type,
index d9cab24..5d84e89 100644 (file)
@@ -15,9 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to MPS libraries."""
-from __future__ import absolute_import as _abs
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
 
 # pylint: disable=C0103,W0612
 
@@ -50,7 +49,7 @@ def matmul(lhs, rhs, transa=False, transb=False):
         n = c
     return _api.extern(
         (m, n), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
         name="C")
 
@@ -82,6 +81,6 @@ def conv2d(data, weight, pad='SAME', stride=1):
 
     return _api.extern(
         (n, ho, wo, co), [data, weight],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
         name="C")
index 3e2132e..a55a344 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to NNPACK libraries."""
+import tvm
 import tvm._ffi
-
 from .. import api as _api
-from .. import intrin as _intrin
 
 
 def is_available():
@@ -46,7 +45,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
     m = rhs.shape[0]
     return _api.extern(
         (m, ), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.nnpack.fully_connected_inference",
             ins[0], ins[1], outs[0], nthreads), name="C")
 
@@ -110,7 +109,7 @@ def convolution_inference(
     return _api.extern(
         (batch, output_channels, output_height, output_width),
         [data, kernel, bias] if bias is not None else [data, kernel],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.nnpack.convolution_inference",
             ins[0],
             ins[1],
@@ -163,7 +162,7 @@ def convolution_inference_without_weight_transform(
     return _api.extern(
         (batch, output_channels, output_height, output_width),
         [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.nnpack.convolution_inference_without_weight_transform",
             ins[0],
             ins[1],
@@ -198,7 +197,7 @@ def convolution_inference_weight_transform(
     return _api.extern(
         (output_channels, input_channels, transform_tile_size, transform_tile_size),
         [kernel],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.nnpack.convolution_inference_weight_transform",
             ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
 
index 059bf23..bcc9b17 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to random library."""
+import tvm
 import tvm._ffi
-
 from .. import api as _api
-from .. import intrin as _intrin
 
 
 def randint(low, high, size, dtype='int32'):
@@ -39,7 +38,7 @@ def randint(low, high, size, dtype='int32'):
         A tensor with specified size and dtype
     """
     assert 'int' in dtype, "the type of randint output must be int or uint"
-    return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
+    return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
         "tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype)
 
 
@@ -67,7 +66,7 @@ def uniform(low, high, size):
     out : Tensor
         A tensor with specified size and dtype.
     """
-    return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
+    return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
         "tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
 
 
@@ -91,7 +90,7 @@ def normal(loc, scale, size):
     out : Tensor
         A tensor with specified size and dtype
     """
-    return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
+    return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
         "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
 
 
index bdd6146..e11be5a 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """External function interface to rocBLAS libraries."""
-from __future__ import absolute_import as _abs
-
+import tvm
 from .. import api as _api
-from .. import intrin as _intrin
 
 def matmul(lhs, rhs, transa=False, transb=False):
     """Create an extern op that compute matrix mult of A and rhs with rocBLAS
@@ -43,6 +41,6 @@ def matmul(lhs, rhs, transa=False, transb=False):
     m = rhs.shape[0] if transb else rhs.shape[1]
     return _api.extern(
         (n, m), [lhs, rhs],
-        lambda ins, outs: _intrin.call_packed(
+        lambda ins, outs: tvm.tir.call_packed(
             "tvm.contrib.rocblas.matmul",
             ins[0], ins[1], outs[0], transa, transb), name="C")
index b7bea7f..7c46312 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Generic opertors in TVM.
-We follow the numpy naming convention for this interface
-(e.g., tvm.generic.multitply ~ numpy.multiply).
-The default implementation is used by tvm.ExprOp.
-"""
-# pylint: disable=unused-argument
-from . import make as _make
-
-#Operator precedence used when overloading.
-__op_priority__ = 0
-
-
-def add(lhs, rhs):
-    """Generic add operator.
-
-    Parameters
-    ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of add operaton.
-    """
-    return _make._OpAdd(lhs, rhs)
-
-
-def subtract(lhs, rhs):
-    """Generic subtract operator.
-
-    Parameters
-    ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of subtract operaton.
-    """
-    return _make._OpSub(lhs, rhs)
-
-
-def multiply(lhs, rhs):
-    """Generic multiply operator.
-
-    Parameters
-    ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of multiply operaton.
-    """
-    return _make._OpMul(lhs, rhs)
-
-def divide(lhs, rhs):
-    """Generic divide operator.
-
-    Parameters
-    ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of divide operaton.
-    """
-    return _make._OpDiv(lhs, rhs)
-
-def floordiv(lhs, rhs):
-    """Generic floordiv operator.
-
-    Parameters
-    ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of divide operaton.
-    """
-    return _make._OpFloorDiv(lhs, rhs)
-
-
-def cast(src, dtype):
-    """Generic cast operator.
-
-    Parameters
-    ----------
-    src : object
-        The source operand.
-
-    Returns
-    -------
-    op : tvm.Expr
-        The result Expr of divide operaton.
-    """
-    return _make._cast(dtype, src)
+"""Generic operators."""
+# pylint:disable=unused-wildcard-import, wildcard-import
+from .tir.generic import *
index 78ce2e2..0933628 100644 (file)
 """Intrinsics of TVM-Python Hybrid Script for Python compilation time
 semantic support."""
 from tvm.ir.container import Array
+from tvm import target as _tgt
+from tvm.tir import expr as _expr
+from tvm.tir import ir_pass
+from tvm.tir import call_pure_intrin
+from tvm.tir.stmt import For
+
 from .. import api as _api
-from .. import expr as _expr
-from .. import make as _make
-from .. import target as _tgt
-from .. import ir_pass
-from ..stmt import For
+
 from .util import _internal_assert
-from ..intrin import call_pure_intrin
 
-#pylint: disable=redefined-builtin
+# pylint: disable=redefined-builtin
 
 LOOP_INTRIN = {
     'range'       : For.Serial,
@@ -69,15 +70,15 @@ def bind(func_id, args):
 
 def _math_intrin(func_id, args):
     # pylint: disable=import-outside-toplevel
-    from .. import intrin
-    return getattr(intrin, func_id)(*args)
+    import tvm.tir.op
+    return getattr(tvm.tir.op, func_id)(*args)
 
 sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
 
 
 def _min_max(func_id, args):
     _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
-    return getattr(_make, func_id.title())(args[0], args[1])
+    return getattr(_expr, func_id.title())(args[0], args[1])
 
 
 min = max = _min_max #pylint: disable=invalid-name
@@ -127,7 +128,7 @@ def len(func_id, args):
 def _cast(func_id, args):
     _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
                      "Only one expression can be cast")
-    return _make.Cast(func_id, args[0])
+    return _expr.Cast(func_id, args[0])
 
 float16 = float32 = float64 = _cast #pylint: disable=invalid-name
 int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
index a0b2dfe..6be0006 100644 (file)
@@ -24,7 +24,11 @@ import types
 import numbers
 
 from enum import Enum
-from tvm.ir.container import Array
+from tvm.ir import Array, Range
+import tvm.tir
+from tvm.tir import expr as _expr
+from tvm.tir import stmt as _stmt
+from tvm.tir import ir_pass as _ir_pass
 
 from .util import _internal_assert
 from . import calls
@@ -35,12 +39,7 @@ from ..api import any as _any
 
 from ..tensor import Tensor, Operation
 from .. import _api_internal as _tvm_internal
-from .. import expr as _expr
-from .. import make as _make
-from .. import stmt as _stmt
-
 from .. import api  as _api
-from .. import ir_pass as _ir_pass
 
 
 def concat_list_to_block(lst):
@@ -79,13 +78,13 @@ class Symbol(Enum):
 
 def _floordiv(x, y):
     if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
-        return _api.floordiv(x, y)
+        return tvm.tir.floordiv(x, y)
     return operator.floordiv(x, y)
 
 
 def _floormod(x, y):
     if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
-        return _api.floormod(x, y)
+        return tvm.tir.floormod(x, y)
     return operator.mod(x, y)
 
 
@@ -208,11 +207,11 @@ class HybridParser(ast.NodeVisitor):
             if _scope == 'global':
                 body = self.wrap_up_binds(body)
 
-            _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
+            _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
             _dtype = _buf.dtype
             _true = _api.convert(True)
-            body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
-            body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
+            body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
+            body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
 
         for elem in to_pop:
             self.symbols.pop(elem)
@@ -223,7 +222,7 @@ class HybridParser(ast.NodeVisitor):
     def wrap_up_binds(self, body):
         for _, iter_var in self.binds.items():
             ext = iter_var.dom.extent
-            body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
+            body = tvm.tir.AttrStmt(iter_var, 'thread_extent', ext, body)
         self.binds = {}
         return body
 
@@ -271,7 +270,7 @@ class HybridParser(ast.NodeVisitor):
             return entry if isinstance(node.ctx, ast.Load) else None
         if ty is Symbol.BufferVar:
             if isinstance(node.ctx, ast.Load):
-                return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
+                return tvm.tir.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
                                   _expr.Call.Halide, entry.op, entry.value_index)
             return entry, [_api.const(0, 'int32')]
         # Do I need any assertion here?
@@ -304,10 +303,10 @@ class HybridParser(ast.NodeVisitor):
             args = [_api.const(0, 'int32')]
         _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
 
-        read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
+        read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
         value = HybridParser._binop_maker[type(node.op)](read, rhs)
 
-        return _make.Provide(buf.op, 0, value, args)
+        return tvm.tir.Provide(buf.op, 0, value, args)
 
 
     def visit_Assign(self, node):
@@ -358,13 +357,13 @@ class HybridParser(ast.NodeVisitor):
             lhs = self.visit(lhs_)
             if lhs is not None:
                 buf, args = lhs
-                return _make.Provide(buf.op, 0, rhs, args)
+                return tvm.tir.Provide(buf.op, 0, rhs, args)
             return util.make_nop()
 
         lhs, args = self.visit(lhs)
         _internal_assert(isinstance(lhs, Tensor), \
                          "An array access's LHS is expected to be a expr.Call!")
-        res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
+        res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args)
         return res
 
 
@@ -391,8 +390,8 @@ class HybridParser(ast.NodeVisitor):
                     arr = arr[i.value]
             return arr
         if isinstance(node.ctx, ast.Load):
-            return _make.Call(arr.dtype, arr.name, args,
-                              _expr.Call.Halide, arr.op, arr.value_index)
+            return tvm.tir.Call(arr.dtype, arr.name, args,
+                                _expr.Call.Halide, arr.op, arr.value_index)
         return arr, args
 
     def visit_With(self, node):
@@ -426,14 +425,14 @@ class HybridParser(ast.NodeVisitor):
             else_body = visit_list_to_block(self.visit, node.orelse)
         else:
             else_body = None
-        return _make.IfThenElse(cond, if_body, else_body)
+        return tvm.tir.IfThenElse(cond, if_body, else_body)
 
 
     def visit_IfExp(self, node):
         cond = self.visit(node.test)
         if_body = self.visit(node.body)
         else_body = self.visit(node.orelse)
-        return _make.Select(cond, if_body, else_body)
+        return tvm.tir.Select(cond, if_body, else_body)
 
 
     def visit_Compare(self, node):
@@ -543,7 +542,7 @@ class HybridParser(ast.NodeVisitor):
         else:
             _internal_assert(not isinstance(for_type, tuple), \
                             "Micro expansion should be handled before!")
-            res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
+            res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
 
         self.symbols.pop(_name)
         return res
@@ -580,7 +579,7 @@ class HybridParser(ast.NodeVisitor):
     def visit_Assert(self, node):
         test = self.visit(node.test)
         mesg = _api.convert(self.visit(node.msg))
-        return _make.AssertStmt(test, mesg, util.make_nop())
+        return tvm.tir.AssertStmt(test, mesg, util.make_nop())
 
 
 def parse_python(src, args, symbols, closure_vars):
index 8ef200a..2f449dc 100644 (file)
@@ -22,12 +22,13 @@ import logging
 import sys
 import numpy
 
+from tvm._ffi.base import numeric_types
 from tvm.ir.container import Array
+
+from tvm.tir import expr as _expr
+from tvm.tir import stmt as _stmt
+
 from .. import api as _api
-from .. import make as _make
-from .. import expr as _expr
-from .. import stmt as _stmt
-from .._ffi.base import numeric_types
 from ..tensor import Tensor
 
 
@@ -46,7 +47,7 @@ def _internal_assert(cond, err):
 # Useful constants. In avoid of runtime dependences, we use function calls to return them.
 def make_nop():
     """Returns a 'no operation' node in HalideIR."""
-    return _make.Evaluate(_api.const(0, dtype='int32'))
+    return _stmt.Evaluate(_api.const(0, dtype='int32'))
 
 
 def is_docstring(node):
@@ -77,10 +78,10 @@ def replace_io(body, rmap):
     def replace(op):
         if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
             buf = rmap[op.func]
-            return _make.Provide(buf.op, op.value_index, op.value, op.args)
+            return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
         if isinstance(op, _expr.Call) and  op.func in rmap.keys():
             buf = rmap[op.func]
-            return _make.Call(buf.dtype, buf.name, op.args, \
+            return _expr.Call(buf.dtype, buf.name, op.args, \
                               _expr.Call.Halide, buf.op, buf.value_index)
         return None
 
index 04cbf9e..93e8fcb 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Expression Intrinsics and math functions in TVM."""
-# pylint: disable=redefined-builtin
-import tvm._ffi
-import tvm.target.codegen
-
-from . import make as _make
-from .api import convert, const
-from .expr import Call as _Call
-from .schedule import Buffer as _Buffer
-
-def _pack_buffer(buf):
-    """Build intrinsics that packs the buffer.
-    """
-    assert buf.shape
-    shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape,
-                       _Call.Intrinsic, None, 0)
-    strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides,
-                         _Call.Intrinsic, None, 0) if buf.strides else 0
-    pack_args = [buf.data,
-                 shape,
-                 strides,
-                 len(buf.shape),
-                 const(0, dtype=buf.dtype),
-                 buf.elem_offset]
-    return _make.Call("handle", "tvm_stack_make_array",
-                      pack_args, _Call.Intrinsic, None, 0)
-
-def call_packed(*args):
-    """Build expression by call an external packed function.
-
-    The argument to packed function can be Expr or Buffer.
-    The argument is the corresponding POD type when Expr is presented.
-
-    When the argument is Buffer, the corresponding PackedFunc
-    will recieve an TVMArrayHandle whose content is valid during the callback period.
-    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
-
-    Parameters
-    ----------
-    args : list of Expr or Buffer.
-        Positional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-
-    See Also
-    --------
-    tvm.extern : Create tensor with extern function call.
-    """
-    call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
-    return _make.Call(
-        "int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0)
-
-
-def call_pure_intrin(dtype, func_name, *args):
-    """Build expression by calling a pure intrinsic function.
-
-    Intrinsics can be overloaded with multiple data types via
-    the intrinsic translation rule.
-
-    Parameters
-    ----------
-    dtype : str
-        The data type of the result.
-
-    func_name: str
-        The intrinsic function name.
-
-    args : list
-        Positional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-    """
-    args = convert(args)
-    return _make.Call(
-        dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0)
-
-
-def call_intrin(dtype, func_name, *args):
-    """Build expression by calling an intrinsic function.
-
-    Intrinsics can be overloaded with multiple data types via
-    the intrinsic translation rule.
-
-    Parameters
-    ----------
-    dtype : str
-        The data type of the result.
-
-    func_name: str
-        The intrinsic function name.
-
-    args : list
-        Positional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-    """
-    args = convert(args)
-    return _make.Call(
-        dtype, func_name, convert(args), _Call.Intrinsic, None, 0)
-
-
-def call_pure_extern(dtype, func_name, *args):
-    """Build expression by calling a pure extern function.
-
-    Parameters
-    ----------
-    dtype : str
-        The data type of the result.
-
-    func_name: str
-        The extern function name.
-
-    args : list
-        Positional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-    """
-    return _make.Call(
-        dtype, func_name, convert(args), _Call.PureExtern, None, 0)
-
-
-def call_extern(dtype, func_name, *args):
-    """Build expression by calling a extern function.
-
-    Parameters
-    ----------
-    dtype : str
-        The data type of the result.
-
-    func_name: str
-        The extern function name.
-
-    args : list
-        Positional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-    """
-    return _make.Call(
-        dtype, func_name, convert(args), _Call.Extern, None, 0)
-
-
-def call_llvm_intrin(dtype, name, *args):
-    """Build expression by calling an llvm intrinsic function
-
-    Parameters
-    ----------
-    dtype : str
-       The data type of the result.
-
-    name : str
-       The name of the llvm intrinsic function.
-
-    args : list
-       Poistional arguments.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-    """
-    llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
-    return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
-
-
-def exp(x):
-    """Take exponetial of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "exp", x)
-
-
-def erf(x):
-    """Take gauss error function of the input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "erf", x)
-
-
-def tanh(x):
-    """Take hyperbolic tanh of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "tanh", x)
-
-
-def sigmoid(x):
-    """Quick function to get sigmoid
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "sigmoid", x)
-
-
-def log(x):
-    """Take log of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "log", x)
-
-def cos(x):
-    """Take cos of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "cos", x)
-
-def sin(x):
-    """Take sin of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "sin", x)
-
-def atan(x):
-    """Take atan of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "atan", x)
-
-def sqrt(x):
-    """Take square root of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "sqrt", x)
-
-
-def rsqrt(x):
-    """Take reciprocal of square root of input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "rsqrt", x)
-
-
-def floor(x):
-    """Take floor of float input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.floor(x)
-
-
-def ceil(x):
-    """Take ceil of float input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.ceil(x)
-
-
-def trunc(x):
-    """Get truncated value of the input.
-
-    The truncated value of the scalar x is the
-    nearest integer i which is closer to zero than x is.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.trunc(x)
-
-
-def abs(x):
-    """Get absolute value of the input element-wise.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.abs(x)
-
-
-def round(x):
-    """Round elements of the array to the nearest integer.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.round(x)
-
-
-def nearbyint(x):
-    """Round elements of the array to the nearest integer.
-    This intrinsic uses llvm.nearbyint instead of llvm.round
-    which is faster but will results different from tvm.round.
-    Notably nearbyint rounds according to the rounding mode,
-    whereas tvm.round (llvm.round) ignores that.
-    For differences between the two see:
-    https://en.cppreference.com/w/cpp/numeric/math/round
-    https://en.cppreference.com/w/cpp/numeric/math/nearbyint
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.nearbyint(x)
-
-
-def isnan(x):
-    """Check if input value is Nan.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return _make.isnan(x)
-
-
-def power(x, y):
-    """x power y
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    y : Expr
-        The exponent
-
-    Returns
-    -------
-    z : Expr
-        The result.
-    """
-    return _make._OpPow(convert(x), convert(y))
-
-
-def popcount(x):
-    """Count the number of set bits in input x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-
-    Returns
-    -------
-    y : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "popcount", x)
-
-def fmod(x, y):
-    """Return the remainder of x divided by y with the same sign as x.
-
-    Parameters
-    ----------
-    x : Expr
-        Input argument.
-    y : Expr
-        Input argument.
-
-    Returns
-    -------
-    z : Expr
-        The result.
-    """
-    return call_pure_intrin(x.dtype, "fmod", x, y)
-
-
-def if_then_else(cond, t, f):
-    """Conditional selection expression.
-
-    Parameters
-    ----------
-    cond : Expr
-        The condition
-
-    t : Expr
-        The result expression if cond is true.
-
-    f : Expr
-        The result expression if cond is false.
-
-    Returns
-    -------
-    result : Node
-        The result of conditional expression.
-
-    Note
-    ----
-    Unlike Select, if_then_else will not execute
-    the branch that does not satisfy the condition.
-    You can use it to guard against out of bound access.
-    Unlike Select, if_then_else cannot be vectorized
-    if some lanes in the vector have different conditions.
-    """
-    return _make._OpIfThenElse(convert(cond), convert(t), convert(f))
-
-
-# Intrinsic rule related code
-def register_intrin_rule(target, intrin, f=None, override=False):
-    """Register an intrinsic function generation rule.
-
-    Intrinsic generation rules are callback functions for
-    code generator to get device specific calls.
-    This function simply translates to.
-
-    :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
-
-    TVM may already pre-register intrinsic rules in the backend.
-    However, user can use this function to change the intrinsic translation
-    behavior or add new intrinsic rules during runtime.
-
-    Parameters
-    ----------
-    target : str
-        The name of codegen target.
-
-    intrin : str
-        The name of the intrinsic.
-
-    f : function, optional
-        The function to be registered.
-
-    override: boolean optional
-        Whether override existing entry.
-
-    Returns
-    -------
-    fregister : function
-        Register function if f is not specified.
-
-    Examples
-    --------
-    The following code registers exp expansion rule for opencl.
-
-    .. code-block:: python
-
-        register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
-    """
-    return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
-
-
-def _rule_float_suffix(op):
-    """Intrinsic rule: Add float suffix if it is float32.
-
-    This is an example intrinsic generation rule.
-
-    Parameters
-    ----------
-    op : Expr
-        The call expression of original intrinsic.
-
-    Returns
-    -------
-    ret : Expr
-        The translated intrinsic rule.
-        Return same op if no translation is possible.
-
-    See Also
-    --------
-    register_intrin_rule : The registeration function for intrin rule.
-    """
-    if op.dtype == "float32":
-        return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
-    if op.dtype == "float64":
-        return call_pure_extern(op.dtype, op.name, *op.args)
-    return op
-
-
-def _rule_float_direct(op):
-    """Intrinsic rule: Directly call pure extern function for floats.
-
-    This is an example intrinsic generation rule.
-
-    Parameters
-    ----------
-    op : Expr
-        The call expression of original intrinsic.
-
-    Returns
-    -------
-    ret : Expr
-        The translated intrinsic rule.
-        Return same op if no translation is possible.
-
-    See Also
-    --------
-    register_intrin_rule : The registeration function for intrin rule.
-    """
-    if str(op.dtype).startswith("float"):
-        return call_pure_extern(op.dtype, op.name, *op.args)
-    return None
-
-@tvm._ffi.register_func("tvm.default_trace_action")
-def _tvm_default_trace_action(*args):
-    print(list(args))
-
-def trace(args, trace_action="tvm.default_trace_action"):
-    """Trace tensor data at the runtime.
-
-    The trace function allows to trace specific tensor at the
-    runtime. The tracing value should come as last argument.
-    The trace action should be specified, by default
-    tvm.default_trace_action is used.
-
-    Parameters
-    ----------
-    args : list of Expr or Buffers.
-        Positional arguments.
-
-    trace_action : str.
-        The name of the trace action.
-
-    Returns
-    -------
-    call : Expr
-        The call expression.
-
-    See Also
-    --------
-    tvm.call_packed : Creates packed function.
-    """
-    if not isinstance(args, list):
-        raise Exception("tvm.trace consumes the args as list type")
-    call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
-    call_args.insert(0, trace_action)
-    return _make.Call(
-        args[-1].dtype, "tvm_call_trace_packed", call_args, _Call.Intrinsic, None, 0)
-
-# opencl pattern for exp
-register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
-# default pattern for exp
-register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
+# pylint:disable=unused-wildcard-import, wildcard-import, redefined-builtin
+"""Backwared compatible layer for intrin."""
+from .tir.op import *
index d47e240..f122956 100644 (file)
@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation
 from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
 from .adt import Constructor, TypeData
 from .module import IRModule
-from .attrs import Attrs
+from .attrs import Attrs, make_node
 from .container import Array, Map
 
 from . import transform
index a596739..f30a18f 100644 (file)
@@ -18,6 +18,7 @@
 import tvm._ffi
 
 from tvm.runtime import Object
+import tvm.runtime._ffi_node_api
 from . import _ffi_api
 
 
@@ -91,3 +92,40 @@ class Attrs(Object):
 
     def __getitem__(self, item):
         return self.__getattr__(item)
+
+def make_node(type_key, **kwargs):
+    """Make a new IR node by its type key and fields
+
+    Parameters
+    ----------
+    type_key : str
+        The type key of the node.
+
+    **kwargs : dict
+        The fields of the node.
+
+    Returns
+    -------
+    node : Node
+        The corresponding IR Node
+
+    Note
+    ----
+    If the created node is instance of AttrsNode, then
+    the creator function will also run bound checks and
+    default value setup as supported by Attrs.
+
+    Example
+    -------
+    The following code constructs a IntImm object
+
+    .. code-block:: python
+
+       x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
+       assert isinstance(x, tvm.tir.IntImm)
+       assert x.value == 10
+    """
+    args = [type_key]
+    for k, v in kwargs.items():
+        args += [k, v]
+    return tvm.runtime._ffi_node_api.MakeNode(*args)
index 3314ef1..07ed8e8 100644 (file)
@@ -53,7 +53,7 @@ class Node(Object):
         return _ffi_api.AsText(self, show_meta_data, annotate)
 
     def __str__(self):
-        return self.astext(show_meta_data=False)
+        return _ffi_api.PrettyPrint(self)
 
 
 @tvm._ffi.register_object("relay.SourceName")
index 46acd16..d29e73a 100644 (file)
@@ -99,3 +99,23 @@ class Range(Node):
     You do not need to create a Range explicitly.
     Python lists and tuples will be converted automatically to a Range in API functions.
     """
+    @staticmethod
+    def make_by_min_extent(min_value, extent):
+        """Construct a Range by min and extent.
+
+        This constructs a range in [min_value, min_value + extent)
+
+        Parameters
+        ----------
+        min_value : PrimExpr
+            The minimum value of the range.
+
+        extent : PrimExpr
+            The extent of the range.
+
+        Returns
+        -------
+        rng : Range
+            The constructed range.
+        """
+        return _ffi_api.range_by_min_extent(min_value, extent)
index 7f94d10..089c393 100644 (file)
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=unused-import
 """namespace of IR node builder make function
 
 This namespace is used for developers. While you do not see any declarations.
@@ -23,19 +24,22 @@ Each api is a PackedFunc that can be called in a positional argument manner.
 You can use make function to build the IR node.
 """
 import tvm._ffi
+import tvm.ir
+from tvm.ir import make_node as node
+from tvm.tir import Call
 
 
-def range_by_min_extent(min_value, extent):
+def make_by_min_extent(min_value, extent):
     """Construct a Range by min and extent.
 
     This constructs a range in [min_value, min_value + extent)
 
     Parameters
     ----------
-    min_value : Expr
+    min_value : PrimExpr
         The minimum value of the range.
 
-    extent : Expr
+    extent : PrimExpr
         The extent of the range.
 
     Returns
@@ -43,45 +47,6 @@ def range_by_min_extent(min_value, extent):
     rng : Range
         The constructed range.
     """
-    return _range_by_min_extent(min_value, extent)
-
-
-def node(type_key, **kwargs):
-    """Make a new DSL node by its type key and fields
-
-    Parameters
-    ----------
-    type_key : str
-        The type key of the node.
-
-    **kwargs : dict
-        The fields of the node.
-
-    Returns
-    -------
-    node : Node
-        The corresponding DSL Node
-
-    Note
-    ----
-    If the created node is instance of AttrsNode, then
-    the creator function will also run bound checks and
-    default value setup as supported by Attrs.
-
-    Example
-    -------
-    The following code constructs a IntImm object
-
-    .. code-block:: python
-
-       x = tvm.make.node("IntImm", dtype="int32", value=10)
-       assert isinstance(x, tvm.expr.IntImm)
-       assert x.value == 10
-    """
-    args = [type_key]
-    for k, v in kwargs.items():
-        args += [k, v]
-    return _Node(*args)
-
+    return tvm.ir.Range.make_by_min_extent(min_value, extent)
 
 tvm._ffi._init_api("tvm.make")
index c1a4130..78ab8ff 100644 (file)
@@ -509,7 +509,7 @@ class ParseTreeToRelayIR(RelayVisitor):
             _, type_params = zip(*type_params)
         self.exit_var_scope()
 
-        attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
+        attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
         return expr.Function(var_list, body, ret_type, type_params, attrs)
 
     @spanify
index 68b2b1c..6929be0 100644 (file)
@@ -181,11 +181,11 @@ class VMCompiler(object):
             raise ValueError("Target is not set in env or passed as argument.")
         tgts = {}
         if isinstance(target, (str, tvm.target.Target)):
-            dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
+            dev_type = tvm.tir.IntImm("int32", tvm.nd.context(str(target)).device_type)
             tgts[dev_type] = tvm.target.create(target)
         elif isinstance(target, dict):
             for dev, tgt in target.items():
-                dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
+                dev_type = tvm.tir.IntImm("int32", tvm.nd.context(dev).device_type)
                 tgts[dev_type] = tvm.target.create(tgt)
         else:
             raise TypeError("target is expected to be str, tvm.target.Target, " +
index ac2ea9d..f920682 100644 (file)
@@ -932,7 +932,7 @@ def _shape():
     def _impl(inputs, attr, params):
         is_symbolic_shape = False
         for axis in attr['_input_shapes'][inputs[0]]:
-            if not isinstance(axis, (int, tvm.expr.IntImm)):
+            if not isinstance(axis, (int, tvm.tir.IntImm)):
                 is_symbolic_shape = True
                 break
 
index 8a4cb2f..e6053b8 100644 (file)
@@ -557,7 +557,7 @@ def split_shape_func(attrs, inputs, _):
     """
     Shape function for split op.
     """
-    if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
+    if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
         indices_or_sections = get_const_int(attrs.indices_or_sections)
     else:
         indices_or_sections = get_const_tuple(attrs.indices_or_sections)
index 650bf9d..cd48d49 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=unused-import
 """The computation schedule api of TVM."""
 import tvm._ffi
-
 from tvm._ffi.base import string_types
+
 from tvm.runtime import Object, convert
 from tvm.ir import container as _container
+from tvm.tir import expr as _expr, Buffer
 
 from . import _api_internal
 from . import tensor as _tensor
-from . import expr as _expr
-
-
-@tvm._ffi.register_object
-class Buffer(Object):
-    """Symbolic data buffer in TVM.
-
-    Buffer provide a way to represent data layout
-    specialization of data structure in TVM.
-
-    Do not construct directly, use :any:`decl_buffer` instead.
-    See the documentation of :any:`decl_buffer` for more details.
-
-    See Also
-    --------
-    decl_buffer : Declare a buffer
-    """
-    READ = 1
-    WRITE = 2
-
-    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
-        """Get an access pointer to the head of buffer.
-
-        This is the recommended method to get buffer data
-        ptress when interacting with external functions.
-
-        Parameters
-        ----------
-        access_mask : int
-            The access pattern MASK. Indicate whether the
-            access will read or write to the data content.
-
-        ptr_type : str, optional
-            The data type of the result pointer. Do not specify
-            unless we want to cast pointer to specific type.
-
-        content_lanes: int, optional
-            The number of lanes for the data type. This value
-            is greater than one for vector types.
-
-        offset: Expr, optional
-            The offset of pointer. We can use it to offset by
-            the number of elements from the address of ptr.
-
-        Examples
-        --------
-        .. code-block:: python
-
-          import tvm.schedule.Buffer
-          # Get access ptr for read
-          buffer.access_ptr("r")
-          # Get access ptr for read/write with bitmask
-          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
-          # Get access ptr for read/write with str flag
-          buffer.access_ptr("rw")
-          # Get access ptr for read with offset
-          buffer.access_ptr("r", offset = 100)
-        """
-        if isinstance(access_mask, string_types):
-            mask = 0
-            for value in access_mask:
-                if value == "r":
-                    mask = mask | Buffer.READ
-                elif value == "w":
-                    mask = mask | Buffer.WRITE
-                else:
-                    raise ValueError("Unknown access_mask %s" % access_mask)
-            access_mask = mask
-        offset = convert(offset)
-        return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
-                                              content_lanes, offset)
-
-    def vload(self, begin, dtype=None):
-        """Generate an Expr that loads dtype from begin index.
-
-        Parameters
-        ----------
-        begin : Array of Expr
-            The beginning index in unit of Buffer.dtype
-
-        dtype : str
-            The data type to be loaded,
-            can be vector type which have lanes that is multiple of Buffer.dtype
-
-        Returns
-        -------
-        load : Expr
-            The corresponding load expression.
-        """
-        begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
-        dtype = dtype if dtype else self.dtype
-        return _api_internal._BufferVLoad(self, begin, dtype)
-
-    def vstore(self, begin, value):
-        """Generate a Stmt that store value into begin index.
-
-        Parameters
-        ----------
-        begin : Array of Expr
-            The beginning index in unit of Buffer.dtype
-
-        value : Expr
-            The value to be stored.
-
-        Returns
-        -------
-        store : Stmt
-            The corresponding store stmt.
-        """
-        begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
-        return _api_internal._BufferVStore(self, begin, value)
 
 
 @tvm._ffi.register_object
index abe8436..2876496 100644 (file)
@@ -60,3 +60,4 @@ from .generic_func import GenericFunc
 from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
 from . import datatype
 from . import codegen
+from .intrin import register_intrin_rule
index a9506b3..328568a 100644 (file)
@@ -19,7 +19,7 @@ import tvm._ffi
 
 import tvm.runtime._ffi_api
 from tvm.runtime import convert, DataType
-from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
+from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
 
 
 def register(type_name, type_code):
diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py
new file mode 100644 (file)
index 0000000..acb0efe
--- /dev/null
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Target dependent intrinsic registration."""
+import tvm._ffi
+from tvm.tir import call_pure_extern
+
+
+# Intrinsic rule related code
+def register_intrin_rule(target, intrin, f=None, override=False):
+    """Register an intrinsic function generation rule.
+
+    Intrinsic generation rules are callback functions for
+    code generator to get device specific calls.
+    This function simply translates to.
+
+    :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
+
+    TVM may already pre-register intrinsic rules in the backend.
+    However, user can use this function to change the intrinsic translation
+    behavior or add new intrinsic rules during runtime.
+
+    Parameters
+    ----------
+    target : str
+        The name of codegen target.
+
+    intrin : str
+        The name of the intrinsic.
+
+    f : function, optional
+        The function to be registered.
+
+    override: boolean optional
+        Whether override existing entry.
+
+    Returns
+    -------
+    fregister : function
+        Register function if f is not specified.
+
+    Examples
+    --------
+    The following code registers exp expansion rule for opencl.
+
+    .. code-block:: python
+
+        register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
+    """
+    return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
+
+
+def _rule_float_suffix(op):
+    """Intrinsic rule: Add float suffix if it is float32.
+
+    This is an example intrinsic generation rule.
+
+    Parameters
+    ----------
+    op : PrimExpr
+        The call expression of original intrinsic.
+
+    Returns
+    -------
+    ret : PrimExpr
+        The translated intrinsic rule.
+        Return same op if no translation is possible.
+
+    See Also
+    --------
+    register_intrin_rule : The registeration function for intrin rule.
+    """
+    if op.dtype == "float32":
+        return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
+    if op.dtype == "float64":
+        return call_pure_extern(op.dtype, op.name, *op.args)
+    return op
+
+
+def _rule_float_direct(op):
+    """Intrinsic rule: Directly call pure extern function for floats.
+
+    This is an example intrinsic generation rule.
+
+    Parameters
+    ----------
+    op : PrimExpr
+        The call expression of original intrinsic.
+
+    Returns
+    -------
+    ret : PrimExpr
+        The translated intrinsic rule.
+        Return same op if no translation is possible.
+
+    See Also
+    --------
+    register_intrin_rule : The registeration function for intrin rule.
+    """
+    if str(op.dtype).startswith("float"):
+        return call_pure_extern(op.dtype, op.name, *op.args)
+    return None
+
+# opencl pattern for exp
+register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
+# default pattern for exp
+register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
index bd25f84..00bd9d1 100644 (file)
 import tvm._ffi
 
 from tvm.runtime import Object, ObjectGeneric, convert_to_object
+from tvm.tir import expr as _expr
 
 from . import _api_internal
-from . import make as _make
-from . import expr as _expr
 
 
 class TensorSlice(ObjectGeneric, _expr.ExprOp):
@@ -74,7 +73,7 @@ class Tensor(Object, _expr.ExprOp):
             else:
                 raise ValueError("The indices must be expression")
 
-        return _make.Call(self.dtype, self.op.name,
+        return _expr.Call(self.dtype, self.op.name,
                           args, _expr.Call.Halide,
                           self.op, self.value_index)
 
@@ -207,136 +206,3 @@ class HybridOp(Operation):
     def axis(self):
         """Represent the IterVar axis, also defined when it is a HybridOp"""
         return self.__getattr__("axis")
-
-
-@tvm._ffi.register_object
-class Layout(Object):
-    """Layout is composed of upper cases, lower cases and numbers,
-    where upper case indicates a primal axis and
-    the corresponding lower case with factor size indicates the subordinate axis.
-    For example, NCHW16c can describe a 5-D tensor of
-    [batch_size, channel, height, width, channel_block].
-    Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
-
-    Do not construct directly, use :any:`layout` instead.
-    See the documentation of :any:`layout` for more details.
-
-    See Also
-    --------
-    layout : Declare a layout
-    """
-    def __len__(self):
-        return _api_internal._LayoutNdim(self)
-
-    def __contains__(self, axis):
-        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
-
-    def __getitem__(self, index):
-        if index >= len(self):
-            raise IndexError("Layout index out of range")
-        return _api_internal._LayoutGetItem(self, index)
-
-    def index_of(self, axis):
-        """Get the index of an axis
-
-        Parameters
-        ----------
-        axis : str
-            The axis name, need to be [a-z,A-Z]
-
-        Returns
-        -------
-        index : int
-            The index of the axis, -1 if not found.
-        """
-        return _api_internal._LayoutIndexOf(self, axis)
-
-    def factor_of(self, axis):
-        """Get the factor size of the subordinate axis.
-
-        Parameters
-        ----------
-        axis : str
-            The axis name, need to be [a-z,A-Z]
-
-        Returns
-        -------
-        factor : int
-            the size of the subordinate-axis of axis (if axis is a primal-axis),
-            or the size of axis itself (if axis is a subordinate-axis).
-            Return -1 if axis is not in the layout.
-        """
-        return _api_internal._LayoutFactorOf(self, axis)
-
-
-@tvm._ffi.register_object
-class BijectiveLayout(Object):
-    """Bijective mapping for two layouts (src-layout and dst-layout).
-    It provides shape and index conversion between each other.
-
-    Do not construct directly, use :any:`bijective_layout` instead.
-    See the documentation of :any:`bijective_layout` for more details.
-
-    See Also
-    --------
-    bijective_layout : Declare a bijective layout converter
-    """
-    def forward_index(self, index):
-        """Given the indices of the src-layout, infer the dst index.
-
-        Parameters
-        ----------
-        index: Array of Expr
-            The indices in src-layout.
-
-        Returns
-        -------
-        dst_index: Array of Expr
-            The inferred indices in dst-layout.
-        """
-        return _api_internal._BijectiveLayoutForwardIndex(self, index)
-
-    def backward_index(self, index):
-        """Given the indices of the dst-layout, infer the src index.
-
-        Parameters
-        ----------
-        index: Array of Expr
-            The indices in dst-layout.
-
-        Returns
-        -------
-        src_index: Array of Expr
-            The inferred indices in src-layout.
-        """
-        return _api_internal._BijectiveLayoutBackwardIndex(self, index)
-
-    def forward_shape(self, shape):
-        """Given the shape of the src-layout, infer the dst shape.
-
-        Parameters
-        ----------
-        shape: Array of Expr
-            The shape in src-layout.
-
-        Returns
-        -------
-        dst_shape: Array of Expr
-            The inferred shape in dst-layout.
-        """
-        return _api_internal._BijectiveLayoutForwardShape(self, shape)
-
-    def backward_shape(self, shape):
-        """Given the shape of the dst-layout, infer the src shape.
-
-        Parameters
-        ----------
-        shape: Array of Expr
-            The shape in dst-layout.
-
-        Returns
-        -------
-        src_shape: Array of Expr
-            The inferred shape in src-layout.
-        """
-        return _api_internal._BijectiveLayoutBackwardShape(self, shape)
index 2a01fe7..1fd8bee 100644 (file)
 import tvm._ffi
 
 from tvm.runtime import Object
+from tvm.ir import Range
+from tvm.tir import expr as _expr
+from tvm.tir import stmt as _stmt
+
 from . import _api_internal
 from . import api as _api
-from . import expr as _expr
-from . import stmt as _stmt
-from . import make as _make
 from . import tensor as _tensor
 from . import schedule as _schedule
 from .build_module import current_build_config
@@ -39,7 +40,7 @@ def _get_region(tslice):
                 begin = idx.var
             else:
                 begin = idx
-            region.append(_make.range_by_min_extent(begin, 1))
+            region.append(Range.make_by_min_extent(begin, 1))
     return region
 
 @tvm._ffi.register_object
@@ -136,7 +137,7 @@ def decl_tensor_intrin(op,
         scalar_params = []
     if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
         body = [body]
-    body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
+    body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
     if len(body) < 3:
         body += [None] * (3 - len(body))
     return _api_internal._TensorIntrin(
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
new file mode 100644 (file)
index 0000000..8621540
--- /dev/null
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+"""Namespace for Tensor-level IR"""
+from tvm.ir import PrimExpr
+from .buffer import Buffer, decl_buffer
+from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
+from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
+from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
+from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
+from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
+
+from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
+from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
+from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
+
+from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
+from .op import call_llvm_intrin, min_value, max_value
+from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
+from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
+from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
+
+from . import ir_builder
+from . import ir_pass
diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py
new file mode 100644 (file)
index 0000000..1b60b8c
--- /dev/null
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.tir"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("tir", __name__)
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
new file mode 100644 (file)
index 0000000..d0d01d7
--- /dev/null
@@ -0,0 +1,247 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Abstraction for array data structures."""
+from numbers import Integral
+import tvm._ffi
+
+from tvm._ffi.base import string_types
+from tvm.runtime import Object, convert
+from tvm.ir import PrimExpr
+from . import _ffi_api
+
+
+@tvm._ffi.register_object
+class Buffer(Object):
+    """Symbolic data buffer in TVM.
+
+    Buffer provide a way to represent data layout
+    specialization of data structure in TVM.
+
+    Do not construct directly, use :py:func:`~decl_buffer` instead.
+    See the documentation of :py:func:`decl_buffer` for more details.
+
+    See Also
+    --------
+    decl_buffer : Declare a buffer
+    """
+    READ = 1
+    WRITE = 2
+
+    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
+        """Get an access pointer to the head of buffer.
+
+        This is the recommended method to get buffer data
+        ptress when interacting with external functions.
+
+        Parameters
+        ----------
+        access_mask : int
+            The access pattern MASK. Indicate whether the
+            access will read or write to the data content.
+
+        ptr_type : str, optional
+            The data type of the result pointer. Do not specify
+            unless we want to cast pointer to specific type.
+
+        content_lanes: int, optional
+            The number of lanes for the data type. This value
+            is greater than one for vector types.
+
+        offset: Expr, optional
+            The offset of pointer. We can use it to offset by
+            the number of elements from the address of ptr.
+
+        Examples
+        --------
+        .. code-block:: python
+
+          # Get access ptr for read
+          buffer.access_ptr("r")
+          # Get access ptr for read/write with bitmask
+          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
+          # Get access ptr for read/write with str flag
+          buffer.access_ptr("rw")
+          # Get access ptr for read with offset
+          buffer.access_ptr("r", offset = 100)
+        """
+        if isinstance(access_mask, string_types):
+            mask = 0
+            for value in access_mask:
+                if value == "r":
+                    mask = mask | Buffer.READ
+                elif value == "w":
+                    mask = mask | Buffer.WRITE
+                else:
+                    raise ValueError("Unknown access_mask %s" % access_mask)
+            access_mask = mask
+        offset = convert(offset)
+        return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type,
+                                        content_lanes, offset)
+
+    def vload(self, begin, dtype=None):
+        """Generate an Expr that loads dtype from begin index.
+
+        Parameters
+        ----------
+        begin : Array of Expr
+            The beginning index in unit of Buffer.dtype
+
+        dtype : str
+            The data type to be loaded,
+            can be vector type which have lanes that is multiple of Buffer.dtype
+
+        Returns
+        -------
+        load : Expr
+            The corresponding load expression.
+        """
+        begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
+        dtype = dtype if dtype else self.dtype
+        return _ffi_api.BufferVLoad(self, begin, dtype)
+
+    def vstore(self, begin, value):
+        """Generate a Stmt that store value into begin index.
+
+        Parameters
+        ----------
+        begin : Array of Expr
+            The beginning index in unit of Buffer.dtype
+
+        value : Expr
+            The value to be stored.
+
+        Returns
+        -------
+        store : Stmt
+            The corresponding store stmt.
+        """
+        begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
+        return _ffi_api.BufferVStore(self, begin, value)
+
+
+def decl_buffer(shape,
+                dtype=None,
+                name="buffer",
+                data=None,
+                strides=None,
+                elem_offset=None,
+                scope="",
+                data_alignment=-1,
+                offset_factor=0,
+                buffer_type=""):
+    """Declare a new symbolic buffer.
+
+    Normally buffer is created automatically during lower and build.
+    This is only needed if user want to specify their own buffer layout.
+
+    See the note below for detailed discussion on usage of buffer.
+
+    Parameters
+    ----------
+    shape : tuple of Expr
+        The shape of the buffer.
+
+    dtype : str, optional
+        The data type of the buffer.
+
+    name : str, optional
+        The name of the buffer.
+
+    data : Var, optional
+        The data pointer in the buffer.
+
+    strides: array of Expr
+        The stride of the buffer.
+
+    elem_offset: Expr, optional
+        The beginning offset of the array to data.
+        In terms of number of elements of dtype.
+
+    scope: str, optional
+        The storage scope of the buffer, if not global.
+        If scope equals empty string, it means it is global memory.
+
+    data_alignment: int, optional
+        The alignment of data pointer in bytes.
+        If -1 is passed, the alignment will be set to TVM's internal default.
+
+    offset_factor: int, optional
+        The factor of elem_offset field, when set,
+        elem_offset is required to be multiple of offset_factor.
+        If 0 is pssed, the alignment will be set to 1.
+        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
+
+    buffer_type: str, optional, {"", "auto_broadcast"}
+        auto_broadcast buffer allows one to implement broadcast computation
+        without considering whether dimension size equals to one.
+        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
+
+    Returns
+    -------
+    buffer : Buffer
+        The created buffer
+
+    Example
+    -------
+    Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
+
+    .. code-block:: python
+
+        m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
+        n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
+        o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
+        A = tvm.placeholder((m0, m1, m2), name='A')
+        B = tvm.placeholder((n0, n1, n2), name='B')
+        C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
+        Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
+        Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
+        s = tvm.create_schedule(C.op)
+        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
+        fadd(a, b, c)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+
+    Note
+    ----
+    Buffer data structure reflects the DLTensor structure in dlpack.
+    While DLTensor data structure is very general, it is usually helpful
+    to create function that only handles specific case of data structure
+    and make compiled function benefit from it.
+
+    If user pass strides and elem_offset is passed as None
+    when constructing the function, then the function will be specialized
+    for the DLTensor that is compact and aligned.
+    If user pass a fully generic symbolic array to the strides,
+    then the resulting function becomes fully generic.
+    """
+    # pylint: disable=import-outside-toplevel
+    from .expr import Var
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    dtype = "float32" if dtype is None else dtype
+    strides = () if strides is None else strides
+    if offset_factor != 0 and elem_offset is None:
+        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
+        elem_offset = Var('%s_elem_offset' % name, shape_dtype)
+    if data is None:
+        data = Var(name, "handle")
+    return _ffi_api.Buffer(
+        data, dtype, shape, strides, elem_offset, name, scope,
+        data_alignment, offset_factor, buffer_type)
diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py
new file mode 100644 (file)
index 0000000..fd8c7a9
--- /dev/null
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Data layout."""
+import tvm._ffi
+
+from tvm.runtime import Object
+from . import _ffi_api
+
+@tvm._ffi.register_object
+class Layout(Object):
+    """Layout is composed of upper cases, lower cases and numbers,
+    where upper case indicates a primal axis and
+    the corresponding lower case with factor size indicates the subordinate axis.
+    For example, NCHW16c can describe a 5-D tensor of
+    [batch_size, channel, height, width, channel_block].
+    Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
+
+    See Also
+    --------
+    layout : Declare a layout
+    """
+    def __len__(self):
+        return _ffi_api.LayoutNdim(self)
+
+    def __contains__(self, axis):
+        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+
+    def __getitem__(self, index):
+        if index >= len(self):
+            raise IndexError("Layout index out of range")
+        return _ffi_api.LayoutGetItem(self, index)
+
+    def index_of(self, axis):
+        """Get the index of an axis
+
+        Parameters
+        ----------
+        axis : str
+            The axis name, need to be [a-z,A-Z]
+
+        Returns
+        -------
+        index : int
+            The index of the axis, -1 if not found.
+        """
+        return _ffi_api.LayoutIndexOf(self, axis)
+
+    def factor_of(self, axis):
+        """Get the factor size of the subordinate axis.
+
+        Parameters
+        ----------
+        axis : str
+            The axis name, need to be [a-z,A-Z]
+
+        Returns
+        -------
+        factor : int
+            the size of the subordinate-axis of axis (if axis is a primal-axis),
+            or the size of axis itself (if axis is a subordinate-axis).
+            Return -1 if axis is not in the layout.
+        """
+        return _ffi_api.LayoutFactorOf(self, axis)
+
+
+@tvm._ffi.register_object
+class BijectiveLayout(Object):
+    """Bijective mapping for two layouts (src-layout and dst-layout).
+    It provides shape and index conversion between each other.
+
+    Do not construct directly, use :any:`bijective_layout` instead.
+    See the documentation of :any:`bijective_layout` for more details.
+
+    Parameters
+    ----------
+    src_layout : str or Layout
+        source layout.
+
+    dst_layout : str or Layout
+        destination layout.
+
+    See Also
+    --------
+    bijective_layout : Declare a layout
+    """
+    def forward_index(self, index):
+        """Given the indices of the src-layout, infer the dst index.
+
+        Parameters
+        ----------
+        index: Array of Expr
+            The indices in src-layout.
+
+        Returns
+        -------
+        dst_index: Array of Expr
+            The inferred indices in dst-layout.
+        """
+        return _ffi_api.BijectiveLayoutForwardIndex(self, index)
+
+    def backward_index(self, index):
+        """Given the indices of the dst-layout, infer the src index.
+
+        Parameters
+        ----------
+        index: Array of Expr
+            The indices in dst-layout.
+
+        Returns
+        -------
+        src_index: Array of Expr
+            The inferred indices in src-layout.
+        """
+        return _ffi_api.BijectiveLayoutBackwardIndex(self, index)
+
+    def forward_shape(self, shape):
+        """Given the shape of the src-layout, infer the dst shape.
+
+        Parameters
+        ----------
+        shape: Array of Expr
+            The shape in src-layout.
+
+        Returns
+        -------
+        dst_shape: Array of Expr
+            The inferred shape in dst-layout.
+        """
+        return _ffi_api.BijectiveLayoutForwardShape(self, shape)
+
+    def backward_shape(self, shape):
+        """Given the shape of the dst-layout, infer the src shape.
+
+        Parameters
+        ----------
+        shape: Array of Expr
+            The shape in dst-layout.
+
+        Returns
+        -------
+        src_shape: Array of Expr
+            The inferred shape in src-layout.
+        """
+        return _ffi_api.BijectiveLayoutBackwardShape(self, shape)
+
+
+def layout(layout_str):
+    """Create a layout node from a string.
+
+    Parameters
+    ----------
+    layout_str : str
+        A layout representation is composed of upper cases, lower cases and numbers,
+        where upper case indicates a primal axis and
+        the corresponding lower case with factor size indicates the subordinate axis.
+        For example, NCHW16c can describe a 5-D tensor of
+        [batch_size, channel, height, width, channel_block].
+        Here subordinate axis channel_block=16 is the factor size of
+        the primal axis C (channel).
+
+    Returns
+    -------
+    layout : Layout
+        The created layout
+    """
+    return _ffi_api.Layout(layout_str)
+
+
+def bijective_layout(src_layout, dst_layout):
+    """Create a bijective layout mapping.
+
+    Parameters
+    ----------
+    src_layout : str or Layout
+        source layout.
+
+    dst_layout : str or Layout
+        destination layout.
+
+    Returns
+    -------
+    bijective_layout : BijectiveLayout
+        The created bijective layout
+    """
+    if isinstance(src_layout, str):
+        src_layout = layout(src_layout)
+    if isinstance(dst_layout, str):
+        dst_layout = layout(dst_layout)
+    return _ffi_api.BijectiveLayout(src_layout, dst_layout)
similarity index 77%
rename from python/tvm/expr.py
rename to python/tvm/tir/expr.py
index 1ff0697..92d6fbe 100644 (file)
@@ -27,16 +27,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
 
   x = tvm.var("n")
   y = x + 2
-  assert(isinstance(y, tvm.expr.Add))
+  assert(isinstance(y, tvm.tir.Add))
   assert(y.a == x)
 """
-# pylint: disable=missing-docstring
 import tvm._ffi
-from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const
 
-from . import make as _make
+from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const
+from tvm.ir import PrimExpr
+import tvm.ir._ffi_api
 from . import generic as _generic
-from . import _api_internal
+from . import _ffi_api
 
 
 def div_ambiguity_error():
@@ -45,6 +45,7 @@ def div_ambiguity_error():
         "please call div, indexdiv/indexmod, floordiv/floormod " +
         " or truncdiv/truncmod directly to avoid ambiguity in the code.")
 
+
 def _dtype_is_int(value):
     if isinstance(value, int):
         return True
@@ -53,6 +54,7 @@ def _dtype_is_int(value):
 
 
 class ExprOp(object):
+    """Operator overloading for Expr like expressions."""
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -98,44 +100,44 @@ class ExprOp(object):
         return _generic.floordiv(other, self)
 
     def __mod__(self, other):
-        return _make._OpFloorMod(self, other)
+        return _ffi_api._OpFloorMod(self, other)
 
     def __neg__(self):
         neg_one = const(-1, self.dtype)
         return self.__mul__(neg_one)
 
     def __lshift__(self, other):
-        return _make.left_shift(self, other)
+        return _ffi_api.left_shift(self, other)
 
     def __rshift__(self, other):
-        return _make.right_shift(self, other)
+        return _ffi_api.right_shift(self, other)
 
     def __and__(self, other):
-        return _make.bitwise_and(self, other)
+        return _ffi_api.bitwise_and(self, other)
 
     def __rand__(self, other):
-        return _make.bitwise_and(other, self)
+        return _ffi_api.bitwise_and(other, self)
 
     def __or__(self, other):
-        return _make.bitwise_or(self, other)
+        return _ffi_api.bitwise_or(self, other)
 
     def __ror__(self, other):
-        return _make.bitwise_or(other, self)
+        return _ffi_api.bitwise_or(other, self)
 
     def __xor__(self, other):
-        return _make.bitwise_xor(self, other)
+        return _ffi_api.bitwise_xor(self, other)
 
     def __rxor__(self, other):
-        return _make.bitwise_xor(other, self)
+        return _ffi_api.bitwise_xor(other, self)
 
     def __invert__(self):
-        return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
+        return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
 
     def __lt__(self, other):
-        return _make._OpLT(self, other)
+        return _ffi_api._OpLT(self, other)
 
     def __le__(self, other):
-        return _make._OpLE(self, other)
+        return _ffi_api._OpLE(self, other)
 
     def __eq__(self, other):
         return EqualOp(self, other)
@@ -144,10 +146,10 @@ class ExprOp(object):
         return NotEqualOp(self, other)
 
     def __gt__(self, other):
-        return _make._OpGT(self, other)
+        return _ffi_api._OpGT(self, other)
 
     def __ge__(self, other):
-        return _make._OpGE(self, other)
+        return _ffi_api._OpGE(self, other)
 
     def __nonzero__(self):
         raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
@@ -161,15 +163,15 @@ class ExprOp(object):
 
         Parameters
         ----------
-        other : Expr
+        other : PrimExpr
             The other expression
 
         Returns
         -------
-        ret : Expr
+        ret : PrimExpr
             The equality expression.
         """
-        return _make._OpEQ(self, other)
+        return _ffi_api._OpEQ(self, other)
 
     def astype(self, dtype):
         """Cast the expression to other type.
@@ -181,7 +183,7 @@ class ExprOp(object):
 
         Returns
         -------
-        expr : Expr
+        expr : PrimExpr
             Expression with new type
         """
         return _generic.cast(self, dtype)
@@ -195,10 +197,10 @@ class EqualOp(ObjectGeneric, ExprOp):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         Left operand.
 
-    b : Expr
+    b : PrimExpr
         Right operand.
     """
     # This class is not manipulated by C++. So use python's identity check function is sufficient
@@ -216,7 +218,7 @@ class EqualOp(ObjectGeneric, ExprOp):
 
     def asobject(self):
         """Convert object."""
-        return _make._OpEQ(self.a, self.b)
+        return _ffi_api._OpEQ(self.a, self.b)
 
 
 class NotEqualOp(ObjectGeneric, ExprOp):
@@ -227,10 +229,10 @@ class NotEqualOp(ObjectGeneric, ExprOp):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         Left operand.
 
-    b : Expr
+    b : PrimExpr
         Right operand.
     """
     # This class is not manipulated by C++. So use python's identity check function is sufficient
@@ -248,30 +250,30 @@ class NotEqualOp(ObjectGeneric, ExprOp):
 
     def asobject(self):
         """Convert object."""
-        return _make._OpNE(self.a, self.b)
+        return _ffi_api._OpNE(self.a, self.b)
 
 
-class PrimExpr(ExprOp, Object):
-    """Base class of all tvm Expressions"""
+class PrimExprWithOp(ExprOp, PrimExpr):
+    """Helper base class to inherit from PrimExpr."""
     # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
     # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
-    __hash__ = Object.__hash__
+    __hash__ = PrimExpr.__hash__
 
 
-class ConstExpr(PrimExpr):
+class ConstExpr(PrimExprWithOp):
     pass
 
-class BinaryOpExpr(PrimExpr):
+class BinaryOpExpr(PrimExprWithOp):
     pass
 
-class CmpExpr(PrimExpr):
+class CmpExpr(PrimExprWithOp):
     pass
 
-class LogicalExpr(PrimExpr):
+class LogicalExpr(PrimExprWithOp):
     pass
 
 @tvm._ffi.register_object("Variable")
-class Var(PrimExpr):
+class Var(PrimExprWithOp):
     """Symbolic variable.
 
     Parameters
@@ -279,18 +281,18 @@ class Var(PrimExpr):
     name : str
         The name
 
-    dtype : int
+    dtype : str
         The data type
     """
     def __init__(self, name, dtype):
         self.__init_handle_by_constructor__(
-            _api_internal._Var, name, dtype)
+            _ffi_api.Var, name, dtype)
 
 
 @tvm._ffi.register_object
 class SizeVar(Var):
     """Symbolic variable to represent a tensor index size
-       which is greater or equal to zero
+       which is greater or equal to zero.
 
     Parameters
     ----------
@@ -303,11 +305,34 @@ class SizeVar(Var):
     # pylint: disable=super-init-not-called
     def __init__(self, name, dtype):
         self.__init_handle_by_constructor__(
-            _api_internal._SizeVar, name, dtype)
+            _ffi_api.SizeVar, name, dtype)
+
+
+@tvm._ffi.register_object
+class CommReducer(Object):
+    """Communicative reduce operator
+
+    Parameters
+    ----------
+    lhs : List[Var]
+       The left arguments of the reducer.
+
+    rhs : List[Var]
+       The right arguments of the reducer.
+
+    result : List[PrimExpr]
+       The reduction results.
+
+    identity_element : List[PrimExpr]
+       The identity elements.
+    """
+    def __init__(self, lhs, rhs, result, identity_element):
+        self.__init_handle_by_constructor__(
+            _ffi_api.CommReducer, lhs, rhs, result, identity_element)
 
 
 @tvm._ffi.register_object
-class Reduce(PrimExpr):
+class Reduce(PrimExprWithOp):
     """Reduce node.
 
     Parameters
@@ -321,7 +346,7 @@ class Reduce(PrimExpr):
     rdom : list of IterVar
         The iteration domain
 
-    condition : Expr
+    condition : PrimExpr
         The reduce condition.
 
     value_index : int
@@ -329,7 +354,7 @@ class Reduce(PrimExpr):
     """
     def __init__(self, combiner, src, rdom, condition, value_index):
         self.__init_handle_by_constructor__(
-            _make.Reduce, combiner, src, rdom,
+            _ffi_api.Reduce, combiner, src, rdom,
             condition, value_index)
 
 
@@ -347,7 +372,7 @@ class FloatImm(ConstExpr):
     """
     def __init__(self, dtype, value):
         self.__init_handle_by_constructor__(
-            _make.FloatImm, dtype, value)
+            tvm.ir._ffi_api.FloatImm, dtype, value)
 
 @tvm._ffi.register_object
 class IntImm(ConstExpr):
@@ -363,7 +388,7 @@ class IntImm(ConstExpr):
     """
     def __init__(self, dtype, value):
         self.__init_handle_by_constructor__(
-            _make.IntImm, dtype, value)
+            tvm.ir._ffi_api.IntImm, dtype, value)
 
     def __int__(self):
         return self.value
@@ -380,7 +405,7 @@ class StringImm(ConstExpr):
     """
     def __init__(self, value):
         self.__init_handle_by_constructor__(
-            _make.StringImm, value)
+            _ffi_api.StringImm, value)
 
     def __eq__(self, other):
         if isinstance(other, ConstExpr):
@@ -394,7 +419,7 @@ class StringImm(ConstExpr):
 
 
 @tvm._ffi.register_object
-class Cast(PrimExpr):
+class Cast(PrimExprWithOp):
     """Cast expression.
 
     Parameters
@@ -402,12 +427,12 @@ class Cast(PrimExpr):
     dtype : str
         The data type
 
-    value : Expr
+    value : PrimExpr
         The value of the function.
     """
     def __init__(self, dtype, value):
         self.__init_handle_by_constructor__(
-            _make.Cast, dtype, value)
+            _ffi_api.Cast, dtype, value)
 
 
 @tvm._ffi.register_object
@@ -416,15 +441,15 @@ class Add(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Add, a, b)
+            _ffi_api.Add, a, b)
 
 
 @tvm._ffi.register_object
@@ -433,15 +458,15 @@ class Sub(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Sub, a, b)
+            _ffi_api.Sub, a, b)
 
 
 @tvm._ffi.register_object
@@ -450,15 +475,15 @@ class Mul(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Mul, a, b)
+            _ffi_api.Mul, a, b)
 
 
 @tvm._ffi.register_object
@@ -467,15 +492,15 @@ class Div(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Div, a, b)
+            _ffi_api.Div, a, b)
 
 
 @tvm._ffi.register_object
@@ -484,15 +509,15 @@ class Mod(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Mod, a, b)
+            _ffi_api.Mod, a, b)
 
 
 @tvm._ffi.register_object
@@ -501,15 +526,15 @@ class FloorDiv(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.FloorDiv, a, b)
+            _ffi_api.FloorDiv, a, b)
 
 
 @tvm._ffi.register_object
@@ -518,15 +543,15 @@ class FloorMod(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.FloorMod, a, b)
+            _ffi_api.FloorMod, a, b)
 
 
 @tvm._ffi.register_object
@@ -535,15 +560,15 @@ class Min(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Min, a, b)
+            _ffi_api.Min, a, b)
 
 
 @tvm._ffi.register_object
@@ -552,15 +577,15 @@ class Max(BinaryOpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Max, a, b)
+            _ffi_api.Max, a, b)
 
 
 @tvm._ffi.register_object
@@ -569,15 +594,15 @@ class EQ(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.EQ, a, b)
+            _ffi_api.EQ, a, b)
 
 
 @tvm._ffi.register_object
@@ -586,15 +611,15 @@ class NE(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.NE, a, b)
+            _ffi_api.NE, a, b)
 
 
 @tvm._ffi.register_object
@@ -603,15 +628,15 @@ class LT(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.LT, a, b)
+            _ffi_api.LT, a, b)
 
 
 @tvm._ffi.register_object
@@ -620,15 +645,15 @@ class LE(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.LE, a, b)
+            _ffi_api.LE, a, b)
 
 
 @tvm._ffi.register_object
@@ -637,15 +662,15 @@ class GT(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.GT, a, b)
+            _ffi_api.GT, a, b)
 
 
 @tvm._ffi.register_object
@@ -654,15 +679,15 @@ class GE(CmpExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.GE, a, b)
+            _ffi_api.GE, a, b)
 
 
 @tvm._ffi.register_object
@@ -671,15 +696,15 @@ class And(LogicalExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.And, a, b)
+            _ffi_api.And, a, b)
 
 
 @tvm._ffi.register_object
@@ -688,15 +713,15 @@ class Or(LogicalExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The left hand operand.
 
-    b : Expr
+    b : PrimExpr
         The right hand operand.
     """
     def __init__(self, a, b):
         self.__init_handle_by_constructor__(
-            _make.Or, a, b)
+            _ffi_api.Or, a, b)
 
 
 @tvm._ffi.register_object
@@ -705,16 +730,16 @@ class Not(LogicalExpr):
 
     Parameters
     ----------
-    a : Expr
+    a : PrimExpr
         The input value
     """
     def __init__(self, a):
         self.__init_handle_by_constructor__(
-            _make.Not, a)
+            _ffi_api.Not, a)
 
 
 @tvm._ffi.register_object
-class Select(PrimExpr):
+class Select(PrimExprWithOp):
     """Select node.
 
     Note
@@ -726,23 +751,23 @@ class Select(PrimExpr):
 
     Parameters
     ----------
-    condition : Expr
+    condition : PrimExpr
         The condition expression.
 
-    true_value : Expr
+    true_value : PrimExpr
         The value to take when condition is true.
 
-    false_value : Expr
+    false_value : PrimExpr
         The value to take when condition is false.
 
     """
     def __init__(self, condition, true_value, false_value):
         self.__init_handle_by_constructor__(
-            _make.Select, condition, true_value, false_value)
+            _ffi_api.Select, condition, true_value, false_value)
 
 
 @tvm._ffi.register_object
-class Load(PrimExpr):
+class Load(PrimExprWithOp):
     """Load node.
 
     Parameters
@@ -753,24 +778,25 @@ class Load(PrimExpr):
     buffer_var : Var
         The buffer variable in the load expression.
 
-    index : Expr
+    index : PrimExpr
         The index in the load.
 
-    predicate : Expr
+    predicate : PrimExpr
         The load predicate.
     """
-    def __init__(self, dtype, buffer_var, index, predicate):
+    def __init__(self, dtype, buffer_var, index, predicate=None):
+        args = [] if predicate is None else [predicate]
         self.__init_handle_by_constructor__(
-            _make.Load, dtype, buffer_var, index, predicate)
+            _ffi_api.Load, dtype, buffer_var, index, *args)
 
 
 @tvm._ffi.register_object
-class Ramp(PrimExpr):
+class Ramp(PrimExprWithOp):
     """Ramp node.
 
     Parameters
     ----------
-    base : Expr
+    base : PrimExpr
         The base expression.
 
     stride : ramp stride
@@ -781,16 +807,16 @@ class Ramp(PrimExpr):
     """
     def __init__(self, base, stride, lanes):
         self.__init_handle_by_constructor__(
-            _make.Ramp, base, stride, lanes)
+            _ffi_api.Ramp, base, stride, lanes)
 
 
 @tvm._ffi.register_object
-class Broadcast(PrimExpr):
+class Broadcast(PrimExprWithOp):
     """Broadcast node.
 
     Parameters
     ----------
-    value : Expr
+    value : PrimExpr
         The value of the expression.
 
     lanes : int
@@ -798,11 +824,11 @@ class Broadcast(PrimExpr):
     """
     def __init__(self, value, lanes):
         self.__init_handle_by_constructor__(
-            _make.Broadcast, value, lanes)
+            _ffi_api.Broadcast, value, lanes)
 
 
 @tvm._ffi.register_object
-class Shuffle(PrimExpr):
+class Shuffle(PrimExprWithOp):
     """Shuffle node.
 
     Parameters
@@ -815,11 +841,11 @@ class Shuffle(PrimExpr):
     """
     def __init__(self, vectors, indices):
         self.__init_handle_by_constructor__(
-            _make.Shuffle, vectors, indices)
+            _ffi_api.Shuffle, vectors, indices)
 
 
 @tvm._ffi.register_object
-class Call(PrimExpr):
+class Call(PrimExprWithOp):
     """Call node.
 
     Parameters
@@ -850,11 +876,11 @@ class Call(PrimExpr):
     PureIntrinsic = 5
     def __init__(self, dtype, name, args, call_type, func, value_index):
         self.__init_handle_by_constructor__(
-            _make.Call, dtype, name, args, call_type, func, value_index)
+            _ffi_api.Call, dtype, name, args, call_type, func, value_index)
 
 
 @tvm._ffi.register_object
-class Let(PrimExpr):
+class Let(PrimExprWithOp):
     """Let node.
 
     Parameters
@@ -862,12 +888,12 @@ class Let(PrimExpr):
     var : Var
         The variable in the binding.
 
-    value : Expr
+    value : PrimExpr
         The value in to be binded.
 
-    body : Expr
+    body : PrimExpr
         The body expression.
     """
     def __init__(self, var, value, body):
         self.__init_handle_by_constructor__(
-            _make.Let, var, value, body)
+            _ffi_api.Let, var, value, body)
diff --git a/python/tvm/tir/generic.py b/python/tvm/tir/generic.py
new file mode 100644 (file)
index 0000000..8a9cf8e
--- /dev/null
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Generic opertors in TVM.
+We follow the numpy naming convention for this interface
+(e.g., tvm.generic.multitply ~ numpy.multiply).
+The default implementation is used by tvm.ExprOp.
+"""
+# pylint: disable=unused-argument
+from . import _ffi_api
+
+# Operator precedence used when overloading.
+__op_priority__ = 0
+
+
+def add(lhs, rhs):
+    """Generic add operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of add operaton.
+    """
+    return _ffi_api._OpAdd(lhs, rhs)
+
+
+def subtract(lhs, rhs):
+    """Generic subtract operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of subtract operaton.
+    """
+    return _ffi_api._OpSub(lhs, rhs)
+
+
+def multiply(lhs, rhs):
+    """Generic multiply operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of multiply operaton.
+    """
+    return _ffi_api._OpMul(lhs, rhs)
+
+def divide(lhs, rhs):
+    """Generic divide operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of divide operaton.
+    """
+    return _ffi_api._OpDiv(lhs, rhs)
+
+def floordiv(lhs, rhs):
+    """Generic floordiv operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of divide operaton.
+    """
+    return _ffi_api._OpFloorDiv(lhs, rhs)
+
+
+def cast(src, dtype):
+    """Generic cast operator.
+
+    Parameters
+    ----------
+    src : object
+        The source operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of divide operaton.
+    """
+    return _ffi_api._cast(dtype, src)
similarity index 89%
rename from python/tvm/ir_builder.py
rename to python/tvm/tir/ir_builder.py
index 4cc7f4f..b56e153 100644 (file)
 # under the License.
 """Developer API of IR node builder make function."""
 from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, DataType
+from tvm.runtime import ObjectGeneric, DataType, convert, const
 from tvm.ir import container as _container
 
-from . import api as _api
 from . import stmt as _stmt
 from . import expr as _expr
-from . import make as _make
 from . import ir_pass as _pass
-from .expr import Call as _Call
+
 
 class WithScope(object):
     """Auxiliary scope  with"""
@@ -53,7 +51,7 @@ class BufferVar(ObjectGeneric):
     .. code-block:: python
 
         # The following code generate IR for x[0] = x[
-        ib = tvm.ir_builder.create()
+        ib = tvm.tir.ir_builder.create()
         x = ib.pointer("float32")
         x[0] = x[10] + 1
 
@@ -78,19 +76,19 @@ class BufferVar(ObjectGeneric):
     def __getitem__(self, index):
         t = DataType(self._content_type)
         if t.lanes > 1:
-            index = _make.Ramp(index * t.lanes, 1, t.lanes)
-        return _make.Load(self._content_type, self._buffer_var, index)
+            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+        return _expr.Load(self._content_type, self._buffer_var, index)
 
     def __setitem__(self, index, value):
-        value = _api.convert(value)
+        value = convert(value)
         if value.dtype != self._content_type:
             raise ValueError(
                 "data type does not match content type %s vs %s" % (
                     value.dtype, self._content_type))
         t = DataType(self._content_type)
         if t.lanes > 1:
-            index = _make.Ramp(index * t.lanes, 1, t.lanes)
-        self._builder.emit(_make.Store(self._buffer_var, value, index))
+            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
+        self._builder.emit(_stmt.Store(self._buffer_var, value, index))
 
 
 class IRBuilder(object):
@@ -117,7 +115,7 @@ class IRBuilder(object):
         """Pop sequence from stack"""
         seq = self._seq_stack.pop()
         if not seq or callable(seq[-1]):
-            seq.append(_make.Evaluate(0))
+            seq.append(_stmt.Evaluate(0))
         seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
         ret_seq = [seq[-1]]
 
@@ -138,7 +136,7 @@ class IRBuilder(object):
            The statement to be emitted or callable that build stmt given body.
         """
         if isinstance(stmt, _expr.Call):
-            stmt = _make.Evaluate(stmt)
+            stmt = _stmt.Evaluate(stmt)
         assert isinstance(stmt, _stmt.Stmt) or callable(stmt)
         self._seq_stack[-1].append(stmt)
 
@@ -167,10 +165,10 @@ class IRBuilder(object):
             x[i] = x[i - 1] + 1
         """
         if isinstance(node, string_types):
-            node = _make.StringImm(node)
+            node = _expr.StringImm(node)
         if isinstance(value, string_types):
-            value = _make.StringImm(value)
-        self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x))
+            value = _expr.StringImm(value)
+        self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
 
     def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
         """Create a for iteration scope.
@@ -211,7 +209,7 @@ class IRBuilder(object):
             name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
             self.nidx += 1
         self._seq_stack.append([])
-        loop_var = _api.var(name, dtype=dtype)
+        loop_var = _expr.Var(name, dtype=dtype)
         extent = end if begin == 0 else _pass.Simplify(end - begin)
         def _exit_cb():
             if for_type == "serial":
@@ -224,7 +222,7 @@ class IRBuilder(object):
                 for_type_id = 3
             else:
                 raise ValueError("Unknown for_type")
-            self.emit(_make.For(
+            self.emit(_stmt.For(
                 loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
         return WithScope(loop_var, _exit_cb)
 
@@ -253,7 +251,7 @@ class IRBuilder(object):
         """
         self._seq_stack.append([])
         def _exit_cb():
-            self.emit(_make.IfThenElse(cond, self._pop_seq(), None))
+            self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None))
         return WithScope(None, _exit_cb)
 
     def else_scope(self):
@@ -286,7 +284,7 @@ class IRBuilder(object):
         self._seq_stack[-1].pop()
         self._seq_stack.append([])
         def _exit_cb():
-            self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
+            self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
         return WithScope(None, _exit_cb)
 
     def new_scope(self):
@@ -326,13 +324,13 @@ class IRBuilder(object):
         buffer : BufferVar
             The buffer var representing the buffer.
         """
-        buffer_var = _api.var(name, dtype="handle")
+        buffer_var = _expr.Var(name, dtype="handle")
         if not isinstance(shape, (list, tuple, _container.Array)):
             shape = [shape]
         if scope:
             self.scope_attr(buffer_var, "storage_scope", scope)
-        self.emit(lambda x: _make.Allocate(
-            buffer_var, dtype, shape, _api.const(1, dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(
+            buffer_var, dtype, shape, const(1, dtype="uint1"), x))
         return BufferVar(self, buffer_var, dtype)
 
     def pointer(self, content_type, name="ptr"):
@@ -351,7 +349,7 @@ class IRBuilder(object):
         ptr : BufferVar
             The buffer var representing the buffer.
         """
-        buffer_var = _api.var(name, dtype="handle")
+        buffer_var = _expr.Var(name, dtype="handle")
         return BufferVar(self, buffer_var, content_type)
 
     def buffer_ptr(self, buf):
@@ -380,7 +378,8 @@ class IRBuilder(object):
         expr : Expr
             The expression will likely tag.
         """
-        return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
+        return _expr.Call(expr.dtype, "likely", [expr],
+                          _expr.Call.PureIntrinsic, None, 0)
 
     def get(self):
         """Return the builded IR.
similarity index 96%
rename from python/tvm/ir_pass.py
rename to python/tvm/tir/ir_pass.py
index 9d7f340..239b1fb 100644 (file)
@@ -25,4 +25,4 @@ You can read "include/tvm/tir/ir_pass.h" for the function signature and
 """
 import tvm._ffi
 
-tvm._ffi._init_api("tvm.ir_pass")
+tvm._ffi._init_api("tvm.ir_pass", __name__)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
new file mode 100644 (file)
index 0000000..a10fe69
--- /dev/null
@@ -0,0 +1,782 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""Operators used in TIR expression."""
+import tvm._ffi
+from tvm.runtime import convert, const
+from tvm.schedule import Buffer
+
+from .expr import Call
+from . import _ffi_api
+
+
+def _pack_buffer(buf):
+    """Build intrinsics that packs the buffer.
+    """
+    assert buf.shape
+    shape = Call("handle", "tvm_stack_make_shape", buf.shape,
+                 Call.Intrinsic, None, 0)
+    strides = Call("handle", "tvm_stack_make_shape", buf.strides,
+                   Call.Intrinsic, None, 0) if buf.strides else 0
+    pack_args = [buf.data,
+                 shape,
+                 strides,
+                 len(buf.shape),
+                 const(0, dtype=buf.dtype),
+                 buf.elem_offset]
+    return Call("handle", "tvm_stack_make_array",
+                pack_args, Call.Intrinsic, None, 0)
+
+def call_packed(*args):
+    """Build expression by call an external packed function.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    tvm.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call(
+        "int32", "tvm_call_packed", call_args, Call.Intrinsic, None, 0)
+
+
+def call_pure_intrin(dtype, func_name, *args):
+    """Build expression by calling a pure intrinsic function.
+
+    Intrinsics can be overloaded with multiple data types via
+    the intrinsic translation rule.
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    func_name: str
+        The intrinsic function name.
+
+    args : list
+        Positional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    args = convert(args)
+    return Call(
+        dtype, func_name, convert(args), Call.PureIntrinsic, None, 0)
+
+
+def call_intrin(dtype, func_name, *args):
+    """Build expression by calling an intrinsic function.
+
+    Intrinsics can be overloaded with multiple data types via
+    the intrinsic translation rule.
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    func_name: str
+        The intrinsic function name.
+
+    args : list
+        Positional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    args = convert(args)
+    return Call(
+        dtype, func_name, convert(args), Call.Intrinsic, None, 0)
+
+
+def call_pure_extern(dtype, func_name, *args):
+    """Build expression by calling a pure extern function.
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    func_name: str
+        The extern function name.
+
+    args : list
+        Positional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return Call(
+        dtype, func_name, convert(args), Call.PureExtern, None, 0)
+
+
+def call_extern(dtype, func_name, *args):
+    """Build expression by calling a extern function.
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    func_name: str
+        The extern function name.
+
+    args : list
+        Positional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return Call(
+        dtype, func_name, convert(args), Call.Extern, None, 0)
+
+
+def call_llvm_intrin(dtype, name, *args):
+    """Build expression by calling an llvm intrinsic function
+
+    Parameters
+    ----------
+    dtype : str
+       The data type of the result.
+
+    name : str
+       The name of the llvm intrinsic function.
+
+    args : list
+       Poistional arguments.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm.target import codegen
+    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
+
+
+@tvm._ffi.register_func("tvm.default_trace_action")
+def _tvm_default_trace_action(*args):
+    print(list(args))
+
+def trace(args, trace_action="tvm.default_trace_action"):
+    """Trace tensor data at the runtime.
+
+    The trace function allows to trace specific tensor at the
+    runtime. The tracing value should come as last argument.
+    The trace action should be specified, by default
+    tvm.default_trace_action is used.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffers.
+        Positional arguments.
+
+    trace_action : str.
+        The name of the trace action.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    tvm.tir.call_packed : Creates packed function.
+    """
+    if not isinstance(args, list):
+        raise Exception("tvm.trace consumes the args as list type")
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    call_args.insert(0, trace_action)
+    return tvm.tir.Call(
+        args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic, None, 0)
+
+
+
+def min_value(dtype):
+    """minimum value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The minimum value of dtype.
+    """
+    return _ffi_api.min_value(dtype)
+
+
+def max_value(dtype):
+    """maximum value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The maximum value of dtype.
+    """
+    return _ffi_api.max_value(dtype)
+
+
+def exp(x):
+    """Take exponetial of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "exp", x)
+
+
+def erf(x):
+    """Take gauss error function of the input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "erf", x)
+
+
+def tanh(x):
+    """Take hyperbolic tanh of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "tanh", x)
+
+
+def sigmoid(x):
+    """Quick function to get sigmoid
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "sigmoid", x)
+
+
+def log(x):
+    """Take log of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "log", x)
+
+def cos(x):
+    """Take cos of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "cos", x)
+
+def sin(x):
+    """Take sin of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "sin", x)
+
+def atan(x):
+    """Take atan of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "atan", x)
+
+def sqrt(x):
+    """Take square root of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "sqrt", x)
+
+
+def rsqrt(x):
+    """Take reciprocal of square root of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "rsqrt", x)
+
+
+def floor(x):
+    """Take floor of float input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.floor(x)
+
+
+def ceil(x):
+    """Take ceil of float input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.ceil(x)
+
+
+def trunc(x):
+    """Get truncated value of the input.
+
+    The truncated value of the scalar x is the
+    nearest integer i which is closer to zero than x is.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.trunc(x)
+
+
+def abs(x):
+    """Get absolute value of the input element-wise.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.abs(x)
+
+
+def round(x):
+    """Round elements of the array to the nearest integer.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.round(x)
+
+
+def nearbyint(x):
+    """Round elements of the array to the nearest integer.
+    This intrinsic uses llvm.nearbyint instead of llvm.round
+    which is faster but will results different from tvm.round.
+    Notably nearbyint rounds according to the rounding mode,
+    whereas tvm.round (llvm.round) ignores that.
+    For differences between the two see:
+    https://en.cppreference.com/w/cpp/numeric/math/round
+    https://en.cppreference.com/w/cpp/numeric/math/nearbyint
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.nearbyint(x)
+
+
+def isnan(x):
+    """Check if input value is Nan.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.isnan(x)
+
+
+def power(x, y):
+    """x power y
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    y : PrimExpr
+        The exponent
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api._OpPow(convert(x), convert(y))
+
+
+def popcount(x):
+    """Count the number of set bits in input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "popcount", x)
+
+def fmod(x, y):
+    """Return the remainder of x divided by y with the same sign as x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "fmod", x, y)
+
+
+def if_then_else(cond, t, f):
+    """Conditional selection expression.
+
+    Parameters
+    ----------
+    cond : PrimExpr
+        The condition
+
+    t : PrimExpr
+        The result expression if cond is true.
+
+    f : PrimExpr
+        The result expression if cond is false.
+
+    Returns
+    -------
+    result : Node
+        The result of conditional expression.
+
+    Note
+    ----
+    Unlike Select, if_then_else will not execute
+    the branch that does not satisfy the condition.
+    You can use it to guard against out of bound access.
+    Unlike Select, if_then_else cannot be vectorized
+    if some lanes in the vector have different conditions.
+    """
+    return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f))
+
+
+def div(a, b):
+    """Compute a / b as in C/C++ semantics.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand, known to be non-negative.
+
+    b : PrimExpr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    Note
+    ----
+    When operands are integers, returns truncdiv(a, b).
+    """
+    return _ffi_api._OpDiv(a, b)
+
+
+def indexdiv(a, b):
+    """Compute floor(a / b) where a and b are non-negative.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand, known to be non-negative.
+
+    b : PrimExpr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+
+    Note
+    ----
+    Use this function to split non-negative indices.
+    This function may take advantage of operands'
+    non-negativeness.
+    """
+    return _ffi_api._OpIndexDiv(a, b)
+
+
+def indexmod(a, b):
+    """Compute the remainder of indexdiv. a and b are non-negative.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand, known to be non-negative.
+
+    b : PrimExpr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+
+    Note
+    ----
+    Use this function to split non-negative indices.
+    This function may take advantage of operands'
+    non-negativeness.
+    """
+    return _ffi_api._OpIndexMod(a, b)
+
+
+def truncdiv(a, b):
+    """Compute the truncdiv of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+
+    Note
+    ----
+    This is the default integer division behavior in C.
+    """
+    return _ffi_api._OpTruncDiv(a, b)
+
+
+def truncmod(a, b):
+    """Compute the truncmod of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+
+    Note
+    ----
+    This is the default integer division behavior in C.
+    """
+    return _ffi_api._OpTruncMod(a, b)
+
+
+def floordiv(a, b):
+    """Compute the floordiv of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api._OpFloorDiv(a, b)
+
+
+def floormod(a, b):
+    """Compute the floormod of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api._OpFloorMod(a, b)
similarity index 85%
rename from python/tvm/stmt.py
rename to python/tvm/tir/stmt.py
index e5feb50..bc02b7d 100644 (file)
@@ -25,18 +25,19 @@ Each statement node have subfields that can be visited from python side.
 
     x = tvm.var("n")
     a = tvm.var("array", tvm.handle)
-    st = tvm.make.Store(a, x + 1, 1)
-    assert isinstance(st, tvm.stmt.Store)
+    st = tvm.tir.stmt.Store(a, x + 1, 1)
+    assert isinstance(st, tvm.tir.stmt.Store)
     assert(st.buffer_var == a)
 """
 import tvm._ffi
 
 from tvm.runtime import Object
-from . import make as _make
+from . import _ffi_api
 
 
 class Stmt(Object):
-    pass
+    """Base class of all the statements."""
+
 
 @tvm._ffi.register_object
 class LetStmt(Stmt):
@@ -47,7 +48,7 @@ class LetStmt(Stmt):
     var : Var
         The variable in the binding.
 
-    value : Expr
+    value : PrimExpr
         The value in to be binded.
 
     body : Stmt
@@ -55,7 +56,7 @@ class LetStmt(Stmt):
     """
     def __init__(self, var, value, body):
         self.__init_handle_by_constructor__(
-            _make.LetStmt, var, value, body)
+            _ffi_api.LetStmt, var, value, body)
 
 
 @tvm._ffi.register_object
@@ -64,10 +65,10 @@ class AssertStmt(Stmt):
 
     Parameters
     ----------
-    condition : Expr
+    condition : PrimExpr
         The assert condition.
 
-    message : Expr
+    message : PrimExpr
         The error message.
 
     body : Stmt
@@ -75,7 +76,7 @@ class AssertStmt(Stmt):
     """
     def __init__(self, condition, message, body):
         self.__init_handle_by_constructor__(
-            _make.AssertStmt, condition, message, body)
+            _ffi_api.AssertStmt, condition, message, body)
 
 
 @tvm._ffi.register_object
@@ -95,7 +96,7 @@ class ProducerConsumer(Stmt):
     """
     def __init__(self, func, is_producer, body):
         self.__init_handle_by_constructor__(
-            _make.ProducerConsumer, func, is_producer, body)
+            _ffi_api.ProducerConsumer, func, is_producer, body)
 
 
 @tvm._ffi.register_object
@@ -107,10 +108,10 @@ class For(Stmt):
     loop_var : Var
         The loop variable.
 
-    min_val : Expr
+    min_val : PrimExpr
         The begining value.
 
-    extent : Expr
+    extent : PrimExpr
         The length of the loop.
 
     for_type : int
@@ -134,7 +135,7 @@ class For(Stmt):
                  device_api,
                  body):
         self.__init_handle_by_constructor__(
-            _make.For, loop_var, min_val, extent,
+            _ffi_api.For, loop_var, min_val, extent,
             for_type, device_api, body)
 
 
@@ -147,18 +148,19 @@ class Store(Stmt):
     buffer_var : Var
         The buffer Variable.
 
-    value : Expr
+    value : PrimExpr
         The value we want to store.
 
-    index : Expr
+    index : PrimExpr
         The index in the store expression.
 
-    predicate : Expr
+    predicate : PrimExpr
         The store predicate.
     """
-    def __init__(self, buffer_var, value, index, predicate):
+    def __init__(self, buffer_var, value, index, predicate=None):
+        args = [] if predicate is None else [predicate]
         self.__init_handle_by_constructor__(
-            _make.Store, buffer_var, value, index, predicate)
+            _ffi_api.Store, buffer_var, value, index, *args)
 
 
 @tvm._ffi.register_object
@@ -173,7 +175,7 @@ class Provide(Stmt):
     value_index : int
         The output value index
 
-    value : Expr
+    value : PrimExpr
         The value to be stored.
 
     args : list of Expr
@@ -181,7 +183,7 @@ class Provide(Stmt):
     """
     def __init__(self, func, value_index, value, args):
         self.__init_handle_by_constructor__(
-            _make.Provide, func, value_index, value, args)
+            _ffi_api.Provide, func, value_index, value, args)
 
 
 @tvm._ffi.register_object
@@ -199,7 +201,7 @@ class Allocate(Stmt):
     extents : list of Expr
         The extents of the allocate
 
-    condition : Expr
+    condition : PrimExpr
         The condition.
 
     body : Stmt
@@ -212,7 +214,7 @@ class Allocate(Stmt):
                  condition,
                  body):
         self.__init_handle_by_constructor__(
-            _make.Allocate, buffer_var, dtype,
+            _ffi_api.Allocate, buffer_var, dtype,
             extents, condition, body)
 
 
@@ -228,7 +230,7 @@ class AttrStmt(Stmt):
     attr_key : str
         Attribute type key.
 
-    value : Expr
+    value : PrimExpr
         The value of the attribute
 
     body : Stmt
@@ -236,7 +238,7 @@ class AttrStmt(Stmt):
     """
     def __init__(self, node, attr_key, value, body):
         self.__init_handle_by_constructor__(
-            _make.AttrStmt, node, attr_key, value, body)
+            _ffi_api.AttrStmt, node, attr_key, value, body)
 
 
 @tvm._ffi.register_object
@@ -250,7 +252,7 @@ class Free(Stmt):
     """
     def __init__(self, buffer_var):
         self.__init_handle_by_constructor__(
-            _make.Free, buffer_var)
+            _ffi_api.Free, buffer_var)
 
 
 @tvm._ffi.register_object
@@ -271,7 +273,7 @@ class Realize(Stmt):
     bounds : list of range
         The bound of realize
 
-    condition : Expr
+    condition : PrimExpr
         The realize condition.
 
     body : Stmt
@@ -285,7 +287,7 @@ class Realize(Stmt):
                  condition,
                  body):
         self.__init_handle_by_constructor__(
-            _make.Realize, func, value_index, dtype,
+            _ffi_api.Realize, func, value_index, dtype,
             bounds, condition, body)
 
 
@@ -300,7 +302,7 @@ class SeqStmt(Stmt):
     """
     def __init__(self, seq):
         self.__init_handle_by_constructor__(
-            _make.SeqStmt, seq)
+            _ffi_api.SeqStmt, seq)
 
     def __getitem__(self, i):
         return self.seq[i]
@@ -315,7 +317,7 @@ class IfThenElse(Stmt):
 
     Parameters
     ----------
-    condition : Expr
+    condition : PrimExpr
         The expression
 
     then_case : Stmt
@@ -326,7 +328,7 @@ class IfThenElse(Stmt):
     """
     def __init__(self, condition, then_case, else_case):
         self.__init_handle_by_constructor__(
-            _make.IfThenElse, condition, then_case, else_case)
+            _ffi_api.IfThenElse, condition, then_case, else_case)
 
 
 @tvm._ffi.register_object
@@ -335,12 +337,12 @@ class Evaluate(Stmt):
 
     Parameters
     ----------
-    value : Expr
+    value : PrimExpr
         The expression to be evalued.
     """
     def __init__(self, value):
         self.__init_handle_by_constructor__(
-            _make.Evaluate, value)
+            _ffi_api.Evaluate, value)
 
 
 @tvm._ffi.register_object
@@ -363,7 +365,7 @@ class Prefetch(Stmt):
     """
     def __init__(self, func, value_index, dtype, bounds):
         self.__init_handle_by_constructor__(
-            _make.Prefetch, func, value_index, dtype, bounds)
+            _ffi_api.Prefetch, func, value_index, dtype, bounds)
 
 
 @tvm._ffi.register_object
@@ -417,6 +419,3 @@ def stmt_list(stmt):
     if isinstance(stmt, ProducerConsumer):
         return stmt_list(stmt.body)
     return [stmt]
-
-
-_make.stmt_list = stmt_list
index 35810cb..1e71baf 100644 (file)
 namespace tvm {
 namespace tir {
 
-TVM_REGISTER_GLOBAL("_Var")
+TVM_REGISTER_GLOBAL("tir.Var")
 .set_body_typed([](std::string s, DataType t) {
     return Var(s, t);
   });
 
-TVM_REGISTER_GLOBAL("_SizeVar")
+TVM_REGISTER_GLOBAL("tir.SizeVar")
 .set_body_typed([](std::string s, DataType t) {
     return SizeVar(s, t);
   });
 
-TVM_REGISTER_GLOBAL("make.abs")
+TVM_REGISTER_GLOBAL("tir.abs")
 .set_body_typed(tvm::abs);
 
-TVM_REGISTER_GLOBAL("make.isnan")
+TVM_REGISTER_GLOBAL("tir.isnan")
 .set_body_typed(tvm::isnan);
 
-TVM_REGISTER_GLOBAL("make.floor")
+TVM_REGISTER_GLOBAL("tir.floor")
 .set_body_typed(tvm::floor);
 
-TVM_REGISTER_GLOBAL("make.ceil")
+TVM_REGISTER_GLOBAL("tir.ceil")
 .set_body_typed(tvm::ceil);
 
-TVM_REGISTER_GLOBAL("make.round")
+TVM_REGISTER_GLOBAL("tir.round")
 .set_body_typed(tvm::round);
 
-TVM_REGISTER_GLOBAL("make.nearbyint")
+TVM_REGISTER_GLOBAL("tir.nearbyint")
 .set_body_typed(tvm::nearbyint);
 
-TVM_REGISTER_GLOBAL("make.trunc")
+TVM_REGISTER_GLOBAL("tir.trunc")
 .set_body_typed(tvm::trunc);
 
-TVM_REGISTER_GLOBAL("make._cast")
+TVM_REGISTER_GLOBAL("tir._cast")
 .set_body_typed(tvm::cast);
 
-TVM_REGISTER_GLOBAL("make._range_by_min_extent")
+TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
 .set_body_typed(Range::make_by_min_extent);
 
 
-TVM_REGISTER_GLOBAL("make.SeqStmt")
+TVM_REGISTER_GLOBAL("tir.SeqStmt")
 .set_body_typed([](Array<Stmt> seq) {
   return SeqStmt(std::move(seq));
 });
 
-TVM_REGISTER_GLOBAL("make.For")
+TVM_REGISTER_GLOBAL("tir.For")
 .set_body_typed([](
   Var loop_var, PrimExpr min, PrimExpr extent,
   int for_type, int device_api, Stmt body) {
@@ -85,7 +85,7 @@ TVM_REGISTER_GLOBAL("make.For")
                    body);
 });
 
-TVM_REGISTER_GLOBAL("make.Load")
+TVM_REGISTER_GLOBAL("tir.Load")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
     DataType t = args[0];
     if (args.size() == 3) {
@@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("make.Load")
     }
   });
 
-TVM_REGISTER_GLOBAL("make.Store")
+TVM_REGISTER_GLOBAL("tir.Store")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
     PrimExpr value = args[1];
     if (args.size() == 3) {
@@ -105,10 +105,10 @@ TVM_REGISTER_GLOBAL("make.Store")
     }
   });
 
-TVM_REGISTER_GLOBAL("make.Realize")
+TVM_REGISTER_GLOBAL("tir.Realize")
 .set_body_typed(RealizeNode::make);
 
-TVM_REGISTER_GLOBAL("make.Call")
+TVM_REGISTER_GLOBAL("tir.Call")
 .set_body_typed([](
   DataType type, std::string name,
   Array<PrimExpr> args, int call_type,
@@ -122,12 +122,12 @@ TVM_REGISTER_GLOBAL("make.Call")
                     value_index);
 });
 
-TVM_REGISTER_GLOBAL("make.CommReducer")
+TVM_REGISTER_GLOBAL("tir.CommReducer")
 .set_body_typed(CommReducerNode::make);
 
 // make from two arguments
 #define REGISTER_MAKE(NodeName)                                     \
-  TVM_REGISTER_GLOBAL("make."#NodeName)                             \
+  TVM_REGISTER_GLOBAL("tir."#NodeName)                             \
   .set_body_typed(NodeName ## Node::make);                          \
 
 
@@ -172,7 +172,7 @@ REGISTER_MAKE(Evaluate);
 
 // overloaded, needs special handling
 // has default args
-TVM_REGISTER_GLOBAL("make.Allocate")
+TVM_REGISTER_GLOBAL("tir.Allocate")
   .set_body_typed([](
     Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
   ){
@@ -180,14 +180,14 @@ TVM_REGISTER_GLOBAL("make.Allocate")
   });
 
 // operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func)                     \
-  TVM_REGISTER_GLOBAL("make."#Node)                             \
+#define REGISTER_MAKE_BINARY_OP(Node, Func)                    \
+  TVM_REGISTER_GLOBAL("tir."#Node)                              \
   .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
     return (Func(a, b));                                        \
   })
 
-#define REGISTER_MAKE_BIT_OP(Node, Func)                                \
-  TVM_REGISTER_GLOBAL("make."#Node)                                     \
+#define REGISTER_MAKE_BIT_OP(Node, Func)                               \
+  TVM_REGISTER_GLOBAL("tir."#Node)                                      \
   .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
     bool lhs_is_int = args[0].type_code() == kDLInt;                    \
     bool rhs_is_int = args[1].type_code() == kDLInt;                    \
@@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
 REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
 REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, operator>>);
-TVM_REGISTER_GLOBAL("make._OpIfThenElse")
+TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
 .set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
   return if_then_else(cond, true_value, false_value);
 });
index 591869e..d2f2cb6 100644 (file)
 
 namespace tvm {
 
-TVM_REGISTER_GLOBAL("_min_value")
+TVM_REGISTER_GLOBAL("tir.min_value")
 .set_body_typed(min_value);
 
-TVM_REGISTER_GLOBAL("_max_value")
+TVM_REGISTER_GLOBAL("tir.max_value")
 .set_body_typed(max_value);
 
 TVM_REGISTER_GLOBAL("Range")
@@ -49,66 +49,6 @@ TVM_REGISTER_GLOBAL("Range")
     }
   });
 
-namespace tir {
-
-TVM_REGISTER_GLOBAL("_Buffer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    CHECK_EQ(args.size(), 10);
-    auto buffer_type = args[9].operator std::string();
-    BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
-    *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
-                            args[5], args[6], args[7], args[8], type);
-  });
-
-TVM_REGISTER_GLOBAL("_BufferAccessPtr")
-.set_body_method(&Buffer::access_ptr);
-
-TVM_REGISTER_GLOBAL("_BufferVLoad")
-.set_body_method(&Buffer::vload);
-
-TVM_REGISTER_GLOBAL("_BufferVStore")
-.set_body_method(&Buffer::vstore);
-
-TVM_REGISTER_GLOBAL("_Layout")
-.set_body_typed(LayoutNode::make);
-
-TVM_REGISTER_GLOBAL("_LayoutIndexOf")
-.set_body_typed([](Layout layout, std::string axis) -> int {
-  return layout.IndexOf(LayoutAxis::make(axis));
-});
-
-TVM_REGISTER_GLOBAL("_LayoutFactorOf")
-.set_body_typed([](Layout layout, std::string axis) -> int {
-  return layout.FactorOf(LayoutAxis::make(axis));
-});
-
-TVM_REGISTER_GLOBAL("_LayoutNdim")
-.set_body_typed([](Layout layout) -> int {
-  return layout.ndim();
-});
-
-TVM_REGISTER_GLOBAL("_LayoutGetItem")
-.set_body_typed([](Layout layout, int idx) -> std::string {
-  const LayoutAxis& axis = layout[idx];
-  return axis.name();
-});
-
-TVM_REGISTER_GLOBAL("_BijectiveLayout")
-.set_body_typed(BijectiveLayoutNode::make);
-
-TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardIndex")
-.set_body_method(&BijectiveLayout::ForwardIndex);
-
-TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardIndex")
-.set_body_method(&BijectiveLayout::BackwardIndex);
-
-TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape")
-.set_body_method(&BijectiveLayout::ForwardShape);
-
-TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
-.set_body_method(&BijectiveLayout::BackwardShape);
-}  // namespace tir
-
 namespace te {
 TVM_REGISTER_GLOBAL("_Tensor")
 .set_body_typed(TensorNode::make);
index 78c6879..4feabeb 100644 (file)
@@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value) {
   data_ = std::move(node);
 }
 
-TVM_REGISTER_GLOBAL("make.IntImm")
+TVM_REGISTER_GLOBAL("ir.IntImm")
 .set_body_typed([](DataType dtype, int64_t value) {
   return IntImm(dtype, value);
 });
@@ -97,7 +97,7 @@ FloatImm::FloatImm(DataType dtype, double value) {
   data_ = std::move(node);
 }
 
-TVM_REGISTER_GLOBAL("make.FloatImm")
+TVM_REGISTER_GLOBAL("ir.FloatImm")
 .set_body_typed([](DataType dtype, double value) {
   return FloatImm(dtype, value);
 });
index d61d72b..183079f 100644 (file)
@@ -304,6 +304,6 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr")
 TVM_REGISTER_GLOBAL("node.NodeListAttrNames")
 .set_body(NodeListAttrNames);
 
-TVM_REGISTER_GLOBAL("make._Node")
+TVM_REGISTER_GLOBAL("node.MakeNode")
 .set_body(MakeNode);
 }  // namespace tvm
index 00bf70b..86362e0 100644 (file)
@@ -906,7 +906,9 @@ static const char* kSemVer = "v0.0.4";
 // - relay_text_printer.cc (specific printing logics for relay)
 // - tir_text_printer.cc (specific printing logics for TIR)
 std::string PrettyPrint(const ObjectRef& node) {
-  return AsText(node, false, nullptr);
+  Doc doc;
+  doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
+  return doc.str();
 }
 
 std::string AsText(const ObjectRef& node,
@@ -918,6 +920,10 @@ std::string AsText(const ObjectRef& node,
   return doc.str();
 }
 
+
+TVM_REGISTER_GLOBAL("ir.PrettyPrint")
+.set_body_typed(PrettyPrint);
+
 TVM_REGISTER_GLOBAL("ir.AsText")
 .set_body_typed(AsText);
 }  // namespace tvm
index ff67e8d..19e32d6 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file buffer.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/tir/expr.h>
@@ -460,5 +461,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 });
 
 TVM_REGISTER_NODE_TYPE(BufferNode);
+
+
+TVM_REGISTER_GLOBAL("tir.Buffer")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    CHECK_EQ(args.size(), 10);
+    auto buffer_type = args[9].operator std::string();
+    BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
+    *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
+                            args[5], args[6], args[7], args[8], type);
+  });
+
+TVM_REGISTER_GLOBAL("tir.BufferAccessPtr")
+.set_body_method(&Buffer::access_ptr);
+
+TVM_REGISTER_GLOBAL("tir.BufferVLoad")
+.set_body_method(&Buffer::vload);
+
+TVM_REGISTER_GLOBAL("tir.BufferVStore")
+.set_body_method(&Buffer::vstore);
+
 }  // namespace tir
 }  // namespace tvm
index 8a5125b..9cc07a8 100644 (file)
@@ -21,6 +21,7 @@
  * \file src/lang/data_layout.cc
  * \brief Data Layout expression.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/data_layout.h>
 #include <tvm/tir/ir_pass.h>
 #include <cctype>
@@ -371,5 +372,44 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     p->stream << "BijectiveLayout(" << b->src_layout.name()
               << "->" << b->dst_layout.name() << ")";
   });
+
+TVM_REGISTER_GLOBAL("tir.Layout")
+.set_body_typed(LayoutNode::make);
+
+TVM_REGISTER_GLOBAL("tir.LayoutIndexOf")
+.set_body_typed([](Layout layout, std::string axis) -> int {
+  return layout.IndexOf(LayoutAxis::make(axis));
+});
+
+TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
+.set_body_typed([](Layout layout, std::string axis) -> int {
+  return layout.FactorOf(LayoutAxis::make(axis));
+});
+
+TVM_REGISTER_GLOBAL("tir.LayoutNdim")
+.set_body_typed([](Layout layout) -> int {
+  return layout.ndim();
+});
+
+TVM_REGISTER_GLOBAL("tir.LayoutGetItem")
+.set_body_typed([](Layout layout, int idx) -> std::string {
+  const LayoutAxis& axis = layout[idx];
+  return axis.name();
+});
+
+TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
+.set_body_typed(BijectiveLayoutNode::make);
+
+TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
+.set_body_method(&BijectiveLayout::ForwardIndex);
+
+TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
+.set_body_method(&BijectiveLayout::BackwardIndex);
+
+TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
+.set_body_method(&BijectiveLayout::ForwardShape);
+
+TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
+.set_body_method(&BijectiveLayout::BackwardShape);
 }  // namespace tir
 }  // namespace tvm
index 454354e..62c0290 100644 (file)
@@ -24,7 +24,7 @@ def test_reduce_prims():
         n = tvm.size_var('n')
         m = tvm.size_var('m')
         A = tvm.placeholder((n, m), name='A')
-        R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
+        R = tvm.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R')
         k = tvm.reduce_axis((0, m))
         B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
         # schedule
@@ -232,8 +232,8 @@ def test_rfactor_elemwise_threads():
 
 def test_argmax():
     def fcombine(x, y):
-        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
-        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
+        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
         return lhs, rhs
 
     def fidentity(t0, t1):
@@ -279,8 +279,8 @@ def test_argmax():
 
 def test_rfactor_argmax():
     def fcombine(x, y):
-        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
-        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
+        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
         return lhs, rhs
 
     def fidentity(t0, t1):
index 640da0b..fd7ec18 100644 (file)
@@ -82,10 +82,10 @@ def test_compile_tuple_dup():
 def test_compile_full():
     # Shape calculations can happen in int64. The test checks that full operator
     # can handle when shapes are not int32
-    shape = (tvm.expr.IntImm('int32', 1),
-             tvm.expr.IntImm('int64', 16),
-             tvm.expr.IntImm('int64', 16),
-             tvm.expr.IntImm('int32', 64))
+    shape = (tvm.tir.IntImm('int32', 1),
+             tvm.tir.IntImm('int64', 16),
+             tvm.tir.IntImm('int64', 16),
+             tvm.tir.IntImm('int32', 64))
     output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
     f = relay.Function([], output)
     mod = tvm.IRModule.from_expr(f)
index 2af4a20..674e214 100644 (file)
@@ -41,7 +41,7 @@ def test_basic_build():
     }
     # build
     targets = {
-        tvm.expr.IntImm("int32", ctx.device_type): tgt
+        tvm.tir.IntImm("int32", ctx.device_type): tgt
     }
     g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params)
 
index 3735259..608bc2a 100644 (file)
@@ -77,9 +77,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
 
 
 def set_external_func_attr(func, compiler, ext_symbol):
-    func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
-    func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler))
-    func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol))
+    func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+    func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler))
+    func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
     return func
 
 
index 7e8c832..713aca9 100644 (file)
@@ -307,7 +307,7 @@ def get_synthetic_lib():
     subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
                                 gcc_input3], relay.copy(gcc_input0))
     subgraph0 = subgraph0.set_attribute(
-        "Primitive", tvm.expr.IntImm("int32", 1))
+        "Primitive", tvm.tir.IntImm("int32", 1))
 
     # Call subgraph0
     subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2])
@@ -320,7 +320,7 @@ def get_synthetic_lib():
     subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
                                 gcc_input7], relay.copy(gcc_input4))
     subgraph1 = subgraph1.set_attribute(
-        "Primitive", tvm.expr.IntImm("int32", 1))
+        "Primitive", tvm.tir.IntImm("int32", 1))
 
     # Call subgraph1
     subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5])
index 29e578b..ad15255 100644 (file)
@@ -17,7 +17,7 @@
 """ test ir"""
 import tvm
 from tvm import relay
-from tvm.expr import *
+from tvm.tir.expr import *
 from tvm.relay import op
 from tvm.relay.analysis import graph_equal
 import numpy as np
@@ -110,7 +110,7 @@ def test_type_relation():
 
     num_inputs = 2
     func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
-    attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
 
     tr = relay.TypeRelation(func, args, num_inputs, attrs)
     assert tr.args == args
index 261cbb9..bcce9b4 100644 (file)
@@ -69,7 +69,7 @@ type List[A] {
 """
 
 def roundtrip(expr):
-    x = relay.fromtext(str(expr))
+    x = relay.fromtext(expr.astext())
     assert_graph_equal(x, expr)
 
 
@@ -343,7 +343,7 @@ def test_func():
     # attributes
     assert parses_as(
         "fn (n=5) { () }",
-        relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5)))
+        relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
     )
 
 
index 5876a70..0d3fd4b 100644 (file)
@@ -630,8 +630,8 @@ def test_upsampling_infer_type():
     y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
     "method=\"BINLINEAR\"" in y.astext()
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)),
-                                                tvm.expr.Cast("int32", tvm.round(w*scale))),
+    assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(h*scale)),
+                                                tvm.tir.Cast("int32", tvm.round(w*scale))),
                                                 "float32")
     n, c = tvm.size_var("n"), tvm.size_var("c")
     x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
@@ -647,9 +647,9 @@ def test_upsampling3d_infer_type():
     y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
 
     yy = run_infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)),
-                                                tvm.expr.Cast("int32", tvm.round(h*scale)),
-                                                tvm.expr.Cast("int32", tvm.round(w*scale))),
+    assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(d*scale)),
+                                                tvm.tir.Cast("int32", tvm.round(h*scale)),
+                                                tvm.tir.Cast("int32", tvm.round(w*scale))),
                                                 "float32")
     n, c = tvm.size_var("n"), tvm.size_var("c")
     x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
index 9c5dfac..c5f340a 100644 (file)
@@ -517,7 +517,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
         alpha_shape = (data[axis],)
         assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")
 
-    if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or not alpha:
+    if all(isinstance(v, tvm.tir.Var) == 1 for v in data) or not alpha:
         return
 
     func = relay.Function([x, y], z)
index 0243adc..c5cd708 100644 (file)
@@ -154,7 +154,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
     out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
     assert zz.checked_type == relay.ty.TensorType(output, out_type)
 
-    if all(isinstance(v, tvm.expr.Var) == 1 for v in data):
+    if all(isinstance(v, tvm.tir.Var) == 1 for v in data):
         return
 
     func = relay.Function([x], z)
index bdc032e..5985273 100644 (file)
@@ -160,9 +160,9 @@ def test_type_relation_alpha_equal():
     broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
     identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
 
-    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
+    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
 
     tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
     same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
@@ -322,7 +322,7 @@ def test_multi_node_subgraph():
     p00 = relay.subtract(z00, w01)
     q00 = relay.multiply(p00, w02)
     func0 = relay.Function([x0, w00, w01, w02], q00)
-    func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))
+    func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a"))
 
     x1 = relay.var('x1', shape=(10, 10))
     w10 = relay.var('w10', shape=(10, 10))
@@ -332,7 +332,7 @@ def test_multi_node_subgraph():
     p10 = relay.subtract(z10, w11)
     q10 = relay.multiply(p10, w12)
     func1 = relay.Function([x1, w10, w11, w12], q10)
-    func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
+    func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b"))
     assert not alpha_equal(func0, func1)
 
 
@@ -413,9 +413,9 @@ def test_call_alpha_equal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
-    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
+    attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
 
     tt1 = relay.TensorType((1, 2, 3), "float32")
     tt2 = relay.TensorType((), "int8")
index 27a143b..6f20278 100644 (file)
@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops():
         add = x0 + y0
         # Function that uses C compiler
         func = relay.Function([x0, y0], add)
-        func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
+        func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.set_attribute("Compiler",
-                                  tvm.expr.StringImm("ccompiler"))
+                                  tvm.tir.StringImm("ccompiler"))
         func = func.set_attribute("ExternalSymbol",
-                                  tvm.expr.StringImm("ccompiler_0"))
+                                  tvm.tir.StringImm("ccompiler_0"))
         add_call = relay.Call(func, [x, y])
         # Function that uses default compiler. Ops are fused in this function.
         p0 = relay.var("p0", shape=(8, 8))
@@ -316,7 +316,7 @@ def test_extern_ccompiler_default_ops():
         concat = relay.concatenate([log, exp], axis=0)
         fused_func = relay.Function([p0], concat)
         fused_func = fused_func.set_attribute("Primitive",
-                                              tvm.expr.IntImm("int32", 1))
+                                              tvm.tir.IntImm("int32", 1))
         fused_call = relay.Call(fused_func, [add_call])
         main = relay.Function([x, y], fused_call)
         mod = tvm.IRModule()
index b84a49a..854301b 100644 (file)
@@ -65,7 +65,7 @@ def test_tuple_type():
 
 def test_type_relation():
     func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast')
-    attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
+    attrs = tvm.ir.make_node('attrs.TestAttrs', name='attr', padding=(3,4))
     tp = TypeVar('tp')
     tf = FuncType([], TupleType([]), [], [])
     tt = TensorType([1, 2, 3], 'float32')
index 28a4388..35822d2 100644 (file)
@@ -151,9 +151,9 @@ def test_reduce_combiner_simplify():
     prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
 
     sum_or_prod = comm_reducer(
-        lambda x, y: tvm.expr.Select(dummy < 0,
+        lambda x, y: tvm.tir.Select(dummy < 0,
                                      x + y, x*y),
-        lambda t0: tvm.expr.Select(dummy < 0,
+        lambda t0: tvm.tir.Select(dummy < 0,
                                    tvm.const(0, t0), tvm.const(1, t0)))
     sum_and_prod = comm_reducer(
         lambda x, y: (x[0] + y[0],
@@ -199,7 +199,7 @@ def test_reduce_combiner_simplify():
             assert tvm.ir_pass.Equal(lhs, rhs)
 
     # Test that components with side effects are not removed
-    side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
+    side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0)
     ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
              sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
     ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
@@ -211,7 +211,7 @@ def test_reduce_simplify():
     k = tvm.reduce_axis((0, 10), name="k")
     j = tvm.reduce_axis((-5, 3), name="j")
     A = tvm.placeholder((10,), name='A')
-    ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]),
+    ck.verify(tvm.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
               tvm.sum(k + j, [k, j]))
     ck.verify(tvm.sum(A[3], []), A[3])
     # The rule below is not typical, removed for now
@@ -235,23 +235,23 @@ def test_simplify_if_then_else():
                                             tmod(tmod(((x*4) + y)  - 466036, 24528) -24512, 16),
                                             x), y)
     expected = tvm.if_then_else(
-        tvm.expr.LE(466036, (x * 4 + y)),
-        tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)),
+        tvm.tir.LE(466036, (x * 4 + y)),
+        tvm.if_then_else(tvm.tir.LE(24512, tmod(((x*4) + y) - 4, 24528)),
                                      tmod(((x*4) + y)  - 4, 16),
                          x), y)
     ck.verify(res, expected)
     ck.verify(res2, expected)
     # can only simplify if condition
-    res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
-    expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
+    res = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
+    expected = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
-    res = tvm.expr.Select(x >= 10,
+    res = tvm.tir.Select(x >= 10,
                           tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
-    expected = tvm.expr.Select(x >= 10, x, 0)
+    expected = tvm.tir.Select(x >= 10, x, 0)
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
-    res = tvm.expr.Select(x >= 10,
+    res = tvm.tir.Select(x >= 10,
                           tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
     ck.verify(res, 0)
 
index ae2837d..aba56ac 100644 (file)
@@ -228,7 +228,7 @@ def test_select_bound():
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
 
     bd = analyzer.const_int_bound(
-        tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1))
+        tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1))
     assert bd.min_value == 0
     assert bd.max_value == 11
 
index 33e31c7..787dfe8 100644 (file)
@@ -19,7 +19,7 @@ import tvm
 
 def assert_expr_equal(a, b):
     res =  tvm.ir_pass.Simplify(a - b)
-    equal = isinstance(res, tvm.expr.IntImm) and res.value == 0
+    equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
     if not equal:
         raise ValueError("{} and {} are not equal".format(a, b))
 
index 9139cc2..3e45d4e 100644 (file)
@@ -23,14 +23,14 @@ def test_domain_touched():
     m = tvm.var('m')
     a = tvm.placeholder((n, m), name = 'a')
     b = tvm.placeholder((n, m), name = 'b')
-    ir = tvm.make.For(
+    ir = tvm.tir.For(
             i, 0, n, 0, 0,
-            tvm.make.For(j, 0, m, 0, 0,
-                tvm.make.Provide(
+            tvm.tir.For(j, 0, m, 0, 0,
+                tvm.tir.Provide(
                     a.op,
                     0,
-                    tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
-                    tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
+                    tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
+                    tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
                     [i, j]
                 )
             )
@@ -51,7 +51,7 @@ def test_domain_touched():
     assert a_domain_rw[0].min.value == -1
     assert a_domain_rw[0].extent.value == 101
     assert a_domain_rw[1].min.value == -1
-    assert isinstance(a_domain_rw[1].extent, tvm.expr.Add)
+    assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
     assert a_domain_rw[1].extent.a.name == 'm'
     assert a_domain_rw[1].extent.b.value == 1
 
index 20e3f57..d83d33d 100644 (file)
@@ -41,7 +41,7 @@ def test_vector():
     base = 10
     stride = 3
     lanes = 2
-    s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
+    s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes))
     assert s.min_value.value == base
     assert s.max_value.value == base + stride * lanes - 1
 
@@ -99,7 +99,7 @@ def test_max_min():
 def test_select():
     ck = IntSetChecker()
     x, y = tvm.var("x"), tvm.var("y")
-    ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
+    ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
               {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
 
 
index 1ce7197..6bb86e4 100644 (file)
@@ -84,7 +84,7 @@ def test_min_max_select():
     assert m.coeff == 3
     assert m.base == 1
 
-    m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2))
+    m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2))
     assert m.coeff == 1
     assert m.base == 0
 
index 99c2942..84560e8 100644 (file)
@@ -29,31 +29,31 @@ def test_vector_simplify():
     ck = RewriteChecker()
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
     # Add rules
-    ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4),
-              tvm.expr.Ramp(x + y, 3, 4))
-    ck.verify(tvm.expr.Ramp(x, 1, 2) + y,
-              tvm.expr.Ramp(x + y, 1, 2))
-    ck.verify(y + tvm.expr.Ramp(x, 1, 2) ,
-              tvm.expr.Ramp(y + x, 1, 2))
+    ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4),
+              tvm.tir.Ramp(x + y, 3, 4))
+    ck.verify(tvm.tir.Ramp(x, 1, 2) + y,
+              tvm.tir.Ramp(x + y, 1, 2))
+    ck.verify(y + tvm.tir.Ramp(x, 1, 2) ,
+              tvm.tir.Ramp(y + x, 1, 2))
     ck.verify(y.astype("int32x2") + x.astype("int32x2"),
               (y + x).astype("int32x2"))
     # Sub rules
-    ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4),
-              tvm.expr.Ramp(x - y, 2, 4))
-    ck.verify(tvm.expr.Ramp(x, 1, 2) - y,
-              tvm.expr.Ramp(x - y, 1, 2))
-    ck.verify(y - tvm.expr.Ramp(x, 1, 2) ,
-              tvm.expr.Ramp(y - x, -1, 2))
+    ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4),
+              tvm.tir.Ramp(x - y, 2, 4))
+    ck.verify(tvm.tir.Ramp(x, 1, 2) - y,
+              tvm.tir.Ramp(x - y, 1, 2))
+    ck.verify(y - tvm.tir.Ramp(x, 1, 2) ,
+              tvm.tir.Ramp(y - x, -1, 2))
     ck.verify(y.astype("int32x2") - x.astype("int32x2"),
               (y - x).astype("int32x2"))
 
     # Mul rules
     ck.verify(y.astype("int32x2") * x.astype("int32x2"),
               (y * x).astype("int32x2"))
-    ck.verify(tvm.expr.Ramp(x, 4, 4) * 2,
-              tvm.expr.Ramp(x * 2, 8, 4))
-    ck.verify(2 * tvm.expr.Ramp(x, 4, 4),
-              tvm.expr.Ramp(x * 2, 8, 4))
+    ck.verify(tvm.tir.Ramp(x, 4, 4) * 2,
+              tvm.tir.Ramp(x * 2, 8, 4))
+    ck.verify(2 * tvm.tir.Ramp(x, 4, 4),
+              tvm.tir.Ramp(x * 2, 8, 4))
 
     ## DivMod rules
     tdiv = tvm.truncdiv
@@ -61,21 +61,21 @@ def test_vector_simplify():
     # truc div
     ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")),
               tdiv(y, x).astype("int32x2"))
-    ck.verify(tdiv(tvm.expr.Ramp(x, 4, 4), 2),
-              tvm.expr.Ramp(tdiv(x, 2), 2, 4))
+    ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2),
+              tvm.tir.Ramp(tdiv(x, 2), 2, 4))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
-    ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
+    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
               (x).astype("int32x4"))
-    ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8),
-              tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8))
+    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
+              tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
     ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
               tmod(y, x).astype("int32x2"))
-    ck.verify(tmod(tvm.expr.Ramp(x, 4, 4), 2),
-              tvm.expr.Broadcast(tmod(x, 2), 4))
-    ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
-              tvm.expr.Ramp(1, 1, 4))
-    ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8),
-              tmod(tvm.expr.Ramp(1, 15, 4), 8))
+    ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2),
+              tvm.tir.Broadcast(tmod(x, 2), 4))
+    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
+              tvm.tir.Ramp(1, 1, 4))
+    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
+              tmod(tvm.tir.Ramp(1, 15, 4), 8))
 
     # floor div
     fld = tvm.floordiv
@@ -83,20 +83,20 @@ def test_vector_simplify():
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True)
     ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")),
               fld(y, x).astype("int32x2"))
-    ck.verify(fld(tvm.expr.Ramp(x, 4, 4), 2),
-              tvm.expr.Ramp(fld(x, 2), 2, 4))
-    ck.verify(fld(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
+    ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2),
+              tvm.tir.Ramp(fld(x, 2), 2, 4))
+    ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
               (x).astype("int32x4"))
-    ck.verify(fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8),
-              fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8))
+    ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
+              fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
     ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")),
               flm(y, x).astype("int32x2"))
-    ck.verify(flm(tvm.expr.Ramp(x, 4, 4), 2),
-              tvm.expr.Broadcast(flm(x, 2), 4))
-    ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
-              tvm.expr.Ramp(1, 1, 4))
-    ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8),
-              flm(tvm.expr.Ramp(1, 15, 4), 8))
+    ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2),
+              tvm.tir.Broadcast(flm(x, 2), 4))
+    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8),
+              tvm.tir.Ramp(1, 1, 4))
+    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
+              flm(tvm.tir.Ramp(1, 15, 4), 8))
 
     # Min/Max rules
     vx = tvm.var("vx", dtype="int32x2")
@@ -113,8 +113,8 @@ def test_vector_simplify():
     ## Logical rules
     ck.verify(y.astype("int32x2").equal(x.astype("int32x2")),
               (y.equal(x)).astype("uint1x2"))
-    ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))),
-              (tvm.expr.NE(y, x)).astype("uint1x2"))
+    ck.verify(tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
+              (tvm.tir.NE(y, x)).astype("uint1x2"))
     ck.verify(y.astype("int32x2") > x.astype("int32x2"),
               (x < y).astype("uint1x2"))
     ck.verify(y.astype("int32x2") >= x.astype("int32x2"),
@@ -123,32 +123,32 @@ def test_vector_simplify():
               (y < x).astype("uint1x2"))
     ck.verify(y.astype("int32x2") <= x.astype("int32x2"),
               (y <= x).astype("uint1x2"))
-    ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-              (tvm.expr.And(y <= x, vc)).astype("uint1x2"))
-    ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
-              (tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
+    ck.verify(tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+              (tvm.tir.And(y <= x, vc)).astype("uint1x2"))
+    ck.verify(tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+              (tvm.tir.Or(y <= x, vc)).astype("uint1x2"))
 
 
 def test_select_simplify():
     ck = RewriteChecker()
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
     # Add rules
-    ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z),
-              tvm.expr.Select(x < 0, y + 1, z))
-    ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z),
-              tvm.expr.Select(x < 0, y + (-1), 1 - z))
-    ck.verify(tvm.expr.Select(x < 0, y, z) - y,
-              tvm.expr.Select(x < 0, 0, z - y))
-    ck.verify(tvm.expr.Select(x < 0, y, z) - z,
-              tvm.expr.Select(x < 0, y - z, 0))
-    ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
-              tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z)))
-    ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
-              tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z)))
-
-    ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y)
-    ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z)
-    ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1)
+    ck.verify(tvm.tir.Select(x < 0, y, 0) + tvm.tir.Select(x < 0, 1, z),
+              tvm.tir.Select(x < 0, y + 1, z))
+    ck.verify(tvm.tir.Select(x < 0, y, 1) - tvm.tir.Select(x < 0, 1, z),
+              tvm.tir.Select(x < 0, y + (-1), 1 - z))
+    ck.verify(tvm.tir.Select(x < 0, y, z) - y,
+              tvm.tir.Select(x < 0, 0, z - y))
+    ck.verify(tvm.tir.Select(x < 0, y, z) - z,
+              tvm.tir.Select(x < 0, y - z, 0))
+    ck.verify(tvm.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
+              tvm.tir.Select(x < 0, tvm.min(y, 1), tvm.min(0, z)))
+    ck.verify(tvm.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)),
+              tvm.tir.Select(x < 0, tvm.max(y, 1), tvm.max(0, z)))
+
+    ck.verify(tvm.tir.Select(x * 3 + 1 != 0, y, z), y)
+    ck.verify(tvm.tir.Select(x * 3 + 1 == 0, y, z), z)
+    ck.verify(tvm.tir.Select(x > 0, y + 1, y + 1), y + 1)
 
 
 def test_add_index_simplify():
@@ -633,7 +633,7 @@ def test_cmp_simplify():
     tmod = tvm.truncmod
     # const int bound
     ck.verify((tmod(x, 2) + 10).equal(0), tvm.const(0, "bool"))
-    ck.verify(tvm.expr.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool"))
+    ck.verify(tvm.tir.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool"))
     ck.verify(tmod(x, 2) + 10 > 1, tvm.const(1, "bool"))
     ck.verify(tmod(x, 2) + 10 <= 1, tvm.const(0, "bool"))
     ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool"))
@@ -645,7 +645,7 @@ def test_cmp_simplify():
     # canonicalization
     ck.verify((x - 10).equal(0), x.equal(10))
     ck.verify((10 - x).equal(0), x.equal(10))
-    ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0)))
+    ck.verify((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0)))
 
     # cmp bound
     ck.verify(x + y < x + z, y < z)
@@ -655,104 +655,104 @@ def test_cmp_simplify():
     ck.verify(y - x < z - x, y < z)
     ck.verify(x - y < x - z, z < y)
 
-    ck.verify(x < z + x, tvm.expr.LT(0, z))
-    ck.verify(x < x + z, tvm.expr.LT(0, z))
+    ck.verify(x < z + x, tvm.tir.LT(0, z))
+    ck.verify(x < x + z, tvm.tir.LT(0, z))
 
-    ck.verify(100 < x + 1, tvm.expr.LT(99, x))
-    ck.verify(1 < 100 - x, tvm.expr.LT(x, 99))
+    ck.verify(100 < x + 1, tvm.tir.LT(99, x))
+    ck.verify(1 < 100 - x, tvm.tir.LT(x, 99))
     ck.verify(x * 3 < y * 3, x < y)
     ck.verify(x * (-3) < y * (-3), y < x)
     ck.verify(x * 3 >= y * 3, y <= x)
 
-    ck.verify(x * 4 >= 2, tvm.expr.LE(1, x))
-    ck.verify(x * 2 >= 50, tvm.expr.LE(25, x))
+    ck.verify(x * 4 >= 2, tvm.tir.LE(1, x))
+    ck.verify(x * 2 >= 50, tvm.tir.LE(25, x))
     ck.verify(x * 4 <= 2, x <= 0)
-    ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
-    ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
+    ck.verify((0 - x * 3) <= 0, tvm.tir.LE(0, x))
+    ck.verify((0 - x * 3) >= 0, tvm.tir.LE(x, 0))
     ck.verify(2 * x <= 0, x <= 0)
 
-    ck.verify(x * 2 >= 3, tvm.expr.LE(2, x))
-    ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))
-    ck.verify(x * 2 >= 1, tvm.expr.LE(1, x))
-    ck.verify(x * 2 >= 0, tvm.expr.LE(0, x))
-    ck.verify(x * 2 >= -1, tvm.expr.LE(0, x))
-    ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x))
-    ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x))
-
-    ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1))
-    ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1))
-    ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0))
-    ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0))
-    ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1))
-    ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1))
-    ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2))
-
-    ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2))
-    ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1))
-    ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1))
-    ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0))
-    ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0))
-    ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1))
-    ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1))
-
-    ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x))
-    ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x))
-    ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x))
-    ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x))
-    ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x))
-    ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x))
-    ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x))
+    ck.verify(x * 2 >= 3, tvm.tir.LE(2, x))
+    ck.verify(x * 2 >= 2, tvm.tir.LE(1, x))
+    ck.verify(x * 2 >= 1, tvm.tir.LE(1, x))
+    ck.verify(x * 2 >= 0, tvm.tir.LE(0, x))
+    ck.verify(x * 2 >= -1, tvm.tir.LE(0, x))
+    ck.verify(x * 2 >= -2, tvm.tir.LE(-1, x))
+    ck.verify(x * 2 >= -3, tvm.tir.LE(-1, x))
+
+    ck.verify(x * 2 <= 3, tvm.tir.LE(x, 1))
+    ck.verify(x * 2 <= 2, tvm.tir.LE(x, 1))
+    ck.verify(x * 2 <= 1, tvm.tir.LE(x, 0))
+    ck.verify(x * 2 <= 0, tvm.tir.LE(x, 0))
+    ck.verify(x * 2 <= -1, tvm.tir.LE(x, -1))
+    ck.verify(x * 2 <= -2, tvm.tir.LE(x, -1))
+    ck.verify(x * 2 <= -3, tvm.tir.LE(x, -2))
+
+    ck.verify(x * (-2) >= 3, tvm.tir.LE(x, -2))
+    ck.verify(x * (-2) >= 2, tvm.tir.LE(x, -1))
+    ck.verify(x * (-2) >= 1, tvm.tir.LE(x, -1))
+    ck.verify(x * (-2) >= 0, tvm.tir.LE(x, 0))
+    ck.verify(x * (-2) >= -1, tvm.tir.LE(x, 0))
+    ck.verify(x * (-2) >= -2, tvm.tir.LE(x, 1))
+    ck.verify(x * (-2) >= -3, tvm.tir.LE(x, 1))
+
+    ck.verify(x * (-2) <= 3, tvm.tir.LE(-1, x))
+    ck.verify(x * (-2) <= 2, tvm.tir.LE(-1, x))
+    ck.verify(x * (-2) <= 1, tvm.tir.LE(0, x))
+    ck.verify(x * (-2) <= 0, tvm.tir.LE(0, x))
+    ck.verify(x * (-2) <= -1, tvm.tir.LE(1, x))
+    ck.verify(x * (-2) <= -2, tvm.tir.LE(1, x))
+    ck.verify(x * (-2) <= -3, tvm.tir.LE(2, x))
 
     # DivMod rules
     # truc div
     ck.verify(tdiv(x, 2) < 3, x < 6)
-    ck.verify(3 < tdiv(x, 2), tvm.expr.LT(7, x))
-    ck.verify(tdiv(x, 3) >= 0, tvm.expr.LE(-2, x))
-    ck.verify(tdiv(x, 2) >= 1, tvm.expr.LE(2, x))
-    ck.verify(tdiv(x, 2) >= 0, tvm.expr.LE(-1, x))
-    ck.verify(tdiv(x, 2) >= -1, tvm.expr.LE(-3, x))
+    ck.verify(3 < tdiv(x, 2), tvm.tir.LT(7, x))
+    ck.verify(tdiv(x, 3) >= 0, tvm.tir.LE(-2, x))
+    ck.verify(tdiv(x, 2) >= 1, tvm.tir.LE(2, x))
+    ck.verify(tdiv(x, 2) >= 0, tvm.tir.LE(-1, x))
+    ck.verify(tdiv(x, 2) >= -1, tvm.tir.LE(-3, x))
 
-    ck.verify(tdiv(x, 2) <= 1, tvm.expr.LE(x, 3))
-    ck.verify(tdiv(x, 2) <= 0, tvm.expr.LE(x, 1))
-    ck.verify(tdiv(x, 2) <= -1, tvm.expr.LE(x, -2))
+    ck.verify(tdiv(x, 2) <= 1, tvm.tir.LE(x, 3))
+    ck.verify(tdiv(x, 2) <= 0, tvm.tir.LE(x, 1))
+    ck.verify(tdiv(x, 2) <= -1, tvm.tir.LE(x, -2))
 
-    ck.verify(tdiv(x, 4) * 4 < x, tvm.expr.LT(0, tmod(x, 4)))
-    ck.verify(tdiv(x, 4) * 4 >= x, tvm.expr.LE(tmod(x, 4), 0))
+    ck.verify(tdiv(x, 4) * 4 < x, tvm.tir.LT(0, tmod(x, 4)))
+    ck.verify(tdiv(x, 4) * 4 >= x, tvm.tir.LE(tmod(x, 4), 0))
 
-    ck.verify(tdiv(x, 4) * 4 < x + y, tvm.expr.LT(0, tmod(x, 4) + y))
-    ck.verify(tdiv(x, 4) * 4 < x - y, tvm.expr.LT(y, tmod(x, 4)))
+    ck.verify(tdiv(x, 4) * 4 < x + y, tvm.tir.LT(0, tmod(x, 4) + y))
+    ck.verify(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4)))
 
-    ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.expr.LE(tmod(x + 2, 4), 2))
-    ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.expr.LE(tmod(x + 2, 4) + y, 2))
-    ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.expr.LE(tmod(x + 2, 4) + (-2), y))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y))
 
     # floor div
     ck.verify(fld(x, 2) < 3, x < 6)
-    ck.verify(3 < fld(x, 2), tvm.expr.LT(7, x))
-    ck.verify(-3 < fld(x, 2), tvm.expr.LT(-5, x))
-    ck.verify(fld(x, 3) >= 0, tvm.expr.LE(0, x))
-    ck.verify(fld(x, 2) >= 1, tvm.expr.LE(2, x))
-    ck.verify(fld(x, 2) >= 0, tvm.expr.LE(0, x))
-    ck.verify(fld(x, 2) >= -1, tvm.expr.LE(-2, x))
-
-    ck.verify(fld(x, 2) <= 1, tvm.expr.LE(x, 3))
-    ck.verify(fld(x, 2) <= 0, tvm.expr.LE(x, 1))
-    ck.verify(fld(x, 2) <= -1, tvm.expr.LE(x, -1))
-
-    ck.verify(fld(x, 4) * 4 < x, tvm.expr.LT(0, flm(x, 4)))
-    ck.verify(fld(x, 4) * 4 >= x, tvm.expr.LE(flm(x, 4), 0))
-
-    ck.verify(fld(x, 4) * 4 < x + y, tvm.expr.LT(0, flm(x, 4) + y))
-    ck.verify(fld(x, 4) * 4 < x - y, tvm.expr.LT(y, flm(x, 4)))
-
-    ck.verify(fld(x + 2, 4) * 4 >= x, tvm.expr.LE(flm(x + 2, 4), 2))
-    ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.expr.LE(flm(x + 2, 4) + y, 2))
-    ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.expr.LE(flm(x + 2, 4) + (-2), y))
+    ck.verify(3 < fld(x, 2), tvm.tir.LT(7, x))
+    ck.verify(-3 < fld(x, 2), tvm.tir.LT(-5, x))
+    ck.verify(fld(x, 3) >= 0, tvm.tir.LE(0, x))
+    ck.verify(fld(x, 2) >= 1, tvm.tir.LE(2, x))
+    ck.verify(fld(x, 2) >= 0, tvm.tir.LE(0, x))
+    ck.verify(fld(x, 2) >= -1, tvm.tir.LE(-2, x))
+
+    ck.verify(fld(x, 2) <= 1, tvm.tir.LE(x, 3))
+    ck.verify(fld(x, 2) <= 0, tvm.tir.LE(x, 1))
+    ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1))
+
+    ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4)))
+    ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0))
+
+    ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y))
+    ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4)))
+
+    ck.verify(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2))
+    ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2))
+    ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y))
     # End DivMod Rules
 
     ck.verify(tvm.min(x, 11) < 10, x < 10)
     ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool"))
-    ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
+    ck.verify(tvm.max(8, x) > 10, tvm.tir.LT(10, x))
     ck.verify(x + 1 < tvm.max(8, x), x < 7)
 
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True)
@@ -777,48 +777,48 @@ def test_logical_simplify():
     ck = RewriteChecker()
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
 
-    ck.verify(tvm.expr.And(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
+    ck.verify(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
               tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
+    ck.verify(tvm.tir.And(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)),
               tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x > 1, tvm.expr.Not(x > 1)), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x <= y, y < x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(y < x, x <= y), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x < 1, 0 < x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x < 0, 1 < x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x < 1, 1 <= x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x <= 1, 1 < x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(1 <= x, x < 1), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(1 < x, x <= 1), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x <= 1, 2 <= x), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(2 <= x, x <= 1), tvm.const(False, "bool"))
-    ck.verify(tvm.expr.And(x == 1, x != 2), x == 1)
-
-
-    ck.verify(tvm.expr.Or(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
+    ck.verify(tvm.tir.And(x > 1, tvm.tir.Not(x > 1)), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x <= y, y < x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(y < x, x <= y), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x < 1, 0 < x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x < 0, 1 < x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x < 1, 1 <= x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x <= 1, 1 < x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(1 <= x, x < 1), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(1 < x, x <= 1), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x <= 1, 2 <= x), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(2 <= x, x <= 1), tvm.const(False, "bool"))
+    ck.verify(tvm.tir.And(x == 1, x != 2), x == 1)
+
+
+    ck.verify(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
               tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
+    ck.verify(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)),
               tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(x > y, tvm.expr.Not(x > y)), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.const(True, "bool"))
 
-    ck.verify(tvm.expr.Or(x <= y, y < x), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(y < x, y >= x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x <= y, y < x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(y < x, y >= x), tvm.const(True, "bool"))
 
-    ck.verify(tvm.expr.Or(x < 1, 0 < x), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(0 < x, x < 1), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x < 1, 0 < x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(0 < x, x < 1), tvm.const(True, "bool"))
 
-    ck.verify(tvm.expr.Or(x < 1, 1 <= x), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(x <= 1, 1 < x), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(1 <= x, x < 1), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(1 < x, x <= 1), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(x <= 1, 2 <= x), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(2 <= x, x <= 1), tvm.const(True, "bool"))
-    ck.verify(tvm.expr.Or(x != 1, x == 2), x != 1)
+    ck.verify(tvm.tir.Or(x < 1, 1 <= x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x <= 1, 1 < x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(1 <= x, x < 1), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(1 < x, x <= 1), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x <= 1, 2 <= x), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(2 <= x, x <= 1), tvm.const(True, "bool"))
+    ck.verify(tvm.tir.Or(x != 1, x == 2), x != 1)
 
 def test_let_simplify():
     ck = RewriteChecker()
     x, y = tvm.var("x"), tvm.var("y")
-    z = tvm.expr.Let(x, 1, x + 1)
+    z = tvm.tir.Let(x, 1, x + 1)
     ck.verify(z + z, 4)
 
 def test_cast_simplify():
@@ -827,11 +827,11 @@ def test_cast_simplify():
 
     dtypes = ["float32", "float16", "int32", "int8", "bool"]
     for dtype1 in dtypes:
-        ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
-        ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
+        ck.verify(tvm.tir.Cast(dtype1, x - x), tvm.const(0, dtype1))
+        ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.const(1, dtype1))
         for dtype2 in dtypes:
             for i in [0, 1, 2, 3]:
-                ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
+                ck.verify(tvm.tir.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
 
 if __name__ == "__main__":
     test_floordiv_index_simplify()
index 9e0b477..58b6083 100644 (file)
@@ -25,9 +25,9 @@ def test_stmt_simplify():
         with ib.if_scope(i < 12):
             A[i] = C[i]
 
-    body = tvm.stmt.LetStmt(n, 10, ib.get())
+    body = tvm.tir.LetStmt(n, 10, ib.get())
     body = tvm.ir_pass.CanonicalSimplify(body)
-    assert isinstance(body.body, tvm.stmt.Store)
+    assert isinstance(body.body, tvm.tir.Store)
 
 
 def test_thread_extent_simplify():
@@ -42,9 +42,9 @@ def test_thread_extent_simplify():
     ib.scope_attr(ty, "thread_extent", 1)
     with ib.if_scope(tx + ty < 12):
         A[tx] = C[tx + ty]
-    body = tvm.stmt.LetStmt(n, 10, ib.get())
+    body = tvm.tir.LetStmt(n, 10, ib.get())
     body = tvm.ir_pass.CanonicalSimplify(body)
-    assert isinstance(body.body.body.body, tvm.stmt.Store)
+    assert isinstance(body.body.body.body, tvm.tir.Store)
 
 
 def test_basic_likely_elimination():
index bfeb652..79b3544 100644 (file)
@@ -185,19 +185,19 @@ def test_cuda_shuffle():
 
     def my_vectorize(stmt):
         def vectorizer(op):
-            if op.for_type == tvm.stmt.For.Vectorized:
+            if op.for_type == tvm.tir.For.Vectorized:
                 four = tvm.const(4, 'int32')
-                idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
+                idx = tvm.tir.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
                 all_ones = tvm.const(1, 'int32x4')
                 store = op.body
                 value = store.value
-                new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones)
+                new_a = tvm.tir.Load('int32x4', value.a.buffer_var, idx, all_ones)
                 bs, ids = [], []
                 for i in range(4):
-                    bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
+                    bs.append(tvm.tir.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
                     ids.append(tvm.const(3 - i, 'int32'))
-                new_b = tvm.make.Shuffle(bs, ids)
-                return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
+                new_b = tvm.tir.Shuffle(bs, ids)
+                return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
             return None
         return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
 
index c60f381..ca32293 100644 (file)
@@ -29,9 +29,9 @@ def test_llvm_intrin():
         tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
         0, 3, 1
     ]
-    ib.emit(tvm.make.Evaluate(
-        tvm.make.Call(
-            "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
+    ib.emit(tvm.tir.Evaluate(
+        tvm.tir.Call(
+            "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
     body = ib.get()
     func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
     fcode = tvm.build(func, None, "llvm")
@@ -643,14 +643,14 @@ def test_llvm_shuffle():
 
         def vectorizer(op):
             store = op.body
-            idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
+            idx = tvm.tir.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
             all_ones = tvm.const(1, 'int32x8')
             value = store.value
-            b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
-            new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
-            new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
+            b_idx = tvm.tir.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
+            new_a = tvm.tir.Load('int32x8', value.a.buffer_var, idx, all_ones)
+            new_b = tvm.tir.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
             value = new_a + new_b
-            return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
+            return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
 
         return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
 
index cf89608..3b9b4a7 100644 (file)
@@ -40,7 +40,7 @@ def test_opencl_ternary_expression():
         true_value = tvm.const(1, dtype=dtype)
         false_value = tvm.const(3, dtype=dtype)
         max_lhs = tvm.const(2, dtype=dtype)
-        max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
+        max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
         C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
         s = tvm.create_schedule(C.op)
         s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
index 3bfe013..4d71cb3 100644 (file)
@@ -26,7 +26,7 @@ def test_static_callback():
     ib = tvm.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     cp = tvm.thread_axis((0, 1), "cop")
-    finit = tvm.make.StringImm("TVMBackendRunOnce")
+    finit = tvm.tir.StringImm("TVMBackendRunOnce")
     ib.scope_attr(cp, "coproc_uop_scope", finit)
     with ib.for_range(0, n, "i", for_type="parallel") as i:
         A[i] = A[i] + 1
index d477983..7f08c75 100644 (file)
@@ -34,7 +34,7 @@ def test_stack_vm_basic():
 
     n = tvm.size_var('n')
     Ab = tvm.decl_buffer((n, ), tvm.float32)
-    stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
+    stmt = tvm.tir.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
     fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
     fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
     fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
@@ -75,7 +75,7 @@ def test_stack_vm_cond():
     ib = tvm.ir_builder.create()
     A = ib.buffer_ptr(Ab)
     with ib.for_range(0, n - 1, "i") as i:
-        with ib.if_scope(tvm.make.EQ(i,  4)):
+        with ib.if_scope(tvm.tir.EQ(i,  4)):
             A[i + 1] = A[i] + 1
         with ib.else_scope():
             A[i + 1] = A[i] + 2
index d9e3c43..d480a0f 100644 (file)
@@ -31,7 +31,7 @@ def test_vector_comparison():
         A = tvm.placeholder(n, dtype=dtype, name='A')
         B = tvm.compute(
             A.shape,
-            lambda i: tvm.expr.Select(
+            lambda i: tvm.tir.Select(
                 A[i] >= 0, A[i] + tvm.const(1, dtype),
                 tvm.const(0, dtype)), name='B')
         s = tvm.create_schedule(B.op)
index 00f9b33..cae4813 100644 (file)
@@ -18,7 +18,7 @@
 import tvm
 from ctypes import *
 import topi
-import tvm.ir_pass as ir_pass
+import tvm.tir.ir_pass as ir_pass
 import numpy as np
 
 tgt = "llvm"
@@ -126,7 +126,7 @@ def test_bfloat_add_and_cast_FloatImm():
     Z = topi.cast(
         topi.add(
             topi.cast(X, dtype="custom[bfloat]16"),
-            tvm.expr.FloatImm("custom[bfloat]16", 1.5)),
+            tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
         dtype="float")
 
     s = tvm.create_schedule([Z.op])
index 87e5a26..311dae8 100644 (file)
@@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
     def tvm_val_2_py_val(val):
         val = tvm.ir_pass.Substitute(val, var_dict)
         val = tvm.ir_pass.Simplify(val)
-        assert isinstance(val, (tvm.expr.IntImm,))
+        assert isinstance(val, (tvm.tir.IntImm,))
         return val.value
 
     ctx = tvm.context(target, 0)
@@ -46,14 +46,14 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
             shape = [tvm_val_2_py_val(j) for j in i.shape]
             emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
             nd_args.append(tvm.nd.array(emu_args[-1], ctx))
-        elif isinstance(i, tvm.expr.Var):
+        elif isinstance(i, tvm.tir.Var):
             emu_args.append(tvm_val_2_py_val(i))
             nd_args.append(emu_args[-1])
         else:
             assert isinstance(i, list)
             emu_args.append(numpy.array(i))
 
-    compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
+    compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))] + \
                    (outs if isinstance(outs, list) else [outs])
     module = tvm.build(sch,
                        compile_args,
@@ -76,7 +76,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
     for nd, np in zip(out_tensors, ref_data):
         tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
 
-    module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))]
+    module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))]
     module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     h_module = tvm.hybrid.build(sch, module_args, module_outs)
 
@@ -111,32 +111,32 @@ def test_outer_product():
         return
 
     #Check for i in (0, n)
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'i'
     assert ir.min.value == 0
     assert ir.extent.name == 'n'
     ibody = ir.body
-    assert isinstance(ibody, tvm.stmt.For)
+    assert isinstance(ibody, tvm.tir.For)
     #Check for j in (0, m)
     assert ibody.loop_var.name == 'j'
     assert ibody.min.value == 0
     assert ibody.extent.name == 'm'
     #Check loop body
     jblock = ibody.body
-    assert isinstance(jblock, tvm.stmt.SeqStmt)
+    assert isinstance(jblock, tvm.tir.SeqStmt)
     jbody = jblock[0]
-    assert isinstance(jbody, tvm.stmt.AssertStmt)
-    assert isinstance(jbody.message, tvm.expr.StringImm)
+    assert isinstance(jbody, tvm.tir.AssertStmt)
+    assert isinstance(jbody.message, tvm.tir.StringImm)
     assert jbody.message.value == "index out of range!"
     jbody = jblock[1]
-    assert isinstance(jbody, tvm.stmt.Provide)
+    assert isinstance(jbody, tvm.tir.Provide)
     assert jbody.func.name == 'c'
     assert len(jbody.args) == 2
     assert jbody.args[0].name == 'i'
     assert jbody.args[1].name == 'j'
-    assert isinstance(jbody.value, tvm.expr.Mul)
+    assert isinstance(jbody.value, tvm.tir.Mul)
     mul = jbody.value
-    assert isinstance(mul.a, tvm.expr.Call)
+    assert isinstance(mul.a, tvm.tir.Call)
     assert mul.a.name == 'a'
     assert mul.b.name == 'b'
 
@@ -177,21 +177,21 @@ def test_fanout():
         return
 
     #Check for i in (0, n-3)
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'i'
     assert ir.min.value == 0
     assert tvm.ir_pass.Equal(ir.extent, n - 3)
     #Check loopbody
     ibody = ir.body
-    assert isinstance(ibody, tvm.stmt.AttrStmt)
+    assert isinstance(ibody, tvm.tir.AttrStmt)
     abody = ibody.body
-    assert isinstance(abody, tvm.stmt.Realize)
+    assert isinstance(abody, tvm.tir.Realize)
     assert abody.bounds[0].min.value == 0
     assert abody.bounds[0].extent.value == 1
     assert abody.func.name == 'sigma'
     #Check i loop body
     rbody = abody.body
-    assert isinstance(rbody[0], tvm.stmt.Provide)
+    assert isinstance(rbody[0], tvm.tir.Provide)
     assert rbody[0].func.name == 'sigma'
     assert len(rbody[0].args) == 1
     assert rbody[0].args[0].value == 0
@@ -201,13 +201,13 @@ def test_fanout():
     assert jloop.min.value == 0
     assert jloop.extent.value == 3
     jbody = jloop.body
-    assert isinstance(jbody, tvm.stmt.Provide)
+    assert isinstance(jbody, tvm.tir.Provide)
     assert len(jbody.args) == 1
     assert jbody.args[0].value == 0
     assert jbody.func.name == 'sigma'
-    assert isinstance(jbody.value, tvm.expr.Add)
+    assert isinstance(jbody.value, tvm.tir.Add)
     value = jbody.value
-    assert isinstance(value.a, tvm.expr.Call)
+    assert isinstance(value.a, tvm.tir.Call)
     assert value.a.name == 'sigma'
     assert len(value.a.args) == 1
     assert value.a.args[0].value == 0
@@ -215,17 +215,17 @@ def test_fanout():
     assert len(value.b.args) == 1
     assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
     divide= rbody[2]
-    assert isinstance(divide, tvm.stmt.Provide)
+    assert isinstance(divide, tvm.tir.Provide)
     assert len(divide.args) == 1
     assert divide.args[0].value == 0
     value = divide.value
-    assert isinstance(value, tvm.expr.Mul)
+    assert isinstance(value, tvm.tir.Mul)
     assert value.a.name == 'sigma'
     assert len(value.a.args) == 1
     assert value.a.args[0].value == 0
     assert abs(value.b.value - (1 / 3.0)) < 1e-5
     write = rbody[3]
-    assert isinstance(write, tvm.stmt.Provide)
+    assert isinstance(write, tvm.tir.Provide)
     assert write.func.name == 'b'
     assert write.value.name == 'sigma'
     assert len(write.value.args) == 1
@@ -260,9 +260,9 @@ def test_looptype():
     iloop = ir[0]
     jloop = ir[1]
     kloop = ir[2]
-    assert iloop.for_type == tvm.stmt.For.Parallel
-    assert jloop.for_type == tvm.stmt.For.Vectorized
-    assert kloop.for_type == tvm.stmt.For.Unrolled
+    assert iloop.for_type == tvm.tir.For.Parallel
+    assert jloop.for_type == tvm.tir.For.Vectorized
+    assert kloop.for_type == tvm.tir.For.Unrolled
 
     func, ins, outs = run_and_check(looptype, [a, b, c])
     run_and_check(func, ins, outs=outs)
@@ -364,7 +364,7 @@ def test_bind():
     c = foo(a)
     s = tvm.create_schedule(c.op)
     ir = tvm.lower(s, [a, c], simple_mode=True)
-    assert not isinstance(ir, tvm.stmt.AttrStmt)
+    assert not isinstance(ir, tvm.tir.AttrStmt)
     func, ins, outs = run_and_check(foo, [a], target='cuda')
     run_and_check(func, ins, outs=outs, target='cuda')
 
@@ -729,20 +729,20 @@ def test_schedule():
     sch[c].vectorize(ji)
     sch[c].reorder(ii, io, joo, joi, ji)
     ir = tvm.lower(sch, [a, b, c], simple_mode=True)
-    assert isinstance(ir, tvm.stmt.ProducerConsumer)
+    assert isinstance(ir, tvm.tir.ProducerConsumer)
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.AttrStmt)
+    assert isinstance(ir, tvm.tir.AttrStmt)
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'i.inner'
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'i.outer'
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'j.outer.outer'
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'j.outer.inner'
     ir = ir.body
     func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
@@ -752,11 +752,11 @@ def test_schedule():
     sch = tvm.create_schedule(c.op)
     sch[c].fuse(c.op.axis[0], c.op.axis[1])
     ir = tvm.lower(sch, [a, b, c], simple_mode=True)
-    assert isinstance(ir, tvm.stmt.ProducerConsumer)
+    assert isinstance(ir, tvm.tir.ProducerConsumer)
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.AttrStmt)
+    assert isinstance(ir, tvm.tir.AttrStmt)
     ir = ir.body
-    assert isinstance(ir, tvm.stmt.For)
+    assert isinstance(ir, tvm.tir.For)
     assert ir.loop_var.name == 'i.j.fused'
     func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
     run_and_check(func, ins, outs=outs)
index 7486629..5679625 100644 (file)
@@ -28,14 +28,14 @@ def test_for():
 
     body = ib.get()
     print(body)
-    assert isinstance(body, tvm.stmt.AttrStmt)
+    assert isinstance(body, tvm.tir.AttrStmt)
     body = body.body
-    assert isinstance(body, tvm.stmt.Allocate)
+    assert isinstance(body, tvm.tir.Allocate)
     body = body.body
-    assert isinstance(body, tvm.stmt.For)
+    assert isinstance(body, tvm.tir.For)
     body = body.body
-    assert isinstance(body, tvm.stmt.SeqStmt)
-    assert isinstance(body[1], tvm.stmt.For)
+    assert isinstance(body, tvm.tir.SeqStmt)
+    assert isinstance(body[1], tvm.tir.For)
 
 def test_if():
     ib = tvm.ir_builder.create()
@@ -50,11 +50,11 @@ def test_if():
 
     body = ib.get()
     assert A == A
-    assert isinstance(body, tvm.stmt.For)
+    assert isinstance(body, tvm.tir.For)
     body = body.body
-    assert isinstance(body, tvm.stmt.IfThenElse)
-    assert isinstance(body.condition, tvm.expr.EQ)
-    assert isinstance(body.then_case.index, tvm.expr.Var)
+    assert isinstance(body, tvm.tir.IfThenElse)
+    assert isinstance(body.condition, tvm.tir.EQ)
+    assert isinstance(body.then_case.index, tvm.tir.Var)
     assert body.else_case.index.value == 0
 
 def test_prefetch():
@@ -64,10 +64,10 @@ def test_prefetch():
 
     with ib.for_range(0, n, name="i") as i:
         ib.emit(
-            tvm.make.Prefetch(
+            tvm.tir.Prefetch(
                 A.op, A.value_index, A.dtype,
-                [tvm.make.range_by_min_extent(i+1, 2),
-                 tvm.make.range_by_min_extent(0, 20)]))
+                [tvm.ir.Range.make_by_min_extent(i+1, 2),
+                 tvm.ir.Range.make_by_min_extent(0, 20)]))
     body = ib.get()
     assert body.body.bounds[0].extent.value == 2
 
index 0015d6d..7339925 100644 (file)
@@ -22,7 +22,7 @@ def test_const():
     x = tvm.const(1, "int32")
     print(x.dtype)
     assert x.dtype == tvm.int32
-    assert isinstance(x, tvm.expr.IntImm)
+    assert isinstance(x, tvm.tir.IntImm)
 
 
 def test_scalar_dtype_inference():
@@ -45,47 +45,47 @@ def test_make():
     x = tvm.const(1, "int32")
     y = tvm.var("x")
     z = x + y
-    assert isinstance(tvm.max(x, y), tvm.expr.Max)
-    assert isinstance(tvm.min(x, y), tvm.expr.Min)
+    assert isinstance(tvm.max(x, y), tvm.tir.Max)
+    assert isinstance(tvm.min(x, y), tvm.tir.Min)
 
 
 def test_ir():
     x = tvm.const(1, "int32")
-    y = tvm.make.IntImm('int32', 1)
+    y = tvm.tir.IntImm('int32', 1)
     z = x + y
-    stmt = tvm.make.Evaluate(z)
-    assert isinstance(stmt, tvm.stmt.Evaluate)
+    stmt = tvm.tir.Evaluate(z)
+    assert isinstance(stmt, tvm.tir.Evaluate)
 
 
 def test_ir2():
     x = tvm.var("n")
     a = tvm.var("array", tvm.handle)
-    st = tvm.make.Store(a, x + 1, 1)
-    assert isinstance(st, tvm.stmt.Store)
+    st = tvm.tir.Store(a, x + 1, 1)
+    assert isinstance(st, tvm.tir.Store)
     assert(st.buffer_var == a)
 
 
 def test_let():
     x = tvm.var('x')
     y = tvm.var('y')
-    stmt = tvm.make.LetStmt(
-        x, 10, tvm.make.Evaluate(x + 1));
+    stmt = tvm.tir.LetStmt(
+        x, 10, tvm.tir.Evaluate(x + 1));
 
 
 def test_cast():
     x = tvm.var('x', dtype="float32")
     y = x.astype("int32")
     z = x.astype("float32x4")
-    assert isinstance(y, tvm.expr.Cast)
-    assert isinstance(z, tvm.expr.Broadcast)
+    assert isinstance(y, tvm.tir.Cast)
+    assert isinstance(z, tvm.tir.Broadcast)
     assert z.lanes == 4
 
 
 def test_attr():
     x = tvm.var('x')
     y = tvm.var('y')
-    stmt = tvm.make.AttrStmt(
-        y, "stride", 10, tvm.make.Evaluate(x + 1));
+    stmt = tvm.tir.AttrStmt(
+        y, "stride", 10, tvm.tir.Evaluate(x + 1));
     assert stmt.node == y
 
     a = tvm.convert(1)
@@ -105,9 +105,9 @@ def test_basic():
 
 
 def test_stmt():
-    x = tvm.make.Evaluate(0)
-    tvm.make.For(tvm.var('i'), 0, 1,
-                 tvm.stmt.For.Serial, 0,
+    x = tvm.tir.Evaluate(0)
+    tvm.tir.For(tvm.var('i'), 0, 1,
+                 tvm.tir.For.Serial, 0,
                  x)
 
 
@@ -207,7 +207,7 @@ def test_equality():
 
 def test_equality_string_imm():
     x = 'a'
-    y = tvm.make.StringImm(x)
+    y = tvm.tir.StringImm(x)
     x == y.value
     x == y
 
index c418785..4ce7e87 100644 (file)
 import tvm
 
 def test_expr_constructor():
-    x = tvm.expr.Var("xx", "float32")
-    assert isinstance(x, tvm.expr.Var)
+    x = tvm.tir.Var("xx", "float32")
+    assert isinstance(x, tvm.tir.Var)
     assert x.name == "xx"
 
-    x = tvm.expr.Reduce(None, [1],
+    x = tvm.tir.Reduce(None, [1],
                         [tvm.api._IterVar((0, 1), "x", 2)],
                         None, 0)
-    assert isinstance(x, tvm.expr.Reduce)
+    assert isinstance(x, tvm.tir.Reduce)
     assert x.combiner == None
     assert x.value_index == 0
 
-    x = tvm.expr.FloatImm("float32", 1.0)
-    assert isinstance(x, tvm.expr.FloatImm)
+    x = tvm.tir.FloatImm("float32", 1.0)
+    assert isinstance(x, tvm.tir.FloatImm)
     assert x.value == 1.0
     assert x.dtype == "float32"
 
-    x = tvm.expr.IntImm("int64", 2)
-    assert isinstance(x, tvm.expr.IntImm)
+    x = tvm.tir.IntImm("int64", 2)
+    assert isinstance(x, tvm.tir.IntImm)
     assert x.value == 2
     assert x.dtype == "int64"
 
-    x = tvm.expr.StringImm("xyza")
-    assert isinstance(x, tvm.expr.StringImm)
+    x = tvm.tir.StringImm("xyza")
+    assert isinstance(x, tvm.tir.StringImm)
     assert x.value == "xyza"
 
-    x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1))
-    assert isinstance(x, tvm.expr.Cast)
+    x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1))
+    assert isinstance(x, tvm.tir.Cast)
     assert x.dtype == "float32"
     assert x.value.value == 1
 
     a = tvm.const(1.0, dtype="float32")
     b = tvm.var("x", dtype="float32")
 
-    for cls in [tvm.expr.Add,
-                tvm.expr.Sub,
-                tvm.expr.Mul,
-                tvm.expr.Div,
-                tvm.expr.Mod,
-                tvm.expr.Min,
-                tvm.expr.Max,
-                tvm.expr.LT,
-                tvm.expr.LE,
-                tvm.expr.GT,
-                tvm.expr.GE]:
+    for cls in [tvm.tir.Add,
+                tvm.tir.Sub,
+                tvm.tir.Mul,
+                tvm.tir.Div,
+                tvm.tir.Mod,
+                tvm.tir.Min,
+                tvm.tir.Max,
+                tvm.tir.LT,
+                tvm.tir.LE,
+                tvm.tir.GT,
+                tvm.tir.GE]:
         x = cls(a, b)
         assert isinstance(x, cls)
         assert x.a == a
@@ -70,58 +70,58 @@ def test_expr_constructor():
     a = tvm.convert(tvm.var("x") > 1)
     b = tvm.convert(tvm.var("x") == 1)
 
-    for cls in [tvm.expr.And,
-                tvm.expr.Or]:
+    for cls in [tvm.tir.And,
+                tvm.tir.Or]:
         x = cls(a, b)
         assert isinstance(x, cls)
         assert x.a == a
         assert x.b.same_as(b)
 
-    x = tvm.expr.Not(a)
-    assert isinstance(x, tvm.expr.Not)
+    x = tvm.tir.Not(a)
+    assert isinstance(x, tvm.tir.Not)
     assert x.a == a
 
-    x = tvm.expr.Select(a, a, b)
-    assert isinstance(x, tvm.expr.Select)
+    x = tvm.tir.Select(a, a, b)
+    assert isinstance(x, tvm.tir.Select)
     assert x.true_value == a
     assert x.false_value == b
     assert x.condition == a
 
     buffer_var = tvm.var("x", dtype="handle")
-    x = tvm.expr.Load("float32", buffer_var, 1, a)
-    assert isinstance(x, tvm.expr.Load)
+    x = tvm.tir.Load("float32", buffer_var, 1, a)
+    assert isinstance(x, tvm.tir.Load)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
     assert x.index.value == 1
     assert x.predicate == a
 
-    x = tvm.expr.Ramp(1, 2, 10)
-    assert isinstance(x, tvm.expr.Ramp)
+    x = tvm.tir.Ramp(1, 2, 10)
+    assert isinstance(x, tvm.tir.Ramp)
     assert x.base.value == 1
     assert x.stride.value == 2
     assert x.lanes == 10
 
-    x = tvm.expr.Broadcast(a, 10)
-    assert isinstance(x, tvm.expr.Broadcast)
+    x = tvm.tir.Broadcast(a, 10)
+    assert isinstance(x, tvm.tir.Broadcast)
     assert x.value == a
     assert x.lanes == 10
 
-    x = tvm.expr.Shuffle([a], [0])
-    assert isinstance(x, tvm.expr.Shuffle)
+    x = tvm.tir.Shuffle([a], [0])
+    assert isinstance(x, tvm.tir.Shuffle)
     assert x.vectors[0] == a
     assert x.indices[0].value == 0
 
-    x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0)
-    assert isinstance(x, tvm.expr.Call)
+    x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0)
+    assert isinstance(x, tvm.tir.Call)
     assert x.dtype == "float32"
     assert x.name == "xyz"
     assert x.args[0] == a
-    assert x.call_type == tvm.expr.Call.Extern
+    assert x.call_type == tvm.tir.Call.Extern
     assert x.func == None
     assert x.value_index == 0
 
     v = tvm.var("aa")
-    x = tvm.expr.Let(v, 1, v)
+    x = tvm.tir.Let(v, 1, v)
     assert x.var == v
     assert x.value.value == 1
     assert x.body == v
@@ -130,75 +130,75 @@ def test_expr_constructor():
 def test_stmt_constructor():
     v = tvm.var("aa")
     buffer_var = tvm.var("buf", dtype="handle")
-    nop = tvm.stmt.Evaluate(1)
-    x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1))
-    assert isinstance(x, tvm.stmt.LetStmt)
+    nop = tvm.tir.Evaluate(1)
+    x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
+    assert isinstance(x, tvm.tir.LetStmt)
     assert x.var == v
     assert x.value.value == 1
-    assert isinstance(x.body, tvm.stmt.Evaluate)
+    assert isinstance(x.body, tvm.tir.Evaluate)
 
-    x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1))
-    assert isinstance(x, tvm.stmt.AttrStmt)
+    x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1))
+    assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"),
+    x = tvm.tir.AssertStmt(tvm.const(1, "uint1"),
                             tvm.convert("hellow"),
                             nop)
-    assert isinstance(x, tvm.stmt.AssertStmt)
+    assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
-    x = tvm.stmt.ProducerConsumer(None, True, nop)
-    assert isinstance(x, tvm.stmt.ProducerConsumer)
+    x = tvm.tir.ProducerConsumer(None, True, nop)
+    assert isinstance(x, tvm.tir.ProducerConsumer)
     assert x.body == nop
 
-    x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop)
-    assert isinstance(x, tvm.stmt.For)
+    x = tvm.tir.For(tvm.var("x"), 0, 10, 0, 0, nop)
+    assert isinstance(x, tvm.tir.For)
     assert x.min.value == 0
     assert x.extent.value == 10
     assert x.body == nop
 
-    x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
-    assert isinstance(x, tvm.stmt.Store)
+    x = tvm.tir.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
+    assert isinstance(x, tvm.tir.Store)
     assert x.buffer_var == buffer_var
     assert x.index.value == 10
     assert x.value.value == 1
 
     tensor = tvm.placeholder((), dtype="float32")
-    x = tvm.stmt.Provide(tensor.op, 0, 10, [])
-    assert isinstance(x, tvm.stmt.Provide)
+    x = tvm.tir.Provide(tensor.op, 0, 10, [])
+    assert isinstance(x, tvm.tir.Provide)
     assert x.value_index == 0
     assert x.value.value == 10
 
-    x = tvm.stmt.Allocate(buffer_var, "float32", [10],
+    x = tvm.tir.Allocate(buffer_var, "float32", [10],
                           tvm.const(1, "uint1"), nop)
-    assert isinstance(x, tvm.stmt.Allocate)
+    assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
     assert x.body == nop
 
-    x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop)
-    assert isinstance(x, tvm.stmt.AttrStmt)
+    x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
+    assert isinstance(x, tvm.tir.AttrStmt)
     assert x.node == buffer_var
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.stmt.Free(buffer_var)
-    assert isinstance(x, tvm.stmt.Free)
+    x = tvm.tir.Free(buffer_var)
+    assert isinstance(x, tvm.tir.Free)
     assert x.buffer_var == buffer_var
 
-    x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
-    assert isinstance(x, tvm.stmt.Realize)
+    x = tvm.tir.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
+    assert isinstance(x, tvm.tir.Realize)
     assert x.body == nop
 
-    x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"),
-                            tvm.stmt.Evaluate(11),
+    x = tvm.tir.IfThenElse(tvm.const(1, "uint1"),
+                            tvm.tir.Evaluate(11),
                             nop)
-    assert isinstance(x, tvm.stmt.IfThenElse)
+    assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
 
-    x = tvm.stmt.Prefetch(None, 1, "float32", [])
-    assert isinstance(x, tvm.stmt.Prefetch)
+    x = tvm.tir.Prefetch(None, 1, "float32", [])
+    assert isinstance(x, tvm.tir.Prefetch)
     assert x.value_index == 1
 
 
index 4f8a93b..0b9fad9 100644 (file)
@@ -69,7 +69,7 @@ def test_map_save_load_json():
 def test_in_container():
     arr = tvm.convert(['a', 'b', 'c'])
     assert 'a' in arr
-    assert tvm.make.StringImm('a') in arr
+    assert tvm.tir.StringImm('a') in arr
     assert 'd' not in arr
 
 def test_ndarray_container():
index cde4a81..4c1cafc 100644 (file)
@@ -20,9 +20,9 @@ import tvm
 from topi.util import get_const_tuple
 
 def test_layout():
-    layout = tvm.layout("NCHW16c")
+    layout = tvm.tir.layout("NCHW16c")
     assert layout is not None
-    assert isinstance(layout, tvm.tensor.Layout)
+    assert isinstance(layout, tvm.tir.Layout)
 
     assert layout.factor_of("c") == 16
     assert layout.factor_of("C") == 16
@@ -63,7 +63,7 @@ def test_bilayout_convertible():
 
 def test_bilayout_shape():
     bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
-    assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
+    assert isinstance(bilayout, tvm.tir.BijectiveLayout)
 
     dst_shape = bilayout.forward_shape((1, 32, 7, 7))
     assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
index 26783e6..d32b4c5 100644 (file)
@@ -29,7 +29,7 @@ def test_const_fold():
     def check(f, *args):
         x = f(*[tvm.const(x, "int32") for x in args])
         y = f(*args)
-        if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y):
+        if not isinstance(x, (tvm.tir.IntImm,)) or x.value != int(y):
             raise ValueError("check error: %s vs %s " % (x, y))
 
     tmod = tvm.truncmod
@@ -56,7 +56,7 @@ def test_const_fold2():
     assert tmod(x, 1).value == 0
     assert (x * 1).same_as(x)
     assert (1 * x).same_as(x)
-    assert isinstance(tdiv(1, x), tvm.expr.Div)
+    assert isinstance(tdiv(1, x), tvm.tir.Div)
 
 def test_const_fold3():
     # Test that using ints with logic operations is forbidden
@@ -92,17 +92,17 @@ def test_const_fold4():
     x1 = tvm.const(4, "int32")
     x2 = x1 + 5
     tdiv = tvm.truncdiv
-    assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
+    assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9
     x3 = tdiv(x2, 3)
-    assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
+    assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
     x4 = x3 + 0.55
-    assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
+    assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
     x5 = tvm.ceil(x4)
-    assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
+    assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
     x6 = x5.astype('int')
-    assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6)
+    assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
     y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
-    assert isinstance(y, tvm.expr.IntImm) and y.value == 6
+    assert isinstance(y, tvm.tir.IntImm) and y.value == 6
 
 
 def test_binary_dtype_match():
index b971e38..e97e73a 100644 (file)
@@ -31,7 +31,7 @@ def test_make_smap():
     # save load json
     x = tvm.const(1, "int32")
     y = tvm.const(10, "int32")
-    z = tvm.expr.Add(x, y)
+    z = tvm.tir.Add(x, y)
     smap = tvm.convert({"z": z, "x": x})
     json_str = tvm.ir.save_json(tvm.convert([smap]))
     arr = tvm.ir.load_json(json_str)
@@ -40,11 +40,11 @@ def test_make_smap():
 
 
 def test_make_node():
-    x = tvm.make.node("IntImm", dtype="int32", value=10)
-    assert isinstance(x, tvm.expr.IntImm)
+    x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
+    assert isinstance(x, tvm.tir.IntImm)
     assert x.value == 10
     A = tvm.placeholder((10, ), name='A')
-    AA = tvm.make.node("Tensor",
+    AA = tvm.ir.make_node("Tensor",
                        shape=A.shape,
                        dtype=A.dtype,
                        op=A.op,
@@ -55,25 +55,25 @@ def test_make_node():
 
 def test_make_attrs():
     try:
-        x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
+        x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
         assert False
     except tvm.error.TVMError as e:
         assert str(e).find("unknown_key") != -1
 
     try:
-        x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
+        x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
         assert False
     except tvm.error.TVMError as e:
         assert str(e).find("upper bound") != -1
 
-    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
     assert x.name == "xx"
     assert x.padding[0].value == 3
     assert x.padding[1].value == 4
     assert x.axis == 10
 
 
-    dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
+    dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
     assert dattr.x.value == 1
     datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
     assert dattr.name.value == "xyz"
@@ -104,7 +104,7 @@ def test_env_func():
     assert y(1) == 2
     assert y.func(1) == 2
 
-    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
     assert x.name == "xx"
     assert x.padding[0].value == 3
     assert x.padding[1].value == 4
index 6b5b7fa..10843f9 100644 (file)
@@ -240,7 +240,7 @@ def test_tensor_intrin_scalar_params():
     C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
     s = tvm.create_schedule(C.op)
     stmt = tvm.lower(s, [A, C], simple_mode=True)
-    assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
+    assert isinstance(stmt.body.body.body, tvm.tir.Evaluate)
     assert len(stmt.body.body.body.value.args) == 5
     assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
     assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
index a8b5fc0..2de5e19 100644 (file)
@@ -128,7 +128,7 @@ def test_tensor_compute1():
 
     s = tvm.create_schedule(C.op)
     stmt = tvm.lower(s, [A, B, C], simple_mode=True)
-    assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
+    assert isinstance(stmt.body.body, tvm.tir.Evaluate)
 
 def test_tensor_compute2():
     M = 2048
@@ -171,8 +171,8 @@ def test_tensor_compute2():
 
     s = tvm.create_schedule(C.op)
     stmt = tvm.lower(s, [A, B, C], simple_mode=True)
-    assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate)
-    assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate)
+    assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate)
+    assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate)
 
 def test_tensor_scan():
     m = tvm.size_var("m")
@@ -259,7 +259,7 @@ def test_tuple_with_different_deps():
     stmt = tvm.schedule.ScheduleOps(sch, bounds)
 
     def get_B1_realize(x):
-        if isinstance(x, tvm.stmt.Realize) and \
+        if isinstance(x, tvm.tir.Realize) and \
            x.func == B1.op and x.value_index == 1:
             ret.append(x)
     ret = []
index 3ccd6f9..98fdeaa 100644 (file)
@@ -29,8 +29,8 @@ def test_operator_type_and_tags():
     B1 = B[0]
     B2 = B[0,0]
 
-    assert isinstance(k + n, tvm.expr.PrimExpr)
-    assert isinstance(n + n, tvm.expr.PrimExpr)
+    assert isinstance(k + n, tvm.tir.PrimExpr)
+    assert isinstance(n + n, tvm.tir.PrimExpr)
     assert isinstance(k + A, tvm.tensor.Tensor)
     assert isinstance(A + k, tvm.tensor.Tensor)
     assert isinstance(n + A, tvm.tensor.Tensor)
@@ -53,11 +53,11 @@ def test_operator_type_and_tags():
     assert (B + A).op.tag == topi.tag.BROADCAST
     assert (B + B).op.tag == topi.tag.BROADCAST
 
-    assert isinstance(k + B2, tvm.expr.PrimExpr)
-    assert isinstance(B2 + k, tvm.expr.PrimExpr)
-    assert isinstance(n + B2, tvm.expr.PrimExpr)
-    assert isinstance(B2 + n, tvm.expr.PrimExpr)
-    assert isinstance(B2 + B2, tvm.expr.PrimExpr)
+    assert isinstance(k + B2, tvm.tir.PrimExpr)
+    assert isinstance(B2 + k, tvm.tir.PrimExpr)
+    assert isinstance(n + B2, tvm.tir.PrimExpr)
+    assert isinstance(B2 + n, tvm.tir.PrimExpr)
+    assert isinstance(B2 + B2, tvm.tir.PrimExpr)
     assert isinstance(B2 + A, tvm.tensor.Tensor)
     assert isinstance(A + B2, tvm.tensor.Tensor)
     assert isinstance(B2 + B, tvm.tensor.Tensor)
index bb4c196..2bd94e0 100644 (file)
 import tvm
 
 def test_attrs_equal():
-    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
-    y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
-    z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1))
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
+    y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
+    z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
     assert tvm.ir_pass.AttrsEqual(x, y)
     assert not tvm.ir_pass.AttrsEqual(x, z)
 
-    dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
+    dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
     assert not tvm.ir_pass.AttrsEqual(dattr, x)
-    dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
+    dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
     assert tvm.ir_pass.AttrsEqual(dattr, dattr2)
 
     assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y})
@@ -42,8 +42,8 @@ def test_attrs_equal():
 
 def test_attrs_hash():
     fhash = tvm.ir_pass.AttrsHash
-    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
-    y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
+    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
+    y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     assert fhash({"x": x}) == fhash({"x": y})
     assert fhash({"x": x}) != fhash({"x": [y, 1]})
     assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]})
index 8f35611..93c815a 100644 (file)
@@ -31,16 +31,16 @@ def test_simplify():
 def test_verify_ssa():
     x = tvm.var('x')
     y = tvm.var()
-    z = tvm.make.Evaluate(x + y)
+    z = tvm.tir.Evaluate(x + y)
     assert(tvm.ir_pass.VerifySSA(z))
 
 
 def test_convert_ssa():
     x = tvm.var('x')
     y = tvm.var()
-    let1 = tvm.make.Let(x, 1, x + 1)
-    let2 = tvm.make.Let(x, 1, x + y)
-    z = tvm.make.Evaluate(let1 + let2)
+    let1 = tvm.tir.Let(x, 1, x + 1)
+    let2 = tvm.tir.Let(x, 1, x + y)
+    z = tvm.tir.Evaluate(let1 + let2)
     assert(not tvm.ir_pass.VerifySSA(z))
     z_ssa = tvm.ir_pass.ConvertSSA(z)
     assert(tvm.ir_pass.VerifySSA(z_ssa))
index e62e539..6b959e0 100644 (file)
@@ -166,12 +166,12 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
 
 def test_in_bounds_const_loop_partition_ir():
     def check_attr_stmt (x):
-        if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
+        if isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
             return True
         return False
 
     def check_branch_stmt (x):
-        if isinstance(x, tvm.stmt.IfThenElse):
+        if isinstance(x, tvm.tir.IfThenElse):
             return True
         return False
 
@@ -183,7 +183,7 @@ def test_in_bounds_const_loop_partition_ir():
         assert (count == nums)
 
     def collect_branch_stmt (x):
-        if isinstance(x, tvm.stmt.IfThenElse):
+        if isinstance(x, tvm.tir.IfThenElse):
             branch_collector.append(x)
 
     n = 21
index a25568f..ef741a4 100644 (file)
@@ -20,8 +20,8 @@ def test_for():
     dev_type = tvm.var("dev_type")
     def device_context(dev_id):
         ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
-        return tvm.make.Call(
-            "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
+        return tvm.tir.Call(
+            "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0)
 
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
index 9ffd565..b464354 100644 (file)
@@ -33,7 +33,7 @@ def test_decorate_device():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt1 = tvm.ir_pass.Simplify(stmt)
     stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
-    assert isinstance(stmt2, tvm.stmt.AttrStmt)
+    assert isinstance(stmt2, tvm.tir.AttrStmt)
     assert stmt2.attr_key == "device_scope"
     assert stmt1 == stmt2.body
 
index 4a28cf6..2eb641b 100644 (file)
@@ -24,19 +24,19 @@ def verify_structure(stmt, expected_struct):
     struct = {}
     def _extract_vars(op):
         global var_list
-        if isinstance(op, tvm.expr.Var):
+        if isinstance(op, tvm.tir.Var):
             var_list.append(op.name)
 
     def _visit(op):
         key = op
-        if isinstance(op, tvm.stmt.IfThenElse):
+        if isinstance(op, tvm.tir.IfThenElse):
             global var_list
             tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars)
             val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
             var_list.clear()
-        elif isinstance(op, tvm.stmt.For):
+        elif isinstance(op, tvm.tir.For):
             val = [(op.body,), ("For", op.loop_var.name)]
-        elif isinstance(op, tvm.stmt.AttrStmt):
+        elif isinstance(op, tvm.tir.AttrStmt):
             val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
         else:
             return
@@ -61,9 +61,9 @@ def test_basic():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope(ib.likely(i < 2)):
-                    ib.emit(tvm.make.Evaluate(m))
+                    ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
-                    ib.emit(tvm.make.Evaluate(n))
+                    ib.emit(tvm.tir.Evaluate(n))
 
     stmt = ib.get()
     new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
@@ -82,7 +82,7 @@ def test_no_else():
         with ib.for_range(0, m, "j") as j:
             with ib.for_range(0, n, "k") as k:
                 with ib.if_scope(ib.likely(i < 2)):
-                    ib.emit(tvm.make.Evaluate(m))
+                    ib.emit(tvm.tir.Evaluate(m))
 
     stmt = ib.get()
     new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
index 858b1e8..f49388d 100644 (file)
@@ -33,7 +33,7 @@ def test_copy2d():
         assert dst.strides[1].value == 1
         assert src.strides[0] == l
         assert tuple(src.shape) == (m, l)
-        return tvm.make.Evaluate(0)
+        return tvm.tir.Evaluate(0)
     stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
 
 def test_copy_pad():
@@ -57,7 +57,7 @@ def test_copy_pad():
         assert pad_after[0].value == 1
         assert pad_after[1].value == 0
         assert pad_value.value == 1.0
-        return tvm.make.Evaluate(0)
+        return tvm.tir.Evaluate(0)
     stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
 
 def test_single_point_test():
@@ -76,7 +76,7 @@ def test_single_point_test():
         assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
         assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
         assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
-        return tvm.make.Evaluate(0)
+        return tvm.tir.Evaluate(0)
     stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
 
 def assert_expr_equal(a, b):
@@ -109,7 +109,7 @@ def test_copy_pad_split():
         assert_expr_equal(pad_before[0], rpad_before)
         assert_expr_equal(pad_after[0], rpad_after)
         assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
-        return tvm.make.Evaluate(0)
+        return tvm.tir.Evaluate(0)
     stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
 
 
index aa569ce..cf8f78c 100644 (file)
@@ -37,13 +37,13 @@ def test_double_buffer():
     stmt = ib.get()
     stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert isinstance(stmt.body.body, tvm.stmt.Allocate)
+    assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
     f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.ir_pass.ThreadSync(f, "shared")
     count = [0]
     def count_sync(op):
-        if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
+        if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
             count[0] += 1
     tvm.ir_pass.PostOrderVisit(f.body, count_sync)
     assert count[0] == 4
index e87353e..521a6f9 100644 (file)
@@ -20,7 +20,7 @@ def test_inline():
     m = tvm.size_var('m')
     A = tvm.placeholder((m,), name='A')
     T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
-    stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
+    stmt = tvm.tir.Evaluate(T[10] + 11 * T[100])
     stmt = tvm.ir_pass.Inline(
         stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
     print(stmt)
@@ -39,11 +39,11 @@ def test_inline2():
     m = tvm.size_var('m')
     A = tvm.placeholder((m,), name='A')
     T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
-    stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
+    stmt = tvm.tir.Evaluate(tvm.exp(T[10]) + 11 * T[100])
     stmt = tvm.ir_pass.Inline(
         stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
     def check(op):
-        if isinstance(op, tvm.expr.Call):
+        if isinstance(op, tvm.tir.Call):
             assert op.func != T.op
     tvm.ir_pass.PostOrderVisit(stmt, check)
 
index 098e0d7..b024a3c 100644 (file)
@@ -32,12 +32,12 @@ def test_ir_transform():
         return None
 
     def postorder(op):
-        assert isinstance(op, tvm.expr.Call)
+        assert isinstance(op, tvm.tir.Call)
         if op.name == "TestA":
             return tvm.call_extern("int32", "TestB", op.args[0] + 1)
         return op
     body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
-    stmt_list = tvm.make.stmt_list(body.body.body)
+    stmt_list = tvm.tir.stmt_list(body.body.body)
     assert stmt_list[0].value.args[0].name == "TestB"
     assert stmt_list[1].value.value == 0
 
index b281e17..181f4ef 100644 (file)
@@ -20,7 +20,7 @@ def test_coproc_lift():
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
     cp = tvm.thread_axis((0, 1), "cop")
-    value = tvm.make.StringImm("xxx")
+    value = tvm.tir.StringImm("xxx")
 
     A = ib.allocate("float32", n, name="A", scope="global")
     with ib.for_range(0, n, name="i") as i:
index 9812660..e9df98e 100644 (file)
@@ -24,7 +24,7 @@ def collect_visit(stmt, f):
 
 def find_top_produce(stmt):
     def f(x, ret):
-        if isinstance(x, tvm.stmt.ProducerConsumer):
+        if isinstance(x, tvm.tir.ProducerConsumer):
             ret.append(x)
     ret = []
     tvm.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret))
@@ -90,13 +90,13 @@ def test_multi_loop():
         with ib.for_range(0, n, "j") as j:
             with ib.for_range(0, m, "k") as k:
                 with ib.if_scope(ib.likely(i*m+j+k < n)):
-                    ib.emit(tvm.make.Evaluate(m))
+                    ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
-                    ib.emit(tvm.make.Evaluate(n))
+                    ib.emit(tvm.tir.Evaluate(n))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, False)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_multi_if():
     ib = tvm.ir_builder.create()
@@ -106,13 +106,13 @@ def test_multi_if():
         with ib.for_range(0, n, 'j') as j:
             with ib.for_range(0, m, 'k') as k:
                 with ib.if_scope(ib.likely(i*m+j+k < n)):
-                    ib.emit(tvm.make.Evaluate(m))
+                    ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
-                    ib.emit(tvm.make.Evaluate(n))
+                    ib.emit(tvm.tir.Evaluate(n))
                 with ib.if_scope(ib.likely(i*m+j-k < n)):
-                    ib.emit(tvm.make.Evaluate(m))
+                    ib.emit(tvm.tir.Evaluate(m))
                 with ib.else_scope():
-                    ib.emit(tvm.make.Evaluate(n))
+                    ib.emit(tvm.tir.Evaluate(n))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, False)
     stmt = tvm.ir_pass.Simplify(stmt)
@@ -157,7 +157,7 @@ def test_vectorize():
     stmt = lower(s, [A, B])
     body = stmt.body.body.body.body.body
     assert(x.var.name not in str(body.condition))
-    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
+    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
 
 def test_condition():
     ib = tvm.ir_builder.create()
@@ -165,24 +165,24 @@ def test_condition():
     n = tvm.size_var('n')
     with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i:
       with ib.for_range(0, 4, 'j') as j:
-        ib.emit(tvm.make.Evaluate(
-          tvm.make.Select(ib.likely(i*4+j<n), m, n)))
+        ib.emit(tvm.tir.Evaluate(
+          tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, False)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select))))
+    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
 
 def test_condition_EQ():
     ib = tvm.ir_builder.create()
     m = tvm.size_var('m')
     n = tvm.size_var('n')
     with ib.for_range(0, 10, 'i') as i:
-            ib.emit(tvm.make.Evaluate(
-                tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
+            ib.emit(tvm.tir.Evaluate(
+                tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select))))
+    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))
 
 def test_thread_axis2():
     n = tvm.convert(4096)
@@ -209,11 +209,11 @@ def test_everything_during_deduction():
         with ib.for_range(0, 32, 'j') as j:
             with ib.if_scope(ib.likely(tvm.truncdiv(i,j) < m)):
                 # this guard will produce everything during deduction
-                ib.emit(tvm.make.Evaluate(m))
+                ib.emit(tvm.tir.Evaluate(m))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, False)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
+    assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))
 
 def test_single_likely():
     n = 60
@@ -229,7 +229,7 @@ def test_single_likely():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_multi_likely():
     n = 94
@@ -250,7 +250,7 @@ def test_multi_likely():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_oneD_pool():
     m = tvm.size_var('m')
@@ -277,7 +277,7 @@ def test_oneD_pool():
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_cce_loop_1():
   ib = tvm.ir_builder.create()
@@ -298,7 +298,7 @@ def test_cce_loop_1():
   stmt = ib.get()
   stmt = tvm.ir_pass.LoopPartition(stmt, True)
   stmt = tvm.ir_pass.Simplify(stmt)
-  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_cce_loop_2():
   ib = tvm.ir_builder.create()
@@ -317,7 +317,7 @@ def test_cce_loop_2():
   stmt = ib.get()
   stmt = tvm.ir_pass.LoopPartition(stmt, True)
   stmt = tvm.ir_pass.Simplify(stmt)
-  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 
 def test_cce_loop_3():
@@ -335,7 +335,7 @@ def test_cce_loop_3():
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt,True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 def test_conv_tiling():
     HSTR = WSTR = 1
@@ -364,7 +364,7 @@ def test_conv_tiling():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
 
 def test_multilevel_splitting_with_indivisble_factors():
@@ -382,7 +382,7 @@ def test_multilevel_splitting_with_indivisble_factors():
     with tvm.build_config(partition_const_loop=True):
         lowered_body = tvm.lower(s, [A, B]).body
         def visit_stmt(op):
-            return(isinstance(op, tvm.expr.Max))
+            return(isinstance(op, tvm.tir.Max))
         num_max = collect_visit(lowered_body, visit_stmt)
         assert num_max.count(True) == 10
 
@@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors():
     # Find the beginning of the Halide IR corresponding to kernel code
     # and make sure it doesn't have an if statements left
     top_produce = find_top_produce(f.body)
-    assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
     # check functional correctness of generated code
     ctx = tvm.context(target, 0)
index e8bc6b1..1e54f38 100644 (file)
@@ -19,8 +19,8 @@ import numpy as np
 
 def lower_intrin(stmt):
     """wrapper to call transformation in stmt"""
-    lower_expr = isinstance(stmt, tvm.expr.PrimExpr)
-    stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt
+    lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
+    stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
     stmt = tvm.ir_pass.CanonicalSimplify(stmt)
     stmt  = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm")
     return stmt.value if lower_expr else stmt.body
@@ -33,8 +33,8 @@ def check_value(expr, vx, vy, data, fref):
 
     def make_binds(i):
         x = expr
-        x = tvm.expr.Let(vx, A[i], x)
-        x = tvm.expr.Let(vy, B[i], x)
+        x = tvm.tir.Let(vx, A[i], x)
+        x = tvm.tir.Let(vy, B[i], x)
         return x
 
     C = tvm.compute((n,), make_binds)
@@ -72,13 +72,13 @@ def test_lower_floordiv():
         res = lower_intrin(tvm.floordiv(x, y))
         check_value(res, x, y, data, lambda a, b: a // b)
         # rhs >= 0
-        res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floordiv(x, y), zero))
+        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.floordiv(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
         # involves max
-        res = lower_intrin(tvm.expr.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero))
+        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero))
         check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero))
+        res = lower_intrin(tvm.tir.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
         # const power of two
         res = lower_intrin(tvm.floordiv(x, tvm.const(8, dtype=dtype)))
@@ -95,10 +95,10 @@ def test_lower_floormod():
         res = lower_intrin(tvm.floormod(x, y))
         check_value(res, x, y, data, lambda a, b: a % b)
         # rhs >= 0
-        res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floormod(x, y), zero))
+        res = lower_intrin(tvm.tir.Select(y >= 0, tvm.floormod(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
         # lhs >= 0
-        res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero))
+        res = lower_intrin(tvm.tir.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero))
         check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
         # const power of two
         res = lower_intrin(tvm.floormod(x, tvm.const(8, dtype=dtype)))
index d287b85..a3927f7 100644 (file)
@@ -17,7 +17,7 @@
 import tvm
 
 def nop():
-    return tvm.stmt.Evaluate(0)
+    return tvm.tir.Evaluate(0)
 
 def test_remove_no_op():
     i = tvm.var('i')
@@ -27,25 +27,25 @@ def test_remove_no_op():
     n = tvm.var('n')
     dtype = 'int64'
     Ab = tvm.decl_buffer((n, ), dtype)
-    stmt = tvm.make.For(
+    stmt = tvm.tir.For(
         i, 0, 4, 0, 0,
-        tvm.make.For(
+        tvm.tir.For(
             j, 0, n, 0, 0,
-            tvm.make.For(
+            tvm.tir.For(
                 k, 0, m, 0, 0,
-                tvm.make.IfThenElse(
-                    (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
+                tvm.tir.IfThenElse(
+                    (i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
     ret = tvm.ir_pass.RemoveNoOp(stmt)
-    assert(isinstance(ret, tvm.stmt.Evaluate))
-    store = tvm.make.Store(Ab.data,
-                           tvm.make.Load(dtype, Ab.data, i) + 1,
+    assert(isinstance(ret, tvm.tir.Evaluate))
+    store = tvm.tir.Store(Ab.data,
+                           tvm.tir.Load(dtype, Ab.data, i) + 1,
                            i + 1)
-    stmt2 = tvm.stmt.SeqStmt([nop(), tvm.stmt.SeqStmt([store, nop()])])
+    stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
     assert(tvm.ir_pass.RemoveNoOp(stmt2) == store)
     # remove zero extent loop
-    stmt3 = tvm.make.For(i, 0, 0, 0, 0, store)
+    stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
     ret = tvm.ir_pass.RemoveNoOp(stmt3)
-    assert(isinstance(ret, tvm.stmt.Evaluate))
+    assert(isinstance(ret, tvm.tir.Evaluate))
 
 
 if __name__ == "__main__":
index 4c42899..dc6ae82 100644 (file)
@@ -21,18 +21,18 @@ def test_rewrite_Select():
     ib = tvm.ir_builder.create()
     A = ib.allocate("float32", 100, name="A", scope="global")
     i = tvm.var("i")
-    y = tvm.expr.Select(i > 1, A[i-1], 1.0)
-    yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value
+    y = tvm.tir.Select(i > 1, A[i-1], 1.0)
+    yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value
 
-    z = tvm.expr.Select(
-        tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
-    zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
+    z = tvm.tir.Select(
+        tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
+    zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value
 
-    a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
-    aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
+    a = tvm.tir.Select(tvm.floordiv(i, 4) > 10, y, z)
+    aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value
     assert yy.name == "tvm_if_then_else"
     assert zz.name == "tvm_if_then_else"
-    assert isinstance(aa, tvm.expr.Select)
+    assert isinstance(aa, tvm.tir.Select)
 
 
 if __name__ == "__main__":
index 2bee66c..47a43c7 100644 (file)
@@ -40,12 +40,12 @@ def test_flatten_prefetch():
     _A= tvm.decl_buffer(A.shape, A.dtype, name = 'A');
     i = tvm.size_var('i')
     j = tvm.size_var('j')
-    region = [tvm.make.range_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
-    stmt = tvm.make.Prefetch(A.op, 0, A.dtype, region)
+    region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
+    stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region)
     stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
     stmt = tvm.ir_pass.Simplify(stmt)
     assert stmt.extent.value == 2
-    assert isinstance(stmt.body, tvm.stmt.For)
+    assert isinstance(stmt.body, tvm.tir.For)
     assert stmt.body.extent.value == 2
 
 
@@ -89,13 +89,13 @@ def test_flatten_double_buffer():
     stmt = tvm.ir_pass.StorageFlatten(stmt, {}, 64)
     stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
     stmt = tvm.ir_pass.Simplify(stmt)
-    assert isinstance(stmt.body.body, tvm.stmt.Allocate)
+    assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
     f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.ir_pass.ThreadSync(f, "shared")
     count = [0]
     def count_sync(op):
-        if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
+        if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
             count[0] += 1
     tvm.ir_pass.PostOrderVisit(f.body, count_sync)
     assert count[0] == 4
index 6fd6f8b..d4125d0 100644 (file)
@@ -39,7 +39,7 @@ def test_storage_share():
     # verify inplace folding works
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 1
@@ -48,7 +48,7 @@ def register_mem(scope_tb, max_bits):
     #Register mem
     @tvm.register_func("tvm.info.mem.%s" % scope_tb)
     def mem_info_inp_buffer():
-        return tvm.make.node("MemoryInfo",
+        return tvm.ir.make_node("MemoryInfo",
                         unit_bits= 16,
                         max_simd_bits=32,
                         max_num_bits=max_bits,
@@ -74,7 +74,7 @@ def test_alloc_seq():
     body = tvm.ir_pass.StorageRewrite(body)
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 200
     tvm.ir_pass.PostOrderVisit(body, verify)
@@ -123,7 +123,7 @@ def test_alloc_different_dtypes():
 
     def dtype_test(dtype_list, length):
         def verify(n):
-            if isinstance(n, tvm.stmt.Allocate):
+            if isinstance(n, tvm.tir.Allocate):
                 assert n.extents[0].value == offset
 
         body = stmt_generater(dtype_list, length)
@@ -166,7 +166,7 @@ def test_inplace_rule():
     # verify inplace folding works
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 2
@@ -196,7 +196,7 @@ def test_storage_combine():
     stmt = tvm.ir_pass.StorageRewrite(stmt)
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert (n.extents[0].value == 16)
     tvm.ir_pass.PostOrderVisit(stmt, verify)
@@ -231,7 +231,7 @@ def test_storage_share_gpu():
     alloc_stats = {"global": 0, "shared": 0}
 
     def verify(n):
-        if isinstance(n, tvm.stmt.AttrStmt):
+        if isinstance(n, tvm.tir.AttrStmt):
             if n.attr_key == "storage_scope":
                 alloc_stats[n.value.value] += 1
     tvm.ir_pass.PostOrderVisit(stmt, verify)
@@ -248,14 +248,14 @@ def test_parallel_alloc():
 
     body = ib.get()
     body = tvm.ir_pass.StorageRewrite(body)
-    assert (isinstance(body.body.body, tvm.stmt.Allocate))
+    assert (isinstance(body.body.body, tvm.tir.Allocate))
 
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
     with ib.for_range(0, n, name="t") as i:
         ib.scope_attr(
             tvm.const(1, "int32") , "pragma_scope",
-            tvm.make.StringImm("parallel_launch_point"))
+            tvm.tir.StringImm("parallel_launch_point"))
         with ib.for_range(0, n, name="i", for_type="parallel") as i:
             with ib.for_range(0, 10, name="j") as j:
                 A = ib.allocate("float32", n, name="A", scope="global")
@@ -263,7 +263,7 @@ def test_parallel_alloc():
     body = ib.get()
     body = tvm.ir_pass.StorageRewrite(body)
 
-    assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
+    assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
 
 def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
     #Test Buffer
@@ -295,7 +295,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
     # verify inplace folding works
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 2
@@ -387,7 +387,7 @@ def test_inplace_rule3():
     # verify only have one allocations.
     # verify inplace folding works
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             assert n.extents[0].value == 70
     tvm.ir_pass.PostOrderVisit(stmt, verify)
 
@@ -413,7 +413,7 @@ def test_alloc_seq_type():
     body = tvm.ir_pass.StorageRewrite(body)
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 500
     tvm.ir_pass.PostOrderVisit(body, verify)
@@ -442,7 +442,7 @@ def test_alloc_seq_type2():
     body = tvm.ir_pass.StorageRewrite(body)
     num_alloc = [0]
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 200
     tvm.ir_pass.PostOrderVisit(body, verify)
@@ -473,7 +473,7 @@ def test_reuse_small_buffer():
     num_alloc = [0]
 
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             num_alloc[0] += 1
             assert n.extents[0].value == 800
     tvm.ir_pass.PostOrderVisit(body, verify)
@@ -512,7 +512,7 @@ def test_large_input():
     s = tvm.create_schedule(c.op)
     stmt = tvm.lower(s, [a, b, c], simple_mode=True)
     def verify(n):
-        if isinstance(n, tvm.stmt.Allocate):
+        if isinstance(n, tvm.tir.Allocate):
             assert n.extents[0].value == 268435456
     tvm.ir_pass.PostOrderVisit(stmt, verify)
 
index 55596ee..0ed0c99 100644 (file)
@@ -40,14 +40,14 @@ def test_storage_sync():
     flist = tvm.ir_pass.SplitHostDevice(f)
     f = flist[1]
     f = tvm.ir_pass.ThreadSync(f, "shared")
-    body_list = tvm.make.stmt_list(f.body.body.body.body)
+    body_list = tvm.tir.stmt_list(f.body.body.body.body)
     assert(body_list[1].value.name == "tvm_storage_sync")
 
 
 def test_coproc_sync():
     @tvm.register_func("tvm.info.mem.global.cache")
     def meminfo_cache():
-        return tvm.make.node(
+        return tvm.ir.make_node(
             "MemoryInfo",
             unit_bits=8,
             max_simd_bits=32,
@@ -66,7 +66,7 @@ def test_coproc_sync():
     stmt = ib.get()
     stmt = tvm.ir_pass.CoProcSync(stmt)
     body = stmt.body.body.body
-    blist = tvm.make.stmt_list(body)
+    blist = tvm.tir.stmt_list(body)
     assert(blist[1].value.name == "cop.coproc_read_barrier")
     assert(blist[1].value.args[3].value == 80)
     assert(blist[-2].value.name == "cop.coproc_sync")
@@ -119,9 +119,9 @@ def test_coproc_sync3():
 
     stmt = ib.get()
     stmt = tvm.ir_pass.CoProcSync(stmt)
-    slist = tvm.make.stmt_list(stmt[0].body.body)
+    slist = tvm.tir.stmt_list(stmt[0].body.body)
     push_st = slist[2]
-    slist = tvm.make.stmt_list(slist[-1])
+    slist = tvm.tir.stmt_list(slist[-1])
     pop_st = slist[0].body[0]
 
     assert(push_st.value.name == "cop.coproc_dep_push")
index e5ef9d0..c6b536b 100644 (file)
@@ -30,26 +30,26 @@ def test_unroll_loop():
             Aptr[j + 1] = Aptr[i] + 1
 
     stmt = ib.get()
-    assert isinstance(stmt, tvm.stmt.For)
+    assert isinstance(stmt, tvm.tir.For)
     ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True)
-    assert not isinstance(ret, tvm.stmt.For)
+    assert not isinstance(ret, tvm.tir.For)
     ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True)
-    assert isinstance(ret, tvm.stmt.For)
+    assert isinstance(ret, tvm.tir.For)
     ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False)
-    assert isinstance(ret, tvm.stmt.For)
-    assert ret.for_type == tvm.stmt.For.Unrolled
+    assert isinstance(ret, tvm.tir.For)
+    assert ret.for_type == tvm.tir.For.Unrolled
 
     ib = tvm.ir_builder.create()
     ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16)
     ib.emit(stmt)
     wrapped = ib.get()
-    wrapped = tvm.stmt.SeqStmt([wrapped, stmt])
-    assert isinstance(ret, tvm.stmt.For)
+    wrapped = tvm.tir.SeqStmt([wrapped, stmt])
+    assert isinstance(ret, tvm.tir.For)
     ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
-    assert isinstance(ret[0], tvm.stmt.For)
-    assert ret[0].for_type == tvm.stmt.For.Unrolled
-    assert isinstance(ret[1], tvm.stmt.For)
-    assert ret[1].for_type != tvm.stmt.For.Unrolled
+    assert isinstance(ret[0], tvm.tir.For)
+    assert ret[0].for_type == tvm.tir.For.Unrolled
+    assert isinstance(ret[1], tvm.tir.For)
+    assert ret[1].for_type != tvm.tir.For.Unrolled
 
 def test_unroll_fake_loop():
     ib = tvm.ir_builder.create()
@@ -65,7 +65,7 @@ def test_unroll_fake_loop():
 
     stmt = ib.get()
     ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
-    assert isinstance(ret[0], tvm.stmt.Store)
+    assert isinstance(ret[0], tvm.tir.Store)
 
 def test_unroll_single_count_loops():
     n = tvm.size_var('n')
index fca22a1..d1cd2d4 100644 (file)
@@ -26,12 +26,12 @@ def test_vectorize_loop():
             A[j] = tvm.const(1, A.dtype)
     stmt = ib.get()
 
-    assert isinstance(stmt.body, tvm.stmt.For)
+    assert isinstance(stmt.body, tvm.tir.For)
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.For)
-    assert not isinstance(stmt.body, tvm.stmt.For)
-    assert isinstance(stmt.body.index, tvm.expr.Ramp)
-    assert isinstance(stmt.body.value, tvm.expr.Broadcast)
+    assert isinstance(stmt, tvm.tir.For)
+    assert not isinstance(stmt.body, tvm.tir.For)
+    assert isinstance(stmt.body.index, tvm.tir.Ramp)
+    assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
 def test_vectorize_vector():
     dtype = 'int64'
@@ -42,12 +42,12 @@ def test_vectorize_vector():
         with ib.for_range(0, 4, for_type="vectorize") as j:
             A[j] = tvm.const(1, A.dtype)
     stmt = ib.get()
-    assert isinstance(stmt.body, tvm.stmt.For)
+    assert isinstance(stmt.body, tvm.tir.For)
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.For)
-    assert not isinstance(stmt.body, tvm.stmt.For)
-    assert isinstance(stmt.body.index, tvm.expr.Ramp)
-    assert isinstance(stmt.body.value, tvm.expr.Broadcast)
+    assert isinstance(stmt, tvm.tir.For)
+    assert not isinstance(stmt.body, tvm.tir.For)
+    assert isinstance(stmt.body.index, tvm.tir.Ramp)
+    assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
 
 def test_vectorize_with_if():
@@ -63,11 +63,11 @@ def test_vectorize_with_if():
                 A[i] = 2.0
     stmt = ib.get()
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.IfThenElse)
-    assert isinstance(stmt.then_case.index, tvm.expr.Ramp)
-    assert isinstance(stmt.then_case.value, tvm.expr.Add)
+    assert isinstance(stmt, tvm.tir.IfThenElse)
+    assert isinstance(stmt.then_case.index, tvm.tir.Ramp)
+    assert isinstance(stmt.then_case.value, tvm.tir.Add)
     assert stmt.then_case.value.dtype == "float32x4"
-    assert isinstance(stmt.else_case, tvm.stmt.For)
+    assert isinstance(stmt.else_case, tvm.tir.For)
 
 def test_vectorize_with_le_cond():
     n = tvm.var('n')
@@ -78,7 +78,7 @@ def test_vectorize_with_le_cond():
             A[i] = A[i] + 1
     stmt = ib.get()
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.For)
+    assert isinstance(stmt, tvm.tir.For)
 
 def test_vectorize_with_ge_cond():
     n = tvm.var('n')
@@ -89,7 +89,7 @@ def test_vectorize_with_ge_cond():
             A[i] = A[i] + 1
     stmt = ib.get()
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.For)
+    assert isinstance(stmt, tvm.tir.For)
 
 def test_vectorize_if_then_else():
     n = tvm.var('n')
@@ -102,7 +102,7 @@ def test_vectorize_if_then_else():
                                A[i] + 1, A[i])
     stmt = ib.get()
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert isinstance(stmt, tvm.stmt.For)
+    assert isinstance(stmt, tvm.tir.For)
 
 
     ib = tvm.ir_builder.create()
@@ -113,10 +113,10 @@ def test_vectorize_if_then_else():
                                            k > 0,
                                            A[k * 4 + i], 0)
     stmt = ib.get()
-    assert isinstance(stmt.body, tvm.stmt.For)
+    assert isinstance(stmt.body, tvm.tir.For)
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
-    assert not isinstance(stmt.body, tvm.stmt.For)
-    assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast)
+    assert not isinstance(stmt.body, tvm.tir.For)
+    assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)
 
 
 if __name__ == "__main__":
index b1a784b..1cbc157 100644 (file)
@@ -50,10 +50,10 @@ def test_dso_module_load():
         Ab = tvm.decl_buffer((n, ), dtype)
         i = tvm.var('i')
         # for i in 0 to n-1:
-        stmt = tvm.make.For(
+        stmt = tvm.tir.For(
             i, 0, n - 1, 0, 0,
-            tvm.make.Store(Ab.data,
-                           tvm.make.Load(dtype, Ab.data, i) + 1,
+            tvm.tir.Store(Ab.data,
+                           tvm.tir.Load(dtype, Ab.data, i) + 1,
                            i + 1))
         fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
         fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
index 1030a99..2fc84bb 100644 (file)
@@ -79,8 +79,8 @@ def test_schedule_scan():
 
 def test_inline_multi_reduce():
     def argmax_comp(x, y):
-        idx = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
-        val = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
+        idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+        val = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
         return idx, val
     def argmax_init(idx_typ, val_typ):
         return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
@@ -142,7 +142,7 @@ def test_inline_mixed():
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     def check(x):
-        if isinstance(x, tvm.expr.Call):
+        if isinstance(x, tvm.tir.Call):
             assert x.func != A2
     tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check)
 
@@ -426,7 +426,7 @@ def test_loop_dep_reduce_cache_write():
     X = tvm.placeholder(shape=(10,), name="x")
     def f(n):
         rv = tvm.reduce_axis((0, n))
-        init = lambda dtype: tvm.expr.Select(n > 1, tvm.const(0, dtype), n.astype(dtype))
+        init = lambda dtype: tvm.tir.Select(n > 1, tvm.const(0, dtype), n.astype(dtype))
         sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum')
         return sum(X[rv], axis=rv)
     Y = tvm.compute(X.shape, f, name="y")
index 59adf0c..ac60c2d 100644 (file)
@@ -321,7 +321,7 @@ def test_tensorize_tensor_compute_op():
     stmt = tvm.schedule.ScheduleOps(s, dom_map)
     # The loop that we tried to tensorize still exists in the code
     # That means tensorize didn't work as expected
-    assert isinstance(stmt.body.body.body, tvm.stmt.For)
+    assert isinstance(stmt.body.body.body, tvm.tir.For)
     assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name
 
 
index 84c9da3..c3e1a10 100644 (file)
@@ -429,7 +429,9 @@ def cast(x, dtype):
     if isinstance(x, tvm.tensor.Tensor):
         return tvm.compute(
             x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
-    return tvm.make._cast(dtype, x)
+    # pylint: disable=import-outside-toplevel
+    from tvm.tir import _ffi_api
+    return _ffi_api._cast(dtype, x)
 
 
 def reinterpret(x, dtype):
index 702f551..53ffe35 100644 (file)
@@ -85,7 +85,7 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
         min_value = lambda dtype: tvm.if_then_else(non_empty, tvm.min_value(dtype),
                                                    tvm.const(0.0, dtype))
         # pylint: disable=unnecessary-lambda
-        _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max')
+        _max = tvm.comm_reducer(lambda x, y: tvm.max(x, y), min_value, name='max')
         rh = tvm.reduce_axis((0, hend - hstart), 'rh')
         rw = tvm.reduce_axis((0, wend - wstart), 'rw')
         return _max(data[batch_index, c, hstart+rh, wstart+rw], axis=[rh, rw])
index eb672bb..97c4a1f 100644 (file)
@@ -86,8 +86,8 @@ print(ir)
 loops = []
 def find_width8(op):
     """ Find all the 'For' nodes whose extent can be divided by 8. """
-    if isinstance(op, tvm.stmt.For):
-        if isinstance(op.extent, tvm.expr.IntImm):
+    if isinstance(op, tvm.tir.For):
+        if isinstance(op.extent, tvm.tir.IntImm):
             if op.extent.value % 8 == 0:
                 loops.append(op)
 
@@ -113,8 +113,8 @@ def vectorize8(op):
         name = op.loop_var.name
         lo, li = tvm.var(name + '.outer'), tvm.var(name + '.inner')
         body = tvm.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li})
-        body = tvm.make.For(li, 0, 8, tvm.stmt.For.Vectorized, 0, body)
-        body = tvm.make.For(lo, 0, extent // 8, tvm.stmt.For.Serial, 0, body)
+        body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body)
+        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body)
         return body
     return None
 
index f338e70..c1af984 100644 (file)
@@ -96,7 +96,7 @@ print(fopencl.imported_modules[0].get_source())
 #
 def my_cuda_math_rule(op):
     """Customized CUDA intrinsic lowering rule"""
-    assert isinstance(op, tvm.expr.Call)
+    assert isinstance(op, tvm.tir.Call)
     if op.dtype == "float32":
         # call float function
         return tvm.call_pure_extern("float32", "%sf" % op.name, op.args[0])
@@ -106,7 +106,7 @@ def my_cuda_math_rule(op):
     else:
         # cannot do translation, return self.
         return op
-tvm.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True)
+tvm.target.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True)
 ######################################################################
 # Register the rule to TVM with override option to override existing rule.
 # Notice the difference between the printed code from previous one:
@@ -135,7 +135,7 @@ def my_cuda_mylog_rule(op):
         return tvm.call_pure_extern("float64", "log", op.args[0])
     else:
         return op
-tvm.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
+tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
 
 n = tvm.var("n")
 A = tvm.placeholder((n,), name='A')
index 8fb8083..0c5c85c 100644 (file)
@@ -61,8 +61,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
 # x and y are the operands of reduction, both of them is a tuple of index
 # and value.
 def fcombine(x, y):
-    lhs = tvm.expr.Select((x[1] >= y[1]), x[0], y[0])
-    rhs = tvm.expr.Select((x[1] >= y[1]), x[1], y[1])
+    lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+    rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
     return lhs, rhs
 
 # our identity element also need to be a tuple, so `fidentity` accepts
index df67faa..f368362 100644 (file)
@@ -68,7 +68,7 @@ def build_config(debug_flag=0, **kwargs):
             env.dev.command_handle,
             debug_flag)
 
-        return tvm.make.stmt_seq(debug, stmt)
+        return tvm.tir.stmt_seq(debug, stmt)
     pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
                  (1, ir_pass.inject_dma_intrin),
                  (1, ir_pass.inject_skip_copy),
index 83db612..8d58958 100644 (file)
@@ -62,11 +62,11 @@ class DevContext(object):
 
     def __init__(self, env):
         self.vta_axis = tvm.thread_axis("vta")
-        self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
+        self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
         ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
-        self.command_handle = tvm.make.Call(
+        self.command_handle = tvm.tir.Call(
             "handle", "tvm_thread_context", [ctx],
-            tvm.expr.Call.Intrinsic, None, 0)
+            tvm.tir.Call.Intrinsic, None, 0)
         self.DEBUG_NO_SYNC = False
         env._dev_ctx = self
         self.gemm = intrin.gemm(env, env.mock_mode)
@@ -256,29 +256,29 @@ def get_env():
 @tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
 def mem_info_inp_buffer():
     spec = get_env()
-    return tvm.make.node("MemoryInfo",
-                         unit_bits=spec.INP_ELEM_BITS,
-                         max_simd_bits=spec.INP_ELEM_BITS,
-                         max_num_bits=spec.INP_BUFF_SIZE * 8,
-                         head_address=None)
+    return tvm.ir.make_node("MemoryInfo",
+                            unit_bits=spec.INP_ELEM_BITS,
+                            max_simd_bits=spec.INP_ELEM_BITS,
+                            max_num_bits=spec.INP_BUFF_SIZE * 8,
+                            head_address=None)
 
 @tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
 def mem_info_wgt_buffer():
     spec = get_env()
-    return tvm.make.node("MemoryInfo",
-                         unit_bits=spec.WGT_ELEM_BITS,
-                         max_simd_bits=spec.WGT_ELEM_BITS,
-                         max_num_bits=spec.WGT_BUFF_SIZE * 8,
-                         head_address=None)
+    return tvm.ir.make_node("MemoryInfo",
+                            unit_bits=spec.WGT_ELEM_BITS,
+                            max_simd_bits=spec.WGT_ELEM_BITS,
+                            max_num_bits=spec.WGT_BUFF_SIZE * 8,
+                            head_address=None)
 
 @tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
 def mem_info_acc_buffer():
     spec = get_env()
-    return tvm.make.node("MemoryInfo",
-                         unit_bits=spec.ACC_ELEM_BITS,
-                         max_simd_bits=spec.ACC_ELEM_BITS,
-                         max_num_bits=spec.ACC_BUFF_SIZE * 8,
-                         head_address=None)
+    return tvm.ir.make_node("MemoryInfo",
+                            unit_bits=spec.ACC_ELEM_BITS,
+                            max_simd_bits=spec.ACC_ELEM_BITS,
+                            max_num_bits=spec.ACC_BUFF_SIZE * 8,
+                            head_address=None)
 
 # TVM related registration
 @tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
index 77c7ff2..a43fc75 100644 (file)
@@ -98,7 +98,7 @@ def gemm(env, mock=False):
                     0, 0, 0))
             return irb.get()
         # return a triple of normal-set, reset, update
-        nop = tvm.make.Evaluate(0)
+        nop = tvm.tir.Evaluate(0)
         if mock:
             return (nop, nop, nop)
         return (instr(0), instr(1), instr(2))
index e42e3a0..8b8a2f0 100644 (file)
@@ -59,8 +59,8 @@ def fold_uop_loop(stmt_in):
 
     def _fold_outermost_loop(body):
         stmt = body
-        while not isinstance(stmt, tvm.stmt.For):
-            if isinstance(stmt, (tvm.stmt.ProducerConsumer,)):
+        while not isinstance(stmt, tvm.tir.For):
+            if isinstance(stmt, (tvm.tir.ProducerConsumer,)):
                 stmt = stmt.body
             else:
                 return None, body, None
@@ -70,7 +70,7 @@ def fold_uop_loop(stmt_in):
         fail = [False]
 
         def _post_order(op):
-            assert isinstance(op, tvm.expr.Call)
+            assert isinstance(op, tvm.tir.Call)
             base_args = 2
             if op.name == "VTAUopPush":
                 args = []
@@ -112,7 +112,7 @@ def fold_uop_loop(stmt_in):
 
     def _do_fold(stmt):
         if (stmt.attr_key == "coproc_uop_scope" and
-                isinstance(stmt.value, tvm.expr.StringImm) and
+                isinstance(stmt.value, tvm.tir.StringImm) and
                 stmt.value.value == env.dev.vta_push_uop.value):
             body = stmt.body
             begins = []
@@ -133,8 +133,8 @@ def fold_uop_loop(stmt_in):
             if body == stmt.body:
                 return stmt
             ends = list(reversed(ends))
-            body = tvm.stmt.stmt_seq(*(begins + [body] + ends))
-            return tvm.make.AttrStmt(
+            body = tvm.tir.stmt_seq(*(begins + [body] + ends))
+            return tvm.tir.AttrStmt(
                 stmt.node, stmt.attr_key, stmt.value, body)
         return None
     out = tvm.ir_pass.IRTransform(
@@ -163,40 +163,40 @@ def cpu_access_rewrite(stmt_in):
     env = get_env()
     rw_info = {}
     def _post_order(op):
-        if isinstance(op, tvm.stmt.Allocate):
+        if isinstance(op, tvm.tir.Allocate):
             buffer_var = op.buffer_var
             if not buffer_var in rw_info:
                 return None
             new_var = rw_info[buffer_var]
-            let_stmt = tvm.make.LetStmt(
+            let_stmt = tvm.tir.LetStmt(
                 new_var, tvm.call_extern(
                     "handle", "VTABufferCPUPtr",
                     env.dev.command_handle,
                     buffer_var), op.body)
-            alloc = tvm.make.Allocate(
+            alloc = tvm.tir.Allocate(
                 buffer_var, op.dtype, op.extents,
                 op.condition, let_stmt)
             del rw_info[buffer_var]
             return alloc
-        if isinstance(op, tvm.expr.Load):
+        if isinstance(op, tvm.tir.Load):
             buffer_var = op.buffer_var
             if not buffer_var in rw_info:
                 rw_info[buffer_var] = tvm.var(
                     buffer_var.name + "_ptr", "handle")
             new_var = rw_info[buffer_var]
-            return tvm.make.Load(op.dtype, new_var, op.index)
-        if isinstance(op, tvm.stmt.Store):
+            return tvm.tir.Load(op.dtype, new_var, op.index)
+        if isinstance(op, tvm.tir.Store):
             buffer_var = op.buffer_var
             if not buffer_var in rw_info:
                 rw_info[buffer_var] = tvm.var(
                     buffer_var.name + "_ptr", "handle")
             new_var = rw_info[buffer_var]
-            return tvm.make.Store(new_var, op.value, op.index)
+            return tvm.tir.Store(new_var, op.value, op.index)
         raise RuntimeError("not reached")
     stmt = tvm.ir_pass.IRTransform(
         stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
     for buffer_var, new_var in rw_info.items():
-        stmt = tvm.make.LetStmt(
+        stmt = tvm.tir.LetStmt(
             new_var, tvm.call_extern(
                 "handle", "VTABufferCPUPtr",
                 env.dev.command_handle,
@@ -222,15 +222,15 @@ def lift_alloc_to_scope_begin(stmt_in):
         for op in slist:
             if op.body == body:
                 body = op
-            elif isinstance(op, tvm.stmt.Allocate):
-                body = tvm.make.Allocate(
+            elif isinstance(op, tvm.tir.Allocate):
+                body = tvm.tir.Allocate(
                     op.buffer_var, op.dtype,
                     op.extents, op.condition, body)
-            elif isinstance(op, tvm.stmt.AttrStmt):
-                body = tvm.make.AttrStmt(
+            elif isinstance(op, tvm.tir.AttrStmt):
+                body = tvm.tir.AttrStmt(
                     op.node, op.attr_key, op.value, body)
-            elif isinstance(op, tvm.stmt.For):
-                body = tvm.make.For(
+            elif isinstance(op, tvm.tir.For):
+                body = tvm.tir.For(
                     op.loop_var, op.min, op.extent, op.for_type,
                     op.device_api, body)
             else:
@@ -239,24 +239,24 @@ def lift_alloc_to_scope_begin(stmt_in):
         return body
 
     def _pre_order(op):
-        if isinstance(op, tvm.stmt.For):
+        if isinstance(op, tvm.tir.For):
             lift_stmt.append([])
-        elif isinstance(op, tvm.stmt.AttrStmt):
+        elif isinstance(op, tvm.tir.AttrStmt):
             if op.attr_key == "virtual_thread":
                 lift_stmt.append([])
 
     def _post_order(op):
-        if isinstance(op, tvm.stmt.Allocate):
+        if isinstance(op, tvm.tir.Allocate):
             lift_stmt[-1].append(op)
             return op.body
-        if isinstance(op, tvm.stmt.AttrStmt):
+        if isinstance(op, tvm.tir.AttrStmt):
             if op.attr_key == "storage_scope":
                 lift_stmt[-1].append(op)
                 return op.body
             if op.attr_key == "virtual_thread":
                 return _merge_block(lift_stmt.pop() + [op], op.body)
             return op
-        if isinstance(op, tvm.stmt.For):
+        if isinstance(op, tvm.tir.For):
             return _merge_block(lift_stmt.pop() + [op], op.body)
         raise RuntimeError("not reached")
     stmt = tvm.ir_pass.IRTransform(
@@ -280,7 +280,7 @@ def inject_skip_copy(stmt_in):
     """
     def _do_fold(stmt):
         if _match_pragma(stmt, "skip_dma_copy"):
-            return tvm.make.Evaluate(0)
+            return tvm.tir.Evaluate(0)
         return None
     return tvm.ir_pass.IRTransform(
         stmt_in, _do_fold, None, ["AttrStmt"])
@@ -303,13 +303,13 @@ def inject_coproc_sync(stmt_in):
     def _do_fold(stmt):
         if _match_pragma(stmt, "coproc_sync"):
             success[0] = True
-            sync = tvm.make.Call(
-                "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
-            return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)])
+            sync = tvm.tir.Call(
+                "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
+            return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
         if _match_pragma(stmt, "trim_loop"):
             op = stmt.body
-            assert isinstance(op, tvm.stmt.For)
-            return tvm.make.For(
+            assert isinstance(op, tvm.tir.For)
+            return tvm.tir.For(
                 op.loop_var, op.min, 2, op.for_type,
                 op.device_api, op.body)
         return None
@@ -640,9 +640,9 @@ def inject_conv2d_transpose_skip(stmt_in):
     selects = []
 
     def _find_basics(op):
-        if isinstance(op, tvm.expr.Call):
+        if isinstance(op, tvm.tir.Call):
             calls.append(op)
-        elif isinstance(op, tvm.expr.Select):
+        elif isinstance(op, tvm.tir.Select):
             selects.append(op)
 
     def _do_fold(op):
@@ -665,7 +665,7 @@ def inject_conv2d_transpose_skip(stmt_in):
                 args = op.body.body.args
                 res_tensor = op.body.body.func.output(0)
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
-                inner = tvm.make.AttrStmt(
+                inner = tvm.tir.AttrStmt(
                     [dout, res_tensor], 'buffer_bind_scope',
                     tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
                 return inner
@@ -697,19 +697,19 @@ def inject_conv2d_transpose_skip(stmt_in):
                 args = conv_call.args
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
                        1, 0, 1, 0, env.BLOCK_OUT)
-                inner = tvm.make.AttrStmt(
+                inner = tvm.tir.AttrStmt(
                     [dout, res_tensor], 'buffer_bind_scope',
                     tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
                 args = kernel_call.args
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
                        1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
-                inner = tvm.make.AttrStmt(
+                inner = tvm.tir.AttrStmt(
                     [dwgt, kernel_tensor], 'buffer_bind_scope',
                     tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
                 args = data_call.args
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
                        1, 0, 1, 0, env.BLOCK_IN)
-                inner = tvm.make.AttrStmt(
+                inner = tvm.tir.AttrStmt(
                     [dinp, pad_data_tensor], 'buffer_bind_scope',
                     tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
                 return inner
@@ -739,11 +739,11 @@ def annotate_alu_coproc_scope(stmt_in):
             irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                            env.dev.get_task_qid(env.dev.QID_COMPUTE))
             irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
-                           tvm.make.StringImm("VTAPushALUOp"))
+                           tvm.tir.StringImm("VTAPushALUOp"))
             irb.emit(stmt)
             return irb.get()
         if _match_pragma(stmt, "skip_alu"):
-            return tvm.make.Evaluate(0)
+            return tvm.tir.Evaluate(0)
         return stmt
 
     stmt_out = tvm.ir_pass.IRTransform(
@@ -810,7 +810,7 @@ def inject_alu_intrin(stmt_in):
             # Get to the innermost loop body
             loop_body = stmt.body
             nest_size = 0
-            while isinstance(loop_body, tvm.stmt.For):
+            while isinstance(loop_body, tvm.tir.For):
                 loop_body = loop_body.body
                 nest_size += 1
             # Get the src/dst arguments
@@ -825,27 +825,27 @@ def inject_alu_intrin(stmt_in):
                 extents.append(tmp_body.extent)
                 tmp_body = tmp_body.body
             # Derive opcode
-            if isinstance(loop_body.value, tvm.expr.Add):
+            if isinstance(loop_body.value, tvm.tir.Add):
                 alu_opcode = env.dev.ALU_OPCODE_ADD
                 lhs = loop_body.value.a
                 rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.expr.Sub):
+            elif isinstance(loop_body.value, tvm.tir.Sub):
                 alu_opcode = env.dev.ALU_OPCODE_SUB
                 lhs = loop_body.value.a
                 rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.expr.Mul):
+            elif isinstance(loop_body.value, tvm.tir.Mul):
                 alu_opcode = env.dev.ALU_OPCODE_MUL
                 lhs = loop_body.value.a
                 rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.expr.Min):
+            elif isinstance(loop_body.value, tvm.tir.Min):
                 alu_opcode = env.dev.ALU_OPCODE_MIN
                 lhs = loop_body.value.a
                 rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.expr.Max):
+            elif isinstance(loop_body.value, tvm.tir.Max):
                 alu_opcode = env.dev.ALU_OPCODE_MAX
                 lhs = loop_body.value.a
                 rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.expr.Call):
+            elif isinstance(loop_body.value, tvm.tir.Call):
                 if loop_body.value.name == 'shift_left':
                     alu_opcode = env.dev.ALU_OPCODE_SHR
                     lhs = loop_body.value.args[0]
@@ -857,7 +857,7 @@ def inject_alu_intrin(stmt_in):
                 else:
                     raise RuntimeError(
                         "Function call not recognized %s" % (loop_body.value.name))
-            elif isinstance(loop_body.value, tvm.expr.Load):
+            elif isinstance(loop_body.value, tvm.tir.Load):
                 alu_opcode = env.dev.ALU_OPCODE_SHR
                 lhs = loop_body.value
                 rhs = tvm.const(0, "int32")
@@ -871,12 +871,12 @@ def inject_alu_intrin(stmt_in):
             # Check if lhs/rhs is immediate
             use_imm = False
             imm_val = None
-            if isinstance(rhs, tvm.expr.IntImm):
+            if isinstance(rhs, tvm.tir.IntImm):
                 assert lhs.buffer_var.same_as(dst_var)
                 src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
                 use_imm = True
                 imm_val = rhs
-            if isinstance(lhs, tvm.expr.IntImm):
+            if isinstance(lhs, tvm.tir.IntImm):
                 assert rhs.buffer_var.same_as(dst_var)
                 src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
                 use_imm = True