Don't replace reduction init axis with new axis if bound to a thread. (#3408)
authorChristian Sarofeen <csarofeen@nvidia.com>
Mon, 12 Aug 2019 21:11:11 +0000 (17:11 -0400)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 12 Aug 2019 21:11:11 +0000 (14:11 -0700)
* Don't replace reduction init axis with new axis if bound to a thread.

* Linter.

* Reduce bind test case.

* Guard test on CUDA support.

* [CUDA TE TESTS] Add rfactor predicate test, add global bx and tx.

* [CUDA TE TESTS] Add loop partition test for simple rfactor case.

src/op/op_util.cc
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_pass_loop_partition.py

index 668408a..801f4fa 100644 (file)
@@ -69,11 +69,14 @@ MakeLoopNest(const Stage& stage,
 
     // initialize the offset and loop_level
     Var var = bind_iv->var;
-    if (new_loop_var) {
-      var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
-    }
+
     // Mark the iter var in the IR, to remember the point
     if (bind_iv->thread_tag.length() == 0) {
+      // Only generate new loop if we're not bound to a thread.
+      if (new_loop_var) {
+        var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
+      }
+
       ForType for_type = ForType::Serial;
       IterVarAttr it_attr;
       if (stage->iter_var_attrs.count(iv)) {
index e8439de..1fb9c0a 100644 (file)
@@ -1,3 +1,4 @@
+
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -19,6 +20,10 @@ import numpy as np
 from tvm.contrib.nvcc import have_fp16, have_int8
 from tvm.contrib import nvcc
 
+tx = tvm.thread_axis("threadIdx.x")
+bx = tvm.thread_axis("blockIdx.x")
+
+
 def test_cuda_vectorize_add():
     num_thread = 8
     def check_cuda(dtype, n, lanes):
@@ -35,8 +40,8 @@ def test_cuda_vectorize_add():
         B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
         s = tvm.create_schedule(B.op)
         xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
-        s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
-        s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
+        s[B].bind(xo, bx)
+        s[B].bind(xi, tx)
         fun = tvm.build(s, [A, B], "cuda")
         ctx = tvm.gpu(0)
         a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
@@ -65,8 +70,8 @@ def test_cuda_multiply_add():
                         lambda i: tvm.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name='D')
         s = tvm.create_schedule(D.op)
         xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
-        s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
-        s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
+        s[D].bind(xo, bx)
+        s[D].bind(xi, tx)
         fun = tvm.build(s, [A, B, C, D], "cuda")
         np_a = np.random.randint(low=-128, high=127, size=(n,lanes))
         np_b = np.random.randint(low=-128, high=127, size=(n,lanes))
@@ -91,9 +96,9 @@ def test_cuda_vectorize_load():
         A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
         B = tvm.compute((n,), lambda i: A[i], name='B')
         s = tvm.create_schedule(B.op)
-        bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
-        s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
-        s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+        block, thread = s[B].split(B.op.axis[0], factor=num_thread)
+        s[B].bind(block, bx)
+        s[B].bind(thread, tx)
         fun = tvm.build(s, [A, B], "cuda", name="vector_load")
         np_a = np.random.randint(low=-128, high=127, size=(n,lanes))
         a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a)
@@ -115,7 +120,7 @@ def test_cuda_make_int8x4():
         s = tvm.create_schedule(A.op)
         y, x = s[A].op.axis
         s[A].vectorize(x)
-        s[A].bind(y, tvm.thread_axis("blockIdx.x"))
+        s[A].bind(y, bx)
         fun = tvm.build(s, [A], "cuda", name="make_int8x4")
         np_a = np.full((n, lanes), value, dtype=dtype)
         a = tvm.nd.empty(np_a.shape, dtype, ctx)
@@ -133,7 +138,7 @@ def test_cuda_inf_nan():
         inf_value = tvm.const(value, dtype=dtype)
         C = tvm.compute((n,), lambda i: inf_value, name='C')
         s = tvm.create_schedule(C.op)
-        s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
+        s[C].bind(s[C].op.axis[0], tx)
         fun = tvm.build(s, [A, C], target)
         a = tvm.nd.empty((n,), A.dtype, ctx)
         c = tvm.nd.empty((n,), A.dtype, ctx)
@@ -197,6 +202,61 @@ def test_cuda_shuffle():
         module(nda, ndb, ndc)
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
 
+
+def test_cuda_reducition_binding():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    k = tvm.reduce_axis((0, 32), 'k')
+    A = tvm.placeholder((96, 32), name='A')
+    B = tvm.compute( (96,), lambda m:
+                     tvm.sum(A[m, k], axis=k),
+                     name='B')
+    s = tvm.create_schedule(B.op)
+
+    s[B].reorder(B.op.reduce_axis[0], B.op.axis[0])
+
+    mo, _ = s[B].split(B.op.axis[0], 32)
+    s[B].bind(mo, tvm.thread_axis("blockIdx.x"))
+
+    fcuda = tvm.build(s, [A, B], "cuda")
+
+def test_rfactor_predicates():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    n = tvm.reduce_axis((0, 129), 'n')
+    A = tvm.placeholder((129,), name='A')
+    B = tvm.compute( (1, ), lambda b:
+                     tvm.sum(A[n],
+                             axis=n),
+                     name='B'
+    )
+
+    s = tvm.create_schedule(B.op)
+
+    _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8)
+
+    BF = s.rfactor(B, ni, 0)
+    s[B].set_store_predicate(tx.var.equal(0))
+
+    s[B].bind(s[B].op.reduce_axis[0], tx)
+    s[B].bind(s[B].op.axis[0], bx)
+
+    s[BF].compute_at(s[B], s[B].op.axis[0])
+
+    _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2)
+
+    BF2 = s.rfactor(BF, noi, 0)
+
+    s[BF].bind(s[BF].op.axis[0], tx)
+    s[BF2].compute_at(s[BF], s[BF].op.axis[1])
+
+    fcuda = tvm.build(s, [A, B], "cuda")
+
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
@@ -204,3 +264,5 @@ if __name__ == "__main__":
     test_cuda_make_int8x4()
     test_cuda_inf_nan()
     test_cuda_shuffle()
+    test_cuda_reducition_binding()
+    test_rfactor_predicates()
index 85cb9b9..eb11d76 100644 (file)
@@ -384,6 +384,34 @@ def test_double_splitting_with_indivisible_factors():
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5)
     tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5)
 
+def test_simple_rfactor():
+    K = 16*4+4
+    k = tvm.reduce_axis((0, K), 'k')
+
+    A = tvm.placeholder((1, K), name='A')
+
+    B = tvm.compute( (1,), lambda b:
+            tvm.sum(A[b, k], axis=k),
+            name='B'
+    )
+
+    s = tvm.create_schedule(B.op)
+    ko, _ = s[B].split(s[B].op.reduce_axis[0], 16)
+    BF = s.rfactor(B, ko, 0)
+
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+
+    stmt1 = tvm.schedule.ScheduleOps(s, bounds)
+    stmt1 = tvm.ir_pass.Simplify(stmt1)
+
+    stmt2 = tvm.ir_pass.LoopPartition(stmt1, True)
+    stmt2 = tvm.ir_pass.Simplify(stmt2)
+
+    #make sure loop partition actually did something
+    assert not tvm.ir_pass.Equal(stmt1.body, stmt2.body)
+
+
 if __name__ == "__main__":
     test_basic()
     test_const_loop()
@@ -402,3 +430,4 @@ if __name__ == "__main__":
     test_cce_loop_3()
     test_conv_tiling()
     test_double_splitting_with_indivisible_factors()
+    test_simple_rfactor()