[ARITH] cleanup the indexmod/div on python side (#4028)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 28 Sep 2019 21:43:44 +0000 (14:43 -0700)
committerGitHub <noreply@github.com>
Sat, 28 Sep 2019 21:43:44 +0000 (14:43 -0700)
24 files changed:
python/tvm/autotvm/task/task.py
python/tvm/expr.py
src/pass/rewrite_unsafe_select.cc
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level5.py
tests/python/unittest/test_autotvm_flop_calculator.py
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_ir_builder.py
tests/python/unittest/test_lang_buffer.py
tests/python/unittest/test_pass_rewrite_unsafe_select.py
tests/python/unittest/test_schedule_tensorize.py
topi/python/topi/arm_cpu/bitserial_conv2d.py
topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/rcnn/proposal.py
topi/python/topi/cuda/sort.py
topi/python/topi/cuda/ssd/multibox.py
topi/python/topi/nn/bitserial_conv2d.py
topi/python/topi/nn/sparse.py
topi/python/topi/util.py
topi/python/topi/vision/ssd/multibox.py
topi/python/topi/x86/conv2d_avx_1x1.py
tutorials/optimize/opt_gemm.py
vta/python/vta/ir_pass.py

index 901183f..e0db275 100644 (file)
@@ -350,7 +350,9 @@ def compute_flop(sch):
             return _count_flop(exp.value)
         if isinstance(exp, expr.Var):
             return 0
-        if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
+        if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
+                            expr.Div, expr.Mod,
+                            expr.FloorDiv, expr.FloorMod,
                             expr.Max, expr.Min,
                             expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
                             expr.And, expr.Or, expr.Not)):
index a8bd651..5b7c60d 100644 (file)
@@ -72,23 +72,23 @@ class ExprOp(object):
         return _generic.multiply(other, self)
 
     def __div__(self, other):
-        if _dtype_is_int(self) and _dtype_is_int(other):
-            raise div_ambiguity_error()
+        if _dtype_is_int(self) and _dtype_is_int(other):
+            raise div_ambiguity_error()
         return _generic.divide(self, other)
 
     def __rdiv__(self, other):
-        if _dtype_is_int(self) and _dtype_is_int(other):
-            raise div_ambiguity_error()
+        if _dtype_is_int(self) and _dtype_is_int(other):
+            raise div_ambiguity_error()
         return _generic.divide(other, self)
 
     def __truediv__(self, other):
-        if _dtype_is_int(self) and _dtype_is_int(other):
-            raise div_ambiguity_error()
+        if _dtype_is_int(self) and _dtype_is_int(other):
+            raise div_ambiguity_error()
         return _generic.divide(self, other)
 
     def __rtruediv__(self, other):
-        if _dtype_is_int(self) and _dtype_is_int(other):
-            raise div_ambiguity_error()
+        if _dtype_is_int(self) and _dtype_is_int(other):
+            raise div_ambiguity_error()
         return _generic.divide(other, self)
 
     def __floordiv__(self, other):
@@ -100,8 +100,8 @@ class ExprOp(object):
         return _generic.divide(other, self)
 
     def __mod__(self, other):
-        raise div_ambiguity_error()
-        return _make._OpMod(self, other)
+        raise div_ambiguity_error()
+        return _make._OpMod(self, other)
 
     def __neg__(self):
         neg_one = _api_internal._const(-1, self.dtype)
index 871efca..62db0b4 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
   bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
   bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
   bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
   bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
   bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
   bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
index 424462f..2d92489 100644 (file)
@@ -373,6 +373,8 @@ def test_split_infer_type():
         yy = run_infer_type(y.astuple())
         assert yy.checked_type == ret_type
 
