[ARITH] Use explicit div mode in python. (#4014)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 27 Sep 2019 14:45:25 +0000 (07:45 -0700)
committerGitHub <noreply@github.com>
Fri, 27 Sep 2019 14:45:25 +0000 (07:45 -0700)
46 files changed:
docs/api/python/tvm.rst
python/tvm/api.py
python/tvm/contrib/nnpack.py
python/tvm/expr.py
python/tvm/generic.py
python/tvm/hybrid/parser.py
python/tvm/relay/op/_transform.py
src/api/api_ir.cc
src/contrib/hybrid/codegen_hybrid.cc
src/contrib/hybrid/codegen_hybrid.h
src/lang/expr_operator.cc
src/pass/lower_intrin.cc
tests/python/unittest/test_arith_canonical_simplify.py
tests/python/unittest/test_arith_const_int_bound.py
tests/python/unittest/test_arith_deduce_bound.py
tests/python/unittest/test_arith_intset.py
tests/python/unittest/test_arith_modular_set.py
tests/python/unittest/test_autotvm_flop_calculator.py
tests/python/unittest/test_build_lower.py
tests/python/unittest/test_codegen_llvm.py
tests/python/unittest/test_ir_builder.py
tests/python/unittest/test_lang_buffer.py
tests/python/unittest/test_lang_operator.py
tests/python/unittest/test_lang_tensor_overload_op.py
tests/python/unittest/test_pass_basic.py
tests/python/unittest/test_pass_equal.py
tests/python/unittest/test_pass_loop_partition.py
tests/python/unittest/test_schedule_bound_inference.py
tests/python/unittest/test_schedule_tensorize.py
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
topi/python/topi/arm_cpu/conv2d_transpose.py
topi/python/topi/arm_cpu/depthwise_conv2d.py
topi/python/topi/cuda/conv2d_transpose_nchw.py
topi/python/topi/cuda/conv2d_winograd.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/ssd/multibox.py
topi/python/topi/mali/conv2d.py
topi/python/topi/nn/bitserial_conv2d.py
topi/python/topi/nn/bitserial_dense.py
topi/python/topi/nn/conv2d.py
topi/python/topi/nn/depthwise_conv2d.py
topi/python/topi/nn/dilate.py
topi/python/topi/nn/flatten.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/dense.py
topi/python/topi/x86/depthwise_conv2d.py

index a04502b..b517195 100644 (file)
@@ -35,6 +35,13 @@ The user facing API for computation declaration.
    tvm.thread_axis
    tvm.comm_reducer
    tvm.sum
+   tvm.div
+   tvm.indexdiv
+   tvm.indexmod
+   tvm.truncdiv
+   tvm.truncmod
+   tvm.floordiv
+   tvm.floormod
    tvm.min
    tvm.max
    tvm.tag_scope
@@ -53,6 +60,13 @@ The user facing API for computation declaration.
 .. autofunction:: tvm.thread_axis
 .. autofunction:: tvm.comm_reducer
 .. autofunction:: tvm.sum
+.. autofunction:: tvm.div
+.. autofunction:: tvm.indexdiv
+.. autofunction:: tvm.indexmod
+.. autofunction:: tvm.truncdiv
+.. autofunction:: tvm.truncmod
+.. autofunction:: tvm.floordiv
+.. autofunction:: tvm.floormod
 .. autofunction:: tvm.min
 .. autofunction:: tvm.max
 .. autofunction:: tvm.tag_scope
index b54d364..e7523bd 100644 (file)
@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     reducer.__doc__ = doc_str.format(name)
     return reducer
 
+def div(a, b):
+    """Compute a / b as in C/C++ semantics.
+
+    Parameters
+    ----------
+    a : Expr
+        The left hand operand, known to be non-negative.
+
+    b : Expr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : Expr
+        The result expression.
+    Note
+    ----
+    When operands are integers, returns truncdiv(a, b).
+    """
+    return _make._OpDiv(a, b)
+
+
+def indexdiv(a, b):
+    """Compute floor(a / b) where a and b are non-negative.
+
+    Parameters
+    ----------
+    a : Expr
+        The left hand operand, known to be non-negative.
+
+    b : Expr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : Expr
+        The result expression.
+
+    Note
+    ----
+    Use this function to split non-negative indices.
+    This function may take advantage of operands'
+    non-negativeness.
+    """
+    return _make._OpIndexDiv(a, b)
+
+
+def indexmod(a, b):
+    """Compute the remainder of indexdiv. a and b are non-negative.
+
+    Parameters
+    ----------
+    a : Expr
+        The left hand operand, known to be non-negative.
+
+    b : Expr
+        The right hand operand, known to be non-negative.
+
+    Returns
+    -------
+    res : Expr
+        The result expression.
+
+    Note
+    ----
+    Use this function to split non-negative indices.
+    This function may take advantage of operands'
+    non-negativeness.
+    """
+    return _make._OpIndexMod(a, b)
+
 
 def truncdiv(a, b):
     """Compute the truncdiv of two expressions.
index c6530dd..aceab6d 100644 (file)
@@ -101,8 +101,11 @@ def convolution_inference(
     assert isinstance(stride, list) and len(stride) == 2
     batch, _, input_height, input_width = data.shape
     output_channels, _, kernel_height, kernel_width = kernel.shape
-    output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
-    output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
+    idxdiv = _api.indexdiv
+    output_height = idxdiv(
+        input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
+    output_width = idxdiv(
+        input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
 
     return _api.extern(
         (batch, output_channels, output_height, output_width),
@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
     batch, _, input_height, input_width = data.shape
     output_channels, _, _, _ = transformed_kernel.shape
     kernel_height, kernel_width = (3, 3)
-    output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
-    output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
+    idxdiv = _api.indexdiv
+    output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
+    output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
 
     return _api.extern(
         (batch, output_channels, output_height, output_width),
index 0d70ac9..a8bd651 100644 (file)
@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
 # pylint: disable=missing-docstring
 from __future__ import absolute_import as _abs
 from ._ffi.node import NodeBase, NodeGeneric, register_node
+from ._ffi.runtime_ctypes import TVMType, TypeCode
 from . import make as _make
 from . import generic as _generic
 from . import _api_internal
 
 
+def div_ambiguity_error():
+    return RuntimeError(
+        "TVM supports multiple types of integer divisions, " +
+        "please call div, indexdiv/indexmod, floordiv/floormod " +
+        " or truncdiv/truncmod directly to avoid ambiguity in the code.")
+
+def _dtype_is_int(value):
+    if isinstance(value, int):
+        return True
+    return (isinstance(value, ExprOp) and
+            TVMType(value.dtype).type_code == TypeCode.INT)
+
+
 class ExprOp(object):
     def __add__(self, other):
         return _generic.add(self, other)
@@ -58,24 +72,35 @@ class ExprOp(object):
         return _generic.multiply(other, self)
 
     def __div__(self, other):
+        # if _dtype_is_int(self) and _dtype_is_int(other):
+        #     raise div_ambiguity_error()
         return _generic.divide(self, other)
 
     def __rdiv__(self, other):
+        # if _dtype_is_int(self) and _dtype_is_int(other):
+        #     raise div_ambiguity_error()
         return _generic.divide(other, self)
 
     def __truediv__(self, other):
-        return self.__div__(other)
+        # if _dtype_is_int(self) and _dtype_is_int(other):
+        #     raise div_ambiguity_error()
+        return _generic.divide(self, other)
 
     def __rtruediv__(self, other):
-        return self.__rdiv__(other)
+        # if _dtype_is_int(self) and _dtype_is_int(other):
+        #     raise div_ambiguity_error()
+        return _generic.divide(other, self)
 
     def __floordiv__(self, other):
-        return self.__div__(other)
+        # return _generic.floordiv(self, other)
+        return _generic.divide(self, other)
 
     def __rfloordiv__(self, other):
-        return self.__rdiv__(other)
+        # return _generic.floordiv(other, self)
+        return _generic.divide(other, self)
 
     def __mod__(self, other):
+        # raise div_ambiguity_error()
         return _make._OpMod(self, other)
 
     def __neg__(self):
index 7f5c3d6..b7bea7f 100644 (file)
@@ -25,6 +25,7 @@ from . import make as _make
 #Operator precedence used when overloading.
 __op_priority__ = 0
 
+
 def add(lhs, rhs):
     """Generic add operator.
 
@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
     """
     return _make._OpMul(lhs, rhs)
 
-
 def divide(lhs, rhs):
     """Generic divide operator.
 
@@ -96,6 +96,23 @@ def divide(lhs, rhs):
     """
     return _make._OpDiv(lhs, rhs)
 
+def floordiv(lhs, rhs):
+    """Generic floordiv operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of divide operaton.
+    """
+    return _make._OpFloorDiv(lhs, rhs)
+
 
 def cast(src, dtype):
     """Generic cast operator.
index 171a2f8..44db999 100644 (file)
@@ -31,6 +31,7 @@ from . import util
 from .preprocessor import determine_variable_usage
 from ..api import all as _all
 from ..api import any as _any
+
 from ..container import Array
 from ..tensor import Tensor, Operation
 from .. import _api_internal as _tvm_internal
@@ -78,6 +79,18 @@ class Symbol(Enum):
     ThreadBind = 10
 
 
+def _floordiv(x, y):
+    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
+        return _api.floordiv(x, y)
+    return operator.floordiv(x, y)
+
+
+def _floormod(x, y):
+    if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
+        return _api.floormod(x, y)
+    return operator.mod(x, y)
+
+
 class HybridParser(ast.NodeVisitor):
     """Python AST visitor pass which finally lowers it to HalideIR"""
 
@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
         ast.Sub     : operator.sub,
         ast.Mult    : operator.mul,
         ast.Div     : operator.div if sys.version_info[0] == 2 else operator.truediv,
-        ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv,
-        ast.Mod     : operator.mod,
+        ast.FloorDiv: _floordiv,
+        ast.Mod     : _floormod,
         ast.BitOr   : operator.or_,
         ast.BitAnd  : operator.and_,
         ast.BitXor  : operator.xor,
index 5dddfc6..d1c0f09 100644 (file)
@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
 @script
 def _arange_shape_func(start, stop, step):
     out = output_tensor((1,), "int64")
-    out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
+    out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
     return out
 
 @_reg.register_shape_func("arange", True)
@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
             assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
             if newshape[i+1] == -1:
                 assert newshape[i+2] != -1, "Split dims cannot both be -1."
-                out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
+                out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
                 out[dst_idx+1] = int64(newshape[i+2])
             else:
                 out[dst_idx] = int64(newshape[i+1])
                 if newshape[i+2] == -1:
-                    out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
+                    out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
                 else:
                     out[dst_idx+1] = int64(newshape[i+2])
             assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
             new_size = int64(1)
             for i in const_range(out.shape[0]):
                 new_size *= out[i]
-            out[infer_idx] = old_size / new_size
+            out[infer_idx] = old_size // new_size
     return out
 
 @_reg.register_shape_func("reshape", False)
index d5fc1fa..b8ee144 100644 (file)
@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
 REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
 REGISTER_MAKE_BINARY_OP(_OpDiv, div);
 REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
+REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
 REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
 REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
 REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
index e98b366..54616ad 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) {  // NOLINT(*)
 void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "*", os, this);
 }
+
 void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) {  // NOLINT(*)
   if (op->type.is_int())
     PrintBinaryExpr(op, "//", os, this);
   else
     PrintBinaryExpr(op, "/", os, this);
 }
+
+void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) {  // NOLINT(*)
+  if (op->type.is_int())
+    PrintBinaryExpr(op, "//", os, this);
+  else
+    PrintBinaryExpr(op, "/", os, this);
+}
+
 void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "%", os, this);
 }
+
+void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) {  // NOLINT(*)
+  PrintBinaryExpr(op, "%", os, this);
+}
 void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "min", os, this);
 }
index e297556..498838f 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -100,6 +100,8 @@ class CodeGenHybrid :
   void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloorDiv* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloorMod* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
@@ -161,12 +163,12 @@ class CodeGenHybrid :
   std::string GetUniqueName(std::string prefix);
   /*! \brief The output code string builder. */
   std::stringstream stream;
-  /*! 
+  /*!
    * \brief Get or allocate the ID for the given variable.
    * \param v The given variable.
    */
   std::string GetVarID(const Variable *v);
-  /*! 
+  /*!
    * \brief Get or allocate the ID for the given tensor.
    * \param func The tensor to allocate a name.
    * \param value_index The value index of the given tensor.
index e502736..46a0737 100644 (file)
@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
 }
 
 Expr floordiv(Expr a, Expr b) {
+  CHECK(a.type().is_int() || a.type().is_uint());
+  CHECK(b.type().is_int() || b.type().is_uint());
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
   if (ret.defined()) return ret;
@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
 }
 
 Expr floormod(Expr a, Expr b) {
+  CHECK(a.type().is_int() || a.type().is_uint());
+  CHECK(b.type().is_int() || b.type().is_uint());
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
   if (ret.defined()) return ret;
index 813037a..bbc3c57 100644 (file)
@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
     if (op == nullptr) return ret;
     int shift;
     const DataType& dtype = op->type;
-    if (dtype.is_float()) {
-      return floor(Div::make(op->a, op->b));
-    }
     CHECK(dtype.is_int() || !dtype.is_uint());
 
     if (is_const_power_of_two_integer(op->b, &shift)) {
index 7f444d2..28a4388 100644 (file)
@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
               x * 13 + z * 4 + y * 4 +6)
     ck.verify(x * 3 - 4 * x + 1, 1 - x)
     ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # trucdiv
-    ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
-    ck.verify((x + y + x + y * 3) % 2, 0)
+    ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x)
+    ck.verify(tmod(x + y + x + y * 3, 2), 0)
 
     # floordiv
     fld = tvm.floordiv
@@ -51,28 +53,31 @@ def test_split_index_simplify():
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
 
     # trucdiv
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
+
     # split div const
-    ck.verify((x/3) *3 + x % 3, x)
-    ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
-    ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
-    ck.verify((x % 2) / 8, 0)
-    ck.verify((x % 2) / 7, 0)
-    ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6)
+    ck.verify(tdiv(x, 3) *3 + tmod(x, 3), x)
+    ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)
+    ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))
+    ck.verify(tdiv(tmod(x, 2), 8), 0)
+    ck.verify(tdiv(tmod(x, 2), 7), 0)
+    ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6))
 
     # split mod const
-    ck.verify((x * 8) % 16, (x % 2) * 8)
-    ck.verify((x * 8) % 2, 0)
+    ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8)
+    ck.verify(tmod(x * 8, 2), 0)
 
     # simplify then fold
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
-    ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
+    ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
     # complex fold
-    ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
+    ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y)
 
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
     ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
-    ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
+    ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
 
     # floordiv
     fld = tvm.floordiv
@@ -85,23 +90,24 @@ def test_split_index_simplify():
     ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
 
     # cannot simplify mixed case, unless we canonicalize into one mode.
-    ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2)
+    ck.verify(tdiv(x,6) * 2 + tmod(fld(x,3), 2), tdiv(x,6) * 2 + tmod(fld(x,3), 2))
 
 
 def test_div_simplify():
     ck = CanonicalChecker()
     x = tvm.var("x")
+    tdiv = tvm.truncdiv
 
     # truc div
-    ck.verify((16+48*x)/16, x*3 + 1)
+    ck.verify(tdiv(16+48*x,16), x*3 + 1)
     # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
     # (17+48*x)/16 != 1+3*x
-    ck.verify((17+48*x)/16, (x * 48 + 17) / 16)
+    ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))
     # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
-    ck.verify((17+48*x)/16, x * 3 + 1)
+    ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1)
     # Trying expressions that are not simplifiable for any values of the variables
-    ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
+    ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
 
     # floordiv
     fld = tvm.floordiv
@@ -124,8 +130,10 @@ def test_canonical_mixed():
     ck = CanonicalChecker()
     x = tvm.var("x")
     z = tvm.const(3, "int32")
-    ck.verify(x / (z*z) - x / (z*z), 0)
-    ck.verify(x / (z+z) - x / (z+z), 0)
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
+    ck.verify(tdiv(x, (z*z)) - tdiv(x, (z*z)), 0)
+    ck.verify(tdiv(x, (z+z)) - tdiv(x, (z+z)), 0)
     ck.verify(x - 2 < 3, x < 5)
     ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0)
     ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
@@ -207,42 +215,44 @@ def test_reduce_simplify():
               tvm.sum(k + j, [k, j]))
     ck.verify(tvm.sum(A[3], []), A[3])
     # The rule below is not typical, removed for now
-    ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))
+    ck.verify(tvm.sum(tvm.div(k, 10), k), tvm.sum(tvm.const(0, "int32"), k))
 
 
 def test_simplify_if_then_else():
     ck = CanonicalChecker()
     x = tvm.var("x")
     y = tvm.var("y")
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # simplification that takes condition into account.
     res = tvm.if_then_else((x * 4 + y) >= 466036,
-                           tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
-                                            (((((x*4) + y)  - 466036) % 24528) -24512) % 16,
+                           tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
+                                            tmod(tmod(((x*4) + y)  - 466036, 24528) -24512, 16),
                                             x), y)
 
     res2 = tvm.if_then_else((x * 4) >= 466036 - y,
-                           tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
-                                            (((((x*4) + y)  - 466036) % 24528) -24512) % 16,
+                           tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
+                                            tmod(tmod(((x*4) + y)  - 466036, 24528) -24512, 16),
                                             x), y)
     expected = tvm.if_then_else(
         tvm.expr.LE(466036, (x * 4 + y)),
-        tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
-                         (((x*4) + y)  - 4) % 16,
+        tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)),
+                                     tmod(((x*4) + y)  - 4, 16),
                          x), y)
     ck.verify(res, expected)
     ck.verify(res2, expected)
     # can only simplify if condition
-    res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
-    expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
+    res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
+    expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
     res = tvm.expr.Select(x >= 10,
-                          tvm.if_then_else(x / 3 > 2, x, 0), 0)
+                          tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
     expected = tvm.expr.Select(x >= 10, x, 0)
     ck.verify(res, ck.analyzer.canonical_simplify(expected))
 
     res = tvm.expr.Select(x >= 10,
-                          tvm.if_then_else(x / 3 < 2, x, 0), 0)
+                          tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
     ck.verify(res, 0)
 
 
@@ -250,20 +260,20 @@ def test_complex_cases():
     ck = CanonicalChecker()
     x = tvm.var("x")
     y = tvm.var("y")
-    res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) +
-              ((((((x*128) + y) % 36)*2) + 1)/2))
-             - (((x*128) + y) % 1296)) + 1)
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
+    res2 = (tdiv(tdiv(tmod(x*128 + y, 1296),36)*2 + 1,2)*36 +
+            tdiv(tmod((x*128) + y, 36)*2 + 1,2)
+            - tmod((x*128) + y, 1296) + 1)
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
     ck.verify(res2, 1)
 
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
-    res3 = ((((((((((x*1024) + y)/65536) + ((((x*1024) + y) % 65536)/256))
-                 + ((((x*1024) + y) % 256)/16)) + (((x*1024) + y) % 16)) - (y/256)) -
-              ((y % 256)/16))  - (y % 16)) - (x*4))
-    ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4))
-
-
+    res3 = (tdiv(x*1024 + y,65536) + tdiv(tmod(x*1024 + y, 65536),256)
+            + tdiv(tmod(x*1024 + y, 256),16) + tmod(x*1024 + y, 16) - tdiv(y,256) -
+            tdiv(tmod(y, 256),16) - tmod(y, 16) - (x*4))
+    ck.verify(res3, tdiv((x*1024) + y, 256) - tdiv(y,256) - (x*4))
 
 
 if __name__ == "__main__":
index 9b41bdc..4944861 100644 (file)
@@ -38,12 +38,13 @@ def test_dtype_bound():
 def test_cast_bound():
     analyzer = tvm.arith.Analyzer()
     x = tvm.var("x", dtype="int8")
-    bd = analyzer.const_int_bound((x % 3).astype("uint32"))
+    tmod = tvm.truncmod
+    bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32"))
     assert bd.min_value == 0
     assert bd.max_value == 2
 
     bd = analyzer.const_int_bound(
-        (x % 3).astype("float32").astype("int32"))
+        tmod(x, 3).astype("float32").astype("int32"))
     assert bd.min_value == -2
     assert bd.max_value == 2
 
@@ -98,47 +99,50 @@ def test_mul_bound():
     assert bd.max_value == bd.POS_INF
 
 
-def test_div_bound():
+def test_truncdiv_bound():
     analyzer = tvm.arith.Analyzer()
     x, y = tvm.var("x"), tvm.var("y")
+    tdiv = tvm.truncdiv
 
     analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
-    bd = analyzer.const_int_bound(x / y)
+    bd = analyzer.const_int_bound(tdiv(x, y))
     assert bd.min_value == -2
 
     analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
     analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
-    bd = analyzer.const_int_bound(x / y)
+    bd = analyzer.const_int_bound(tdiv(x, y))
     assert bd.min_value == -4
     assert bd.max_value == 9
 
     analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
     analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
-    bd = analyzer.const_int_bound(x / y)
+    bd = analyzer.const_int_bound(tdiv(x, y))
     assert bd.min_value == bd.NEG_INF
     assert bd.max_value == bd.POS_INF
 
 
-def test_mod_bound():
+def test_truncmod_bound():
     analyzer = tvm.arith.Analyzer()
     x, y = tvm.var("x"), tvm.var("y")
 
+    tmod = tvm.truncmod
+
     analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
-    bd = analyzer.const_int_bound(x % y)
+    bd = analyzer.const_int_bound(tmod(x, y))
     assert bd.min_value == -9
     assert bd.max_value == 4
 
     analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
-    bd = analyzer.const_int_bound(x % y)
+    bd = analyzer.const_int_bound(tmod(x, y))
     assert bd.min_value == -9
     assert bd.max_value == 9
 
     analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
     analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
-    bd = analyzer.const_int_bound(x % y)
+    bd = analyzer.const_int_bound(tmod(x, y))
     assert bd.min_value == 0
     assert bd.max_value == 9
 
@@ -253,9 +257,12 @@ def test_shift_and_bound():
 def test_mix_index_bound():
     analyzer = tvm.arith.Analyzer()
     x, y = tvm.var("x"), tvm.var("y")
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
+
     analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1))
     analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1))
-    bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8)
+    bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8)
     assert bd.min_value == 0
     assert bd.max_value == 24 - 1
 
@@ -263,7 +270,7 @@ def test_mix_index_bound():
     assert bd.min_value == 0
     assert bd.max_value == 24 * 3 - 1
 
-    bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7)
+    bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7)
     assert bd.min_value == 0
     assert bd.max_value == (23 // 7) * 7 + 6
 
@@ -273,8 +280,8 @@ if __name__ == "__main__":
     test_cast_bound()
     test_add_sub_bound()
     test_mul_bound()
-    test_div_bound()
-    test_mod_bound()
+    test_truncdiv_bound()
+    test_truncmod_bound()
     test_floordiv_bound()
     test_floormod_bound()
     test_min_max_bound()
index e41532d..235c935 100644 (file)
@@ -35,9 +35,11 @@ def test_deduce():
     d_s = tvm.arith.IntervalSet(-3, -1)
     zero = tvm.const(0, "int32")
 
+    tdiv = tvm.truncdiv
+
     e0 = (-b)*a+c-d
     res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((d - c) /(b*-1) + (-1))
+    ans0 = (tdiv(d - c, b*-1) + (-1))
     assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
@@ -46,7 +48,7 @@ def test_deduce():
 
     e0 = d*a+c-d
     res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((d-c)/d - 1)
+    ans0 = (tdiv(d-c,d) - 1)
     assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
@@ -56,7 +58,7 @@ def test_deduce():
 
     e1 = (a*4+b < c)
     res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    ans1 = (((c - b) + -1)/4 -1)
+    ans1 = (tdiv((c - b) + -1,4) -1)
     assert_expr_equal(res1.max_value, ans1)
 
 
@@ -79,7 +81,7 @@ def test_deduce():
 
     e3 = (-b)+a*c-d
     res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
-    ans3 = 2/c+1
+    ans3 = tdiv(2,c)+1
     assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
 
     res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
index ebf50e1..17cc6f1 100644 (file)
@@ -60,13 +60,14 @@ def test_add_sub():
 def test_mul_div():
     ck = IntSetChecker()
     x, y = tvm.var("x"), tvm.var("y")
+    tdiv = tvm.truncdiv
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
     ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
     ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
     ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
 
-    ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
-    ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
+    ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
+    ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
 
     fld = tvm.floordiv
     ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
@@ -76,9 +77,10 @@ def test_mul_div():
 def test_mod():
     ck = IntSetChecker()
     x, y = tvm.var("x"), tvm.var("y")
+    tmod = tvm.truncmod
     ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
-    ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
-    ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
+    ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
+    ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
 
     flm = tvm.floormod
     ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
index 80fff4e..1ce7197 100644 (file)
@@ -54,7 +54,8 @@ def test_div_shift():
     analyzer = tvm.arith.Analyzer()
     x, y = tvm.var("x"), tvm.var("y")
     # not sure if x is non-negative
-    m = analyzer.modular_set((x * 4 + 2) / 2)
+    tdiv = tvm.truncdiv
+    m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
     assert m.coeff == 1
     assert m.base == 0
     # right shift always round down so it is fine
@@ -67,7 +68,7 @@ def test_div_shift():
     assert m.base == 1
     # x is non-negative
     analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
-    m = analyzer.modular_set((x * 4 + 2) / 2)
+    m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
     assert m.coeff == 2
     assert m.base == 1
 
@@ -92,6 +93,7 @@ def test_mix_index():
     a = tvm.var("a")
     b = tvm.var("b")
     analyzer = tvm.arith.Analyzer()
+    tdiv = tvm.truncdiv
     m = analyzer.modular_set(a * 4 + b * 6 + 7)
     assert m.coeff == 2
     assert m.base == 1
@@ -100,11 +102,11 @@ def test_mix_index():
     assert m.coeff == 4
     assert m.base == 3
 
-    m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3))
+    m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3))
     assert m.coeff == 1
     assert m.base == 0
 
-    m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4))
+    m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4))
     assert m.coeff == 2
     assert m.base == 0
 
@@ -121,11 +123,13 @@ def test_constraint_scope():
     a = tvm.var("a")
     b = tvm.var("b")
     analyzer = tvm.arith.Analyzer()
-    with analyzer.constraint_scope(b % 4 == 2):
+    tmod = tvm.truncmod
+
+    with analyzer.constraint_scope(tmod(b, 4) == 2):
         m = analyzer.modular_set(b + 1)
         assert m.coeff == 4
         assert m.base == 3
-        with analyzer.constraint_scope(a % 2 == 1):
+        with analyzer.constraint_scope(tmod(a, 2) == 1):
             m = analyzer.modular_set(b + a * 2)
             assert m.coeff == 4
             assert m.base == 0
@@ -140,15 +144,16 @@ def test_constraint_scope():
 def test_intersect():
     a = tvm.var("a")
     analyzer = tvm.arith.Analyzer()
-    with analyzer.constraint_scope(a % 4 == 1):
-        with analyzer.constraint_scope(a % 3 == 1):
+    tmod = tvm.truncmod
+    with analyzer.constraint_scope(tmod(a, 4) == 1):
+        with analyzer.constraint_scope(tmod(a, 3) == 1):
             m = analyzer.modular_set(a)
             assert m.coeff == 12
             assert m.base == 1
 
-    with analyzer.constraint_scope(a % 3 == 2):
-        with analyzer.constraint_scope(a % 5 == 3):
-            with analyzer.constraint_scope(a % 7 == 2):
+    with analyzer.constraint_scope(tmod(a, 3) == 2):
+        with analyzer.constraint_scope(tmod(a, 5) == 3):
+            with analyzer.constraint_scope(tmod(a, 7) == 2):
                 m = analyzer.modular_set(a)
                 assert m.coeff == 105
                 assert m.base == 23
index 2507732..54ade9a 100644 (file)
@@ -60,11 +60,14 @@ def test_pack_gemm():
         k = tvm.reduce_axis((0, L))
 
         bn = 4
+        fld = tvm.floordiv
+        flm = tvm.floormod
+
         A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
         B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
         C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
         tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
-        C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn])
+        C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
 
         s = tvm.create_schedule([C.op])
         assert compute_flop(s) == 2 * N * L * M
@@ -119,9 +122,11 @@ def test_average_pool():
         OH = (H - KH) + 1
         OW = (W - KW) + 1
 
+
         C = tvm.compute(
             (N, CO, OH, OW),
-            lambda n, co, h, w: tvm.sum(D[n][co][h + kh][w + kw].astype(acc_dtype) / (KW * KH), axis=[kh, kw]))
+            lambda n, co, h, w: tvm.sum(
+                tvm.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]))
 
         s = tvm.create_schedule([C.op])
 
index 082b85f..090120c 100644 (file)
@@ -35,7 +35,7 @@ def test_lower_rfactor():
 def test_dependent_output_shape():
     n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x')
     A = tvm.placeholder((n, m))
-    B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B')
+    B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B')
     s = tvm.create_schedule(B.op)
     mod = tvm.build(s, [A, B, x])
 
index 06f1801..f4401d0 100644 (file)
@@ -409,7 +409,7 @@ def test_llvm_div():
     """Check that the semantics of div and mod is the same as in C/C++"""
     def check_div(start, end, divisor, dtype):
         T = tvm.compute((end - start,),
-                        lambda i: tvm.expr.Cast(dtype, (start + i)) / tvm.const(divisor, dtype))
+                        lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
         s = tvm.create_schedule([T.op])
         f = tvm.build(s, [T], "llvm")
         a = tvm.nd.empty((end - start,), dtype)
@@ -418,8 +418,9 @@ def test_llvm_div():
         tvm.testing.assert_allclose(a.asnumpy(), ref)
 
     def check_mod(start, end, divisor, dtype):
+        tmod = tvm.truncmod
         T = tvm.compute((end - start,),
-                        lambda i: tvm.expr.Cast(dtype, (start + i)) % tvm.const(divisor, dtype))
+                        lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
         s = tvm.create_schedule([T.op])
         f = tvm.build(s, [T], "llvm")
         a = tvm.nd.empty((end - start,), dtype)
@@ -443,7 +444,7 @@ def test_llvm_div():
 def test_llvm_fp_math():
     def check_llvm_reciprocal(n):
         A = tvm.placeholder((n,), name='A')
-        B = tvm.compute((n,), lambda i: 1.0/(1e+37*A[i]), name='B')
+        B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B')
 
         s = tvm.create_schedule(B.op)
         f = tvm.build(s, [A, B], "llvm")
index 322c37e..ef58174 100644 (file)
@@ -41,8 +41,9 @@ def test_if():
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
     A = ib.pointer("float32", name="A")
+    tmod = tvm.truncmod
     with ib.for_range(0, n, name="i") as i:
-        with ib.if_scope((i % 2) == 0):
+        with ib.if_scope(tmod(i, 2) == 0):
             A[i] = A[i] + 1
         with ib.else_scope():
             A[0] = A[i] + 2
@@ -108,13 +109,14 @@ def test_gpu():
     dtype = "float32"
     A = tvm.placeholder((n,), name='A')
     B = tvm.placeholder((n,), name='B')
+    fld = tvm.floordiv
     def test_device_ir(A, B, C):
         n = A.shape[0]
         max_threads = 32
         ib = tvm.ir_builder.create()
         bx = tvm.thread_axis("blockIdx.x")
         tx = tvm.thread_axis("threadIdx.x")
-        ib.scope_attr(bx, "thread_extent", (n+max_threads-1) // max_threads)
+        ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
         ib.scope_attr(tx, "thread_extent", max_threads)
         idx = bx.var * max_threads + tx.var
         Aptr = ib.buffer_ptr(A)
index 8c7e13a..9ad8b62 100644 (file)
@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
     def assert_simplified_equal(index_simplified, index_direct):
         assert tvm.ir_pass.Equal(index_simplified, index_direct),\
         "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     # Test Case1
-    index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1))
+    index_simplified = A_stride.vload(
+        (idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
     index_direct = A_stride.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
+
     # Test Case2
-    index_simplified = A.vload(((k0 % (k1 / s)) / n,
-                                (k0 % (k1 / s)) % n + (k0 % k1)))
-    index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s)))
+    index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
+                                idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
+    index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case3
-    index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n,
-                                ((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n))
+    index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
+                                idxdiv(idxmod(k0, idxdiv(k1, s)), n),
+                                idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
+                                idxmod(idxmod(k0, idxdiv(k1, s)), n)))
     index_direct = A.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case4 (not able to simplify)
-    index_simplified = A.vload(((k0 % (k1 / s)) / n,
-                                (k0 % (k1 / n)) % n + (k0 % k1)))
-    index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
+    index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
+                                idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
+    index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
+                            (idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
     assert_simplified_equal(index_simplified, index_direct)
 
 
@@ -143,14 +150,14 @@ def test_buffer_broadcast():
     check()
 
 
-def test_bbuffer_roadcast_expr():
+def test_buffer_broadcast_expr():
     n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x')
     n1, m1 = tvm.var('n1'), tvm.var('m1')
     o0, o1 = tvm.var('o0'), tvm.var('o1')
 
     A = tvm.placeholder((m0, n0), name='A')
     B = tvm.placeholder((m1, n1), name='B')
-    C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C')
+    C = tvm.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C')
 
     Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
     Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
index 2d30d29..c57f4a1 100644 (file)
@@ -32,10 +32,11 @@ def test_const_fold():
         if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
             raise ValueError("check error: %s vs %s " % (x, y))
 
+    tmod = tvm.truncmod
     check(lambda x, y: x + y, 3, 4)
     check(lambda x, y: x * y, 3, 12)
     check(lambda x, y: x * y - 10, 3, 12)
-    check(lambda x, y: x - y % 10, 3, 12)
+    check(lambda x, y: x - tmod(y, 10), 3, 12)
     check(lambda x, y: x // y + 10, 100, 12)
     check(lambda x, y: x & y + 10, 112, 128)
     check(lambda x, y: x > y, 112, 128)
@@ -47,13 +48,15 @@ def test_const_fold():
 
 def test_const_fold2():
     x = tvm.var("x")
+    tmod = tvm.truncmod
+    tdiv = tvm.truncdiv
     assert (x + 0).same_as(x)
     assert (0 + x).same_as(x)
     assert (x - 0).same_as(x)
-    assert (x % 1).value == 0
+    assert tmod(x, 1).value == 0
     assert (x * 1).same_as(x)
     assert (1 * x).same_as(x)
-    assert isinstance((1 / x), tvm.expr.Div)
+    assert isinstance(tdiv(1, x), tvm.expr.Div)
 
 def test_const_fold3():
     # Test that using ints with logic operations is forbidden
@@ -88,8 +91,9 @@ def test_const_fold3():
 def test_const_fold4():
     x1 = tvm.const(4, "int32")
     x2 = x1 + 5
+    tdiv = tvm.truncdiv
     assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
-    x3 = x2 / 3
+    x3 = tdiv(x2, 3)
     assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
     x4 = x3 + 0.55
     assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
index 7571df5..3cb6e65 100644 (file)
@@ -72,7 +72,7 @@ def test_combination():
     A = tvm.placeholder((n, m), name='A')
     B = tvm.placeholder((n, m), name='B')
     C = tvm.placeholder((n, m), name='C')
-    D = k + A - B * C / x
+    D = k + A - B * C + x
     s = tvm.create_schedule(D.op)
     foo = tvm.build(s, [x, A, B, C, D], "llvm")
     ctx = tvm.cpu(0)
@@ -82,7 +82,7 @@ def test_combination():
     c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx)
     d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
     foo(x, a, b, c, d)
-    tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() / x)
+    tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x)
 
 
 def verify_tensor_scalar_bop(shape, typ="add"):
index b05d75a..8f35611 100644 (file)
 import tvm
 
 def test_simplify():
+  tdiv = tvm.truncdiv
+  tmod = tvm.truncmod
   x = tvm.var('x')
   e1 = tvm.ir_pass.Simplify(x + 2 + 1)
   assert(tvm.ir_pass.Equal(e1, x + 3))
   e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
   assert(tvm.ir_pass.Equal(e2, x * 8))
-  e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
-  assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
+  e3 = tvm.ir_pass.Simplify(x - tdiv(x, 3) * 3)
+  assert(tvm.ir_pass.Equal(e3, tmod(x, 3)))
 
 
 def test_verify_ssa():
index 7b9fd7f..8bd491b 100644 (file)
@@ -24,7 +24,7 @@ def test_equal_expr():
         return x + y + 1
 
     def func2():
-        return tvm.exp((x + y + 1) * y / 4)
+        return tvm.exp(tvm.truncdiv((x + y + 1) * y, 4))
 
     assert tvm.ir_pass.Equal(func1(), func1())
     assert tvm.ir_pass.Equal(func2(), func2())
index bb09966..b6fcfa3 100644 (file)
@@ -162,7 +162,7 @@ def test_condition():
     ib = tvm.ir_builder.create()
     m = tvm.var('m')
     n = tvm.var('n')
-    with ib.for_range(0, ((n+3)/4), 'i') as i:
+    with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i:
       with ib.for_range(0, 4, 'j') as j:
         ib.emit(tvm.make.Evaluate(
           tvm.make.Select(ib.likely(i*4+j<n), m, n)))
@@ -206,7 +206,7 @@ def test_everything_during_deduction():
     ib = tvm.ir_builder.create()
     with ib.for_range(0, n, 'i') as i:
         with ib.for_range(0, 32, 'j') as j:
-            with ib.if_scope(ib.likely(i/j < m)):
+            with ib.if_scope(ib.likely(tvm.truncdiv(i,j) < m)):
                 # this guard will produce everything during deduction
                 ib.emit(tvm.make.Evaluate(m))
     stmt = ib.get()
index 1ff9853..9c3d1df 100644 (file)
@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
 
     bounds = tvm.schedule.InferBound(s)
     assert isinstance(bounds, tvm.container.Map)
-    assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0)
+    idxdiv = tvm.indexdiv
+    assert(tvm.ir_pass.Simplify(
+            bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0)
 
-    expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1)
+    expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
     for i in range(1, 6):
         for j in range(1, 6):
             for k in range(1, 6):
@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
                 comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
                 exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value
                 assert(comp_ext == exp_ext)
-    
+
     assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
 
 def test_bound_fusesplit2():
@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
         if not bounds[A.op.axis[0]].extent.value <= 2:
             print(stmt)
             assert bounds[A.op.axis[0]].extent.value <= 2
-
+    tdiv = tvm.truncdiv
     # These are hard to simplify, moreover we don't simplify them
     _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
     _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
-    _check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
+    _check(tvm.compute((10,), lambda i: A[-2*tdiv(i,2) - tvm.min(i, 0-i)]))
     _check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
     # This would cause out of bounds, but we nevertheless include it
     _check(tvm.compute((10,), lambda i: A[i]))
index b8b9e3d..4bad959 100644 (file)
@@ -221,11 +221,14 @@ def test_tensorize_matmul():
 # This tests whether algorithm and intrinsics expressions are simplified
 # as much as possible first and then checked for equality. See Issue #696
 def test_tensorize_op():
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     def op_intrin():
         bh = 9
         bw = 9
         x = tvm.placeholder((5, 5), name='A')
-        y = tvm.compute((bh, bw), lambda i,j: x[j/3 + i%3, j%3+ i/3])
+        y = tvm.compute((bh, bw),
+                        lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
 
         def intrin_func(ins, outs):
             xx, = ins
@@ -236,7 +239,7 @@ def test_tensorize_op():
             return tvm.decl_tensor_intrin(y.op, intrin_func)
 
     A = tvm.placeholder((5, 5), name='A')
-    B = tvm.compute((9,9), lambda i, j: A[j/3 + i%3, j%3 + i/3])
+    B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
     bt = op_intrin()
     s = tvm.create_schedule(B.op)
 
index 2066b9a..b566c98 100644 (file)
@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
                     kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
                     axis=[ci, kh, kw]), name='conv')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     output = tvm.compute(oshape, lambda n, co, h, w:
-                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
+                         conv[n,
+                              idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
+                              idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
                          name='output_unpack', tag='spatial_conv2d_output')
     return output
 
index bd56ef3..65f1024 100644 (file)
@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
                 kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype),
                 axis=[ci, kh, kw]), name='conv')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     output = tvm.compute(oshape, lambda n, co, h, w:
-                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
+                         conv[n,
+                              idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
+                              idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
                          name='output_unpack', tag='spatial_conv2d_transpose_output')
     return output
 
index 51088df..a5dc1e9 100644 (file)
@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
     kh = tvm.reduce_axis((0, KH), name='kh')
     kw = tvm.reduce_axis((0, KW), name='kw')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     if dilation_h != 1 or dilation_w != 1:
-        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                          tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw]
-                                  .astype(out_dtype) *
-                                  kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
-                                  axis=[kh, kw]), name='depthwise_conv')
+        conv = tvm.compute(
+            ovshape, lambda n, co, h, w, vh, vw, vc: \
+            tvm.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]
+                    .astype(out_dtype) *
+                    kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
+                    axis=[kh, kw]), name='depthwise_conv')
     else:
         conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-                           tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh,
+                           tvm.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh,
                                             vw * WSTR + kw].astype(out_dtype) *
-                                   kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
+                                   kernel_vec[idxdiv(co, M),
+                                              idxmod(co, M),
+                                              kh, kw, vc].astype(out_dtype),
                                    axis=[kh, kw]), name='depthwise_conv')
 
     output = tvm.compute(oshape, lambda n, co, h, w:
-                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
+                         conv[n,
+                              idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
+                              idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
                          name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
     return output
 
index ceb331c..3201878 100644 (file)
@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
                       [0, 0, (bpad_bottom + stride_h - 1) // stride_h,
                        (bpad_right + stride_w - 1) // stride_w], name='FirstPad')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     # remove extra padding introduced by dilatation
-    border_h = (stride_h - bpad_top % stride_h) % stride_h
-    border_w = (stride_w - bpad_left % stride_w) % stride_w
+    border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
+    border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
 
     # dilation stage
     data = FirstPad
@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
         index_tuple = []
         for i in range(n):
             if not equal_const_int(strides[i], 1):
-                index_tuple.append(indices[i] // strides[i])
-                not_zero.append((indices[i] % strides[i]).equal(0))
+                index_tuple.append(idxdiv(indices[i], strides[i]))
+                not_zero.append(idxmod(indices[i], strides[i]).equal(0))
             else:
                 index_tuple.append(indices[i])
         if not_zero:
index 29f14a0..50c4852 100644 (file)
@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
     else:
         kernel_pack = kernel
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     # pack input tile
     input_tile = tvm.compute((CI, P, alpha, alpha), lambda c, p, eps, nu:
-                             data_pad[p // (nH * nW)][c][p // nW % nH * m + eps]
-                             [p % nW * m + nu], name='d')
+                             data_pad[idxdiv(p, (nH * nW))][c][idxmod(idxdiv(p, nW), nH) * m + eps]
+                             [idxmod(p, nW) * m + nu], name='d')
 
     # transform data
     r_a = tvm.reduce_axis((0, alpha), 'r_a')
@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
 
     # output
     output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
-                         inverse[co][n * nH * nW + (h // m) * nW + w // m][h % m][w % m],
+                         inverse[co,
+                                 n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                 idxmod(h, m),
+                                 idxmod(w, m)],
                          name='output', tag='conv2d_nchw_winograd')
     cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
 
index 417e2dc..6ff8a79 100644 (file)
@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
     new_range = num_anchors // elem_per_thread + 1
     # Scan: Downsweep:
     with ib. if_scope(tid < batch_size * num_anchors):
-        i = tid / num_anchors # number of batches
+        i = tid // num_anchors # number of batches
         j = tid % num_anchors # number of anchors
         with ib.if_scope(j < elem_per_thread):
             idx[tid] = idx_in[tid]
@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
     tid = bx * max_threads + tx
 
     with ib.if_scope(tid < batch_size * num_anchors):
-        i = tid / num_anchors
+        i = tid // num_anchors
         j = tid % num_anchors
         base_idx = i * num_anchors * elem_length
         with ib.if_scope(flag[tid] > 0):
index f7e5f94..03fa999 100644 (file)
@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
     tid = bx * max_threads + tx
 
     with ib.if_scope(tid < batch_size * num_anchors):
-        i = tid / num_anchors
+        i = tid // num_anchors
         j = tid % num_anchors
         with ib.if_scope(cls_id[tid] > 0):
             with ib.if_scope(tid == 0):
index 2382127..45882b7 100644 (file)
@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
                     tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu],
                             axis=[r_a, r_b]), name='V')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     # batch gemm
     ci = tvm.reduce_axis((0, CI), name='c')
     M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p:
-                    tvm.sum(U[eps][nu][co // bna][ci][co % bna] *
-                            V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M')
+                    tvm.sum(U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)] *
+                            V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)], axis=ci), name='M')
 
     r_a = tvm.reduce_axis((0, alpha), 'r_a')
     r_b = tvm.reduce_axis((0, alpha), 'r_b')
@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
 
     # unpack output
     output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
-                         Y[co][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
+                         Y[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                           idxmod(h, m), idxmod(w, m)]
                          # The following hack term is used to make the padding in batch gemm ("M")
                          # effective, otherwise the padding will be eliminated by bound inference.
                          # Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
index c04ff01..2faabf2 100644 (file)
@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
                        axis=[ci, dh, dw, b1, b2])
 
     conv = tvm.compute(ovshape, _conv, name='conv_out')
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
 
-    return tvm.compute(oshape, lambda n, co, h, w:
-                       conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
-                       name='conv_vec', tag='spatial_bitserial_conv_nchw')
+    return tvm.compute(
+        oshape, lambda n, co, h, w:
+        conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv(
+            w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
+        name='conv_vec', tag='spatial_bitserial_conv_nchw')
 
 @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
 def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
 
     conv = tvm.compute(ovshape, _conv, name='conv')
 
-    return tvm.compute(oshape, lambda n, h, w, co:
-                       conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
-                       name='output_unpack', tag='spatial_bitserial_conv_nhwc')
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+    return tvm.compute(
+        oshape, lambda n, h, w, co:
+        conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv(
+            co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
+        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
 
 @tvm.target.generic_func
 def bitserial_conv2d_legalize(attrs, inputs, types):
index b28b3a4..d77a1b7 100644 (file)
@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
     weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
                              weight_packed[xo*VX+vx][wb][k], name='weight_vec')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
-        (tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) -
-         tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype)
+        (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
+         tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
+        ).astype(out_dtype)
         << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
 
     matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
-        tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype)
+        tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
+                    ).astype(out_dtype)
         << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
 
     # binary ops
index 590600a..904dd54 100644 (file)
@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
     kh = tvm.reduce_axis((0, kernel_height), name='kh')
     kw = tvm.reduce_axis((0, kernel_width), name='kw')
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
                        tvm.sum(data_pad[n,
-                                        ic // ic_bn,
+                                        idxdiv(ic, ic_bn),
                                         oh * HSTR + kh * dilation_h,
                                         ow * WSTR + kw * dilation_w,
-                                        ic % ic_bn].astype(out_dtype)
+                                        idxmod(ic, ic_bn)].astype(out_dtype)
                                * kernel[oc_chunk,
-                                        ic // ic_bn,
+                                        idxdiv(ic, ic_bn),
                                         kh,
                                         kw,
-                                        ic % ic_bn,
+                                        idxmod(ic, ic_bn),
                                         oc_block],
                                axis=[ic, kh, kw]),
                        name='conv2d_NCHWc', tag="conv2d_NCHWc")
index e703bec..f50e357 100644 (file)
@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
     pad_after = [0, 0, pad_down, pad_right]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
     # depthconv stage
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     di = tvm.reduce_axis((0, filter_height), name='di')
     dj = tvm.reduce_axis((0, filter_width), name='dj')
     Output = tvm.compute(
         (batch, out_channel, out_height, out_width),
         lambda b, c, i, j: tvm.sum(
-            (PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h,
+            (PaddedInput[b, idxdiv(c, channel_multiplier), i*stride_h+di*dilation_h,
                          j*stride_w+dj*dilation_w].astype(out_dtype) *
-             Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
+             Filter[idxdiv(c, channel_multiplier),
+                    idxmod(c, channel_multiplier), di, dj].astype(out_dtype)),
             axis=[di, dj]),
         name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
     return Output
@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
     pad_after = [0, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
     # depthconv stage
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     di = tvm.reduce_axis((0, filter_height), name='di')
     dj = tvm.reduce_axis((0, filter_width), name='dj')
     Output = tvm.compute(
         (batch, out_height, out_width, out_channel),
         lambda b, i, j, c: tvm.sum(
             (PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
-                         c/channel_multiplier].astype(out_dtype) *
-             Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
+                         idxdiv(c, channel_multiplier)].astype(out_dtype) *
+             Filter[di, dj,
+                    idxdiv(c, channel_multiplier),
+                    idxmod(c, channel_multiplier)].astype(out_dtype)),
             axis=[di, dj]),
         name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
     return Output
@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
     dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh')
     dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw')
     db = tvm.reduce_axis((0, batch), name='db')
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
 
     Weight_grad = tvm.compute(
         (filter_h, filter_w, in_c, channel_multiplier),
         lambda fh, fw, c, m: tvm.sum(
-            Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] *
+            Out_grad[db, dh, dw, c*channel_multiplier+idxmod(m, channel_multiplier)] *
             padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
         tag='depthwise_conv2d_backward_weight_nhwc')
 
index 24bf0d3..d952453 100644 (file)
@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
     def _dilate(*indices):
         not_zero = []
         index_tuple = []
+        idxdiv = tvm.indexdiv
+        idxmod = tvm.indexmod
         for i in range(n):
             if not util.equal_const_int(strides[i], 1):
-                index_tuple.append(indices[i] / strides[i])
-                not_zero.append((indices[i] % strides[i]).equal(0))
+                index_tuple.append(idxdiv(indices[i], strides[i]))
+                not_zero.append(idxmod(indices[i], strides[i]).equal(0))
             else:
                 index_tuple.append(indices[i])
         if not_zero:
index eb07f5e..dba9b7c 100644 (file)
@@ -38,12 +38,14 @@ def flatten(data):
     for i in range(1, len(ishape)):
         dim = dim * ishape[i]
     oshape = [ishape[0], dim]
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
 
     def unwrap(idx, shape):
         index = []
         for s in reversed(shape):
-            index.append(idx % s)
-            idx = idx / s
+            index.append(idxmod(idx, s))
+            idx = idxdiv(idx, s)
         return list(reversed(index))
 
     return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:])))
index 6565de2..e7c57e0 100644 (file)
@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
     ic = tvm.reduce_axis((0, in_channel), name='ic')
     kh = tvm.reduce_axis((0, kernel_height), name='kh')
     kw = tvm.reduce_axis((0, kernel_width), name='kw')
+    idxmod = tvm.indexmod
+    idxdiv = tvm.indexdiv
 
     conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                       tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
+                       tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h,
+                                        idxmod(ic, ic_bn),
                                         ow*WSTR+kw*dilation_w].astype(out_dtype) *
-                               kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
+                               kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw,
+                                          idxmod(ic, ic_bn),
                                           oc_block].astype(out_dtype),
                                axis=[ic, kh, kw]), name='conv')
 
     unpack = tvm.compute(unpack_shape,
-                         lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn]
+                         lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
                          .astype(out_dtype),
                          name='output_unpack',
                          tag='conv2d_nchw')
@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
     cfg = get_config()
     _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout)
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     # change shape with the value in config
     ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
                            cfg["tile_ow"].size[-1])
-    new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
+    new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn),
                       raw_data_shape[2], raw_data_shape[3], ic_bn)
     data_layout = "NCHW%dc" % ic_bn
     out_layout = "NCHW%dc" % oc_bn
-    new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
+    new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn),
+                        idxdiv(raw_kernel_shape[1], ic_bn),
                         raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
     new_data = tvm.placeholder(new_data_shape, data.dtype)
     new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
     _, data, kernel, strides, padding, dilation, layout, dtype = workload
     batch_size, in_channel, in_height, in_width = data[:-1]
     out_channel, _, k_height, k_width = kernel[:-1]
-    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
-    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
+    idxdiv = tvm.indexdiv
+
+    out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1
+    out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1
     tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
+    in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
     in_layout = "NCHW%dc" % tile_ic
-    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
+    out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
     out_layout = "NCHW%dc" % tile_oc
     return ((in_shape, in_layout),), ((out_shape, out_layout),)
 
index 9b7624e..ec401bf 100644 (file)
@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
     packw = tvm.compute(packw_shape,
                         lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
 
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
     k = tvm.reduce_axis((0, K), name="k")
     C = tvm.compute((M, N),
                     lambda y, x: tvm.sum(
                         data[y, k].astype(out_dtype) *
-                        packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
+                        packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
                         axis=k),
                     tag="dense_pack")
     if bias is not None:
index d713a54..8af41da 100644 (file)
@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
         data_pad = data
 
     # depthconv stage
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
     kh = tvm.reduce_axis((0, filter_height), name='kh')
     kw = tvm.reduce_axis((0, filter_width), name='kw')
     Output = tvm.compute(
         (batch, out_channel_chunk, out_height, out_width, out_channel_block),
         lambda b, oco, oh, ow, oci: tvm.sum(
-            (data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block,
-                      oh*HSTR+kh, ow*WSTR+kw,
-                      ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block]
+            (data_pad[
+                b,
+                idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block),
+                oh*HSTR+kh, ow*WSTR+kw,
+                idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)]
              .astype(out_dtype) *
              kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
             axis=[kh, kw]),