+    idxd = tvm.indexdiv
+
     d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
     axis = tvm.var("axis")
     verify_split((5, 5, 2, 2), 5,
@@ -393,15 +395,15 @@ def test_split_infer_type():
                   axis=0)
     verify_split((d1, d2, d3, d4), 4,
                  relay.ty.TupleType(tvm.convert([
-                     relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
-                     relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
-                     relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
-                     relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
+                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
+                     relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
                   axis=2)
     verify_split((d1, d2, d3, d4), 2,
                  relay.ty.TupleType(tvm.convert([
-                     relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
-                     relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
+                     relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
+                     relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
                   axis=0)
     verify_split((d1, d2, d3, d4), (2, 4, 7),
                  relay.ty.TupleType(tvm.convert([
index f4ac673..8c10735 100644 (file)
@@ -487,8 +487,9 @@ def test_yolo_reorg_infer_shape():
         assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
 
     n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
+    idxd = tvm.indexdiv
     verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
-    verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
+    verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
 
 def test_yolo_reorg():
     def verify_yolo_reorg(shape, stride):
index 54ade9a..5cafd02 100644 (file)
@@ -60,14 +60,14 @@ def test_pack_gemm():
         k = tvm.reduce_axis((0, L))
 
         bn = 4
-        fld = tvm.floordiv
-        flm = tvm.floormod
+        idxd = tvm.indexdiv
+        idxm = tvm.indexmod
 
         A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
         B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
         C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
         tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
-        C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
+        C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])
 
         s = tvm.create_schedule([C.op])
         assert compute_flop(s) == 2 * N * L * M
index 63aaf21..aa3a537 100644 (file)
@@ -37,7 +37,7 @@ def test_cuda_vectorize_add():
             print("skip because gpu does not support int8")
             return
         A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
-        B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
+        B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
         s = tvm.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
         s[B].bind(xo, bx)
@@ -165,9 +165,10 @@ def test_cuda_shuffle():
         print("skip because cuda is not enabled..")
         return
 
+    idxm = tvm.indexmod
     a = tvm.placeholder((64, ), 'int32')
     b = tvm.placeholder((64, ), 'int32')
-    c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
+    c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
     sch = tvm.create_schedule(c.op)
     x = c.op.axis[0]
     xo, xi = sch[c].split(x, 4)
index ef58174..c910c62 100644 (file)
@@ -109,14 +109,15 @@ def test_gpu():
     dtype = "float32"
     A = tvm.placeholder((n,), name='A')
     B = tvm.placeholder((n,), name='B')
-    fld = tvm.floordiv
+    idxd = tvm.indexdiv
+
     def test_device_ir(A, B, C):
         n = A.shape[0]
         max_threads = 32
         ib = tvm.ir_builder.create()
         bx = tvm.thread_axis("blockIdx.x")
         tx = tvm.thread_axis("threadIdx.x")
-        ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
+        ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
         ib.scope_attr(tx, "thread_extent", max_threads)
         idx = bx.var * max_threads + tx.var
         Aptr = ib.buffer_ptr(A)
index 9ad8b62..32c1745 100644 (file)
@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
     def assert_simplified_equal(index_simplified, index_direct):
         assert tvm.ir_pass.Equal(index_simplified, index_direct),\
         "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
     # Test Case1
     index_simplified = A_stride.vload(
-        (idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
+        (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
     index_direct = A_stride.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
 
     # Test Case2
-    index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
-                                idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
-    index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
+    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
+                                idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
+    index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case3
-    index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
-                                idxdiv(idxmod(k0, idxdiv(k1, s)), n),
-                                idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
-                                idxmod(idxmod(k0, idxdiv(k1, s)), n)))
+    index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
+                                idxd(idxm(k0, idxd(k1, s)), n),
+                                idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
+                                idxm(idxm(k0, idxd(k1, s)), n)))
     index_direct = A.vload((0, k0))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case4 (not able to simplify)
-    index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
-                                idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
-    index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
-                            (idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
+    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
+                                idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
+    index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
+                            (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
     assert_simplified_equal(index_simplified, index_direct)
 
 
index b2d73ec..4c42899 100644 (file)
@@ -28,7 +28,7 @@ def test_rewrite_Select():
         tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
     zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
 
-    a = tvm.expr.Select(i>10, y, z)
+    a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
     aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
     assert yy.name == "tvm_if_then_else"
     assert zz.name == "tvm_if_then_else"
index 4bad959..59adf0c 100644 (file)
@@ -221,14 +221,15 @@ def test_tensorize_matmul():
 # This tests whether algorithm and intrinsics expressions are simplified
 # as much as possible first and then checked for equality. See Issue #696
 def test_tensorize_op():
-    tdiv = tvm.truncdiv
-    tmod = tvm.truncmod
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     def op_intrin():
         bh = 9
         bw = 9
         x = tvm.placeholder((5, 5), name='A')
         y = tvm.compute((bh, bw),
-                        lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
+                        lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])
 
         def intrin_func(ins, outs):
             xx, = ins
@@ -239,7 +240,7 @@ def test_tensorize_op():
             return tvm.decl_tensor_intrin(y.op, intrin_func)
 
     A = tvm.placeholder((5, 5), name='A')
-    B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
+    B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
     bt = op_intrin()
     s = tvm.create_schedule(B.op)
 
index 072c187..9b8360d 100644 (file)
@@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
     OW = (PAD_W - KW) // WSTR + 1
     oshape = (1, OH, OW, CO)
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     # Pad input channels of weights and data when it is not a multiple of 8
     if CI_packed % 8 != 0:
         CI_PAD = CI_packed % 8
@@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
     data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
 
     kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
-    if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
+    idxm = tvm.indexmod
+    if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0:
         kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])
 
     N, H, W, IB, CI = data_q.shape
@@ -147,8 +151,12 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
     else:
         conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
 
-    conv = tvm.compute(oshape, lambda n, h, w, co:
-                       conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
+
+    conv = tvm.compute(oshape,
+                       lambda n, h, w, co:
+                       conv_vec[n,
+                                idxd(h, VH), idxd(w, VW), idxd(co, VC),
+                                idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
                        name='conv', tag='spatial_bitserial_conv_nhwc')
 
     return conv
index 73a97d2..f5cbbf0 100644 (file)
@@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
     data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     r = KW
     m = tile_size
     alpha = m + r - 1
@@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     VK = cfg['tile_k'].size[-1]
 
     # pack input tile
-    input_tile = tvm.compute((C, P // VP, alpha, alpha, VP),
+    input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP),
                              lambda c, b, eps, nu, bb:
-                             data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
-                             [(b*VP+bb) % nW * m + nu],
+                             data_pad[idxd(b*VP + bb, nH*nW), c,
+                                      idxm(idxd(b*VP + bb, nW), nH) * m + eps,
+                                      idxm(b*VP + bb, nW) * m + nu],
                              name='d')
 
     # transform kernel
@@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     else:
         r_kh = tvm.reduce_axis((0, KH), 'r_kh')
         r_kw = tvm.reduce_axis((0, KW), 'r_kw')
-        U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
+        U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk:
                         tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
                                 G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
 
     # transform image
     r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
     r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
-    V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
+    V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb:
                     tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
                             B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')
 
     # batch gemm
     c = tvm.reduce_axis((0, C), name='c')
     M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
-                    tvm.sum(U[eps][nu][k // VK][c][k % VK] *
-                            V[eps][nu][b // VP][c][b % VP], axis=c), name='M')
+                    tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] *
+                            V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M')
 
     # inverse transform
     r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
@@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
 
     # unpack output
     output = tvm.compute((N, K, H, W), lambda n, k, h, w:
-                         Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
+                         Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m),
+                              idxm(h, m), idxm(w, m)],
                          name='output', tag='winograd_conv2d_output')
 
     # we have to manually assign effective GFLOP for winograd
@@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     N, CI, H, W = get_const_tuple(data.shape)
     CO, _, KH, KW = get_const_tuple(kernel.shape)
 
+    idxd = tvm.indexdiv
+
     if groups == 1:
         # query config of this workload
         workload = autotvm.task.args_to_workload(
@@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
 
             # Store the same config for the altered operator (workload)
             new_data = data
-            new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
+            new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
             dispatch_ctx.update(target, new_workload, cfg)
@@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
                                                                    tile_size=tile_size)
             weight = F.reshape(weight,
-                               newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
+                               newshape=(KH + tile_size - 1,
+                                         KW + tile_size - 1,
+                                         idxd(CO, VC), VC, CI))
             weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
 
             copy_inputs[1] = weight
@@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
 
             # Store the same config for the altered operator (workload)
             new_data = data
-            new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
+            new_weight = tvm.placeholder((KH + tile_size - 1,
+                                          KH + tile_size -1,
+                                          idxd(CO, VC), CI, VC),
                                          kernel.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_weight, strides, padding, dilation,
@@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             # Store the same config for the altered operator (workload)
             new_data = data
             CO, M, KH, KW = get_const_tuple(kernel.shape)
-            new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
+            new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_kernel, strides, padding, dilation, out_dtype],
                 depthwise_conv2d_nchw)
index 6ff8a79..33fc724 100644 (file)
@@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
     new_range = num_anchors // elem_per_thread + 1
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
     # Scan: Downsweep:
     with ib. if_scope(tid < batch_size * num_anchors):
-        i = tid // num_anchors # number of batches
-        j = tid % num_anchors # number of anchors
+        i = idxd(tid, num_anchors) # number of batches
+        j = idxm(tid, num_anchors) # number of anchors
         with ib.if_scope(j < elem_per_thread):
             idx[tid] = idx_in[tid]
         with ib.else_scope():
-            idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1]
+            idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
 
     return ib.get()
 
@@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     with ib.if_scope(tid < batch_size * num_anchors):
-        i = tid // num_anchors
-        j = tid % num_anchors
+        i = idxd(tid, num_anchors)
+        j = idxm(tid, num_anchors)
         base_idx = i * num_anchors * elem_length
         with ib.if_scope(flag[tid] > 0):
             with ib.for_range(0, elem_length) as k:
index 06226d1..54f73a1 100644 (file)
@@ -79,10 +79,13 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
     p_im_info = ib.buffer_ptr(im_info_buf)
     p_out = ib.buffer_ptr(out_buf)
 
+    idxm = tvm.indexmod
+    idxd = tvm.indexdiv
+
     with ib.if_scope(tid < batch * height * width):
-        w = tid % width
-        h = (tid // width) % height
-        b = tid // width // height
+        w = idxm(tid, width)
+        h = idxm(idxd(tid, width), height)
+        b = idxd(idxd(tid, width), height)
 
         for k in range(num_anchors):
             out_index = tid * num_anchors + k
@@ -163,6 +166,8 @@ def argsort_ir(data_buf, out_index_buf):
     temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
     temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
 
+    idxm = tvm.indexmod
+
     with ib.for_range(0, batch, for_type="unroll") as b:
         start = b * num_bbox
         for i in range(2):
@@ -170,7 +175,7 @@ def argsort_ir(data_buf, out_index_buf):
             with ib.if_scope(bbox_id < num_bbox):
                 index_out[start + bbox_id] = bbox_id
         with ib.for_range(0, num_bbox) as k:
-            offset = start + 2 * tid + (k % 2)
+            offset = start + 2 * tid + idxm(k, 2)
             with ib.if_scope(
                 tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
                 temp_data[0] = p_data[offset]
index c45465e..b02c14b 100644 (file)
@@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
     ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                           tvm.convert(['shared']),
                           tvm.expr.Call.Intrinsic, None, 0))
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
 
     with ib.for_range(0, axis_mul_before) as i:
         with ib.for_range(0, axis_mul_after) as j:
@@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             base_idx = i * shape[axis] * axis_mul_after + j
             # OddEvenTransposeSort
             with ib.for_range(0, current_sort_num) as k:
-                with ib.if_scope(tid < (current_sort_num + 1) // 2):
-                    offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
+                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
                     if is_ascend:
-                        cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
+                        cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
                                        values_out[offset] > values_out[offset + axis_mul_after])
                     else:
-                        cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
+                        cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
                                        values_out[offset] < values_out[offset + axis_mul_after])
                     with ib.if_scope(cond):
                         temp_data[0] = values_out[offset]
@@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
     temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
     is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     with ib.for_range(0, axis_mul_before) as i:
         with ib.for_range(0, axis_mul_after) as j:
             current_sort_num = valid_count[i * axis_mul_after + j]
@@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
                 output[base_idx + tid * axis_mul_after] = tid
             # OddEvenTransposeSort
             with ib.for_range(0, current_sort_num) as k:
-                with ib.if_scope(tid < (current_sort_num + 1) // 2):
-                    offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
+                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
                     with ib.if_scope(tvm.all(is_ascend == 1, \
-                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             2 * tid + idxm(k, 2) + 1 < current_sort_num, \
                                              data[offset] > data[offset + axis_mul_after])):
                         temp_data[0] = data[offset]
                         data[offset] = data[offset + axis_mul_after]
@@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
                         output[offset] = output[offset + axis_mul_after]
                         output[offset + axis_mul_after] = temp_index[0]
                     with ib.if_scope(tvm.all(is_ascend == 0, \
-                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             2 * tid + idxm(k, 2) + 1 < current_sort_num, \
                                              data[offset] < data[offset + axis_mul_after])):
                         temp_data[0] = data[offset]
                         data[offset] = data[offset + axis_mul_after]
index 03fa999..e1af436 100644 (file)
@@ -95,8 +95,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
 
             for k in range(num_sizes + num_ratios - 1):
                 w = if_then_else(k < num_sizes,
-                                 size_ratio_concat[k] * in_height / in_width / 2.0,
-                                 size_ratio_concat[0] * in_height / in_width *
+                                 float(size_ratio_concat[k]) * in_height / in_width / 2.0,
+                                 float(size_ratio_concat[0]) * in_height / in_width *
                                  math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                 h = if_then_else(
                     k < num_sizes, size_ratio_concat[k] / 2.0,
@@ -204,10 +204,12 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
 
     with ib.if_scope(tid < batch_size * num_anchors):
-        i = tid / num_anchors
-        j = tid % num_anchors
+        i = idxd(tid, num_anchors)
+        j = idxm(tid, num_anchors)
         valid_count[i] = 0
         score[tid] = -1.0
         cls_id[tid] = 0
@@ -314,9 +316,13 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     with ib.if_scope(tid < batch_size * num_anchors):
-        i = tid // num_anchors
-        j = tid % num_anchors
+        i = idxd(tid, num_anchors)
+        j = idxm(tid, num_anchors)
+
         with ib.if_scope(cls_id[tid] > 0):
             with ib.if_scope(tid == 0):
                 out_base_idx = i * num_anchors * 6
index 2faabf2..932c141 100644 (file)
@@ -313,13 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
                        axis=[ci, dh, dw, b1, b2])
 
     conv = tvm.compute(ovshape, _conv, name='conv_out')
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
 
     return tvm.compute(
         oshape, lambda n, co, h, w:
-        conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv(
-            w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
+        conv[n,
+             idxd(co, VC), idxd(h, VH), idxd(w, VW),
+             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
         name='conv_vec', tag='spatial_bitserial_conv_nchw')
 
 @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
@@ -419,12 +420,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
 
     conv = tvm.compute(ovshape, _conv, name='conv')
 
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
     return tvm.compute(
         oshape, lambda n, h, w, co:
-        conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv(
-            co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
+        conv[n,
+             idxd(h, VH), idxd(w, VW), idxd(co, VC),
+             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
         name='output_unpack', tag='spatial_bitserial_conv_nhwc')
 
 @tvm.target.generic_func
index 11116b2..584126e 100644 (file)
@@ -94,12 +94,15 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
         x_val = data[i, bs_c * block_j + c]
         return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
 
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     bsrmm_block = tvm.compute(
         (m, num_blocks, bs_r), _compute_block,
         tag="sparse_dense_bsrmm_block")
     return tvm.compute(
         (m, num_blocks * bs_r),
-        lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
+        lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
         tag="sparse_dense_bsrmm")
 
 @tvm.target.generic_func
index 6de916c..1bf3a10 100644 (file)
@@ -232,10 +232,12 @@ def unravel_index(idx, shape):
     indices : tuple of int or tvm.expr.IntImm
         Corresponding coordinate of the 1D index
     """
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
     indices = []
     for i in range(len(shape) - 1, -1, -1):
-        indices.append(idx % shape[i])
-        idx = idx // shape[i]
+        indices.append(idxm(idx, shape[i]))
+        idx = idxd(idx, shape[i])
     indices = indices[::-1]
     return indices
 
@@ -257,12 +259,13 @@ def const_matrix(matrix, name="const_matrix"):
     """
     row, col = matrix.shape
     dtype = str(matrix.dtype)
+    idxm = tvm.indexmod
 
     def select_array(i, j):
         now = tvm.const(0.0, dtype)
         for ii in range(row):
             for jj in range(col):
-                now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj),
+                now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj),
                                       tvm.const(matrix[ii][jj], dtype),
                                       now)
         return now
index ca1b4a9..135315b 100644 (file)
@@ -73,10 +73,10 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
             center_w = (j + offset_w) * steps_w
             for k in const_range(num_sizes + num_ratios - 1):
                 if k < num_sizes:
-                    w = sizes[k] * in_height / in_width / 2.0
+                    w = float32(sizes[k] * in_height) / in_width / 2.0
                     h = sizes[k] / 2.0
                 else:
-                    w = sizes[0] * in_height / in_width \
+                    w = float32(sizes[0] * in_height) / in_width \
                         * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
                     h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
                 count = i * in_width * (num_sizes + num_ratios - 1) \
index 3d0978c..6e36e93 100644 (file)
@@ -309,8 +309,15 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
 
     # packing the Filter to let memory access be consecutive for AVX512 intrinsic
     # Done in pre-compute stage
-    packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4)
-    PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e],
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
+    packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4)
+    PackW = tvm.compute(packw_shape,
+                        lambda a, b, c, d, e:
+                        Filter[a, b,
+                               c*16 + idxm(d, 16),
+                               idxd(d, 16) * 4 + e],
                         name="packed_filter")
 
     rc = tvm.reduce_axis((0, in_channel), name='rc')
@@ -321,7 +328,9 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
         lambda nn, yy, xx, ff: tvm.sum(
             PaddedInput[nn, yy * stride_h + ry * dilation_h,
                         xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]),
+            PackW[ry, rx, idxd(ff, 16),
+                  idxd(rc, 4) * 16 + idxm(ff, 16),
+                  idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]),
         name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
     return Output
 
index 0fb73ec..a23589a 100644 (file)
@@ -247,7 +247,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
 # We have to re-write the algorithm slightly.
 packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
 C = tvm.compute((M, N),
-                lambda x, y: tvm.sum(A[x, k] * packedB[y / bn, k, y % bn], axis=k),
+                lambda x, y: tvm.sum(A[x, k] * packedB[y // bn, k, tvm.indexmod(y, bn)], axis=k),
                 name = 'C')
 
 s = tvm.create_schedule(C.op)
index 06a1975..12ef7da 100644 (file)
@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
         Transformed statement
     """
     env = get_env()
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
     def _check_compact(buf):
         ndim = len(buf.shape)
         size = tvm.const(1, buf.shape[0].dtype)
@@ -369,7 +372,7 @@ def inject_dma_intrin(stmt_in):
             x_size = 1
             x_stride = buf.strides[ndim - base]
             next_base = base
-            if not util.equal_const_int(x_stride % elem_block, 0):
+            if not util.equal_const_int(idxm(x_stride, elem_block), 0):
                 raise RuntimeError(
                     "scope %s need to have block=%d, shape=%s, strides=%s" % (
                         scope, elem_block, buf.shape, buf.strides))
@@ -394,7 +397,7 @@ def inject_dma_intrin(stmt_in):
             raise RuntimeError("Expect buffer type to be %s instead of %s" %
                                (dtype, buf.dtype))
         shape, strides = buf.shape, buf.strides
-        if not util.equal_const_int(buf.elem_offset % elem_block, 0):
+        if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
             raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
         if allow_fold:
             shape, strides = _fold_buffer_dim(buf, scope, elem_block)
@@ -421,7 +424,7 @@ def inject_dma_intrin(stmt_in):
                 x_size = 1
                 x_stride = 1
                 y_size = 1
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
             if not util.equal_const_int(strides[-2] - elem_block, 0):
                 raise_error()
 
@@ -429,15 +432,15 @@ def inject_dma_intrin(stmt_in):
                 x_size = shape[-2]
                 x_stride = shape[-2]
                 y_size = 1
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
-            if not util.equal_const_int(strides[-3] % elem_block, 0):
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
                 raise_error()
 
             if ndim == 3:
                 x_size = shape[-2]
-                x_stride = strides[-3] / elem_block
+                x_stride = idxd(strides[-3], elem_block)
                 y_size = shape[-3]
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
 
         else:
             if not util.equal_const_int(strides[-1], 1):
@@ -451,7 +454,7 @@ def inject_dma_intrin(stmt_in):
                 x_size = 1
                 x_stride = 1
                 y_size = 1
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
             if not util.equal_const_int(strides[-3], elem_block):
                 raise_error()
 
@@ -459,15 +462,15 @@ def inject_dma_intrin(stmt_in):
                 x_size = shape[-3]
                 x_stride = shape[-3]
                 y_size = 1
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
-            if not util.equal_const_int(strides[-4] % elem_block, 0):
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
                 raise_error()
 
             if ndim == 4:
                 x_size = shape[-3]
-                x_stride = strides[-4] / elem_block
+                x_stride = idxd(strides[-4], elem_block)
                 y_size = shape[-4]
-                return x_size, y_size, x_stride, buf.elem_offset / elem_block
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
 
         raise_error()
 
@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
         Transformed statement
     """
     env = get_env()
+    idxm = tvm.indexmod
+
     def _do_fold(stmt):
         def _equal(x, y):
             return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
@@ -910,10 +915,10 @@ def inject_alu_intrin(stmt_in):
             assert len(extents) != 0
             assert tvm.ir_pass.Equal(
                 tvm.ir_pass.Simplify(
-                    src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
+                    idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
             assert tvm.ir_pass.Equal(
                 tvm.ir_pass.Simplify(
-                    dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
+                    idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
             assert tvm.ir_pass.Equal(src_coeff[-2], 1)
             assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
             if env.BATCH > 1: