[CUDA] Fix codegen for warp shuffle intrinsics (#5606)
authorShizhi Tang <rd0x01@gmail.com>
Mon, 18 May 2020 02:55:05 +0000 (10:55 +0800)
committerGitHub <noreply@github.com>
Mon, 18 May 2020 02:55:05 +0000 (19:55 -0700)
* fix shfl intrin

* improve test_lower_warp_memory_cuda_half_a_warp

src/target/source/intrin_rule_cuda.cc
tests/python/unittest/test_tir_transform_lower_warp_memory.py

index 4e4abd9..7ebcfa6 100644 (file)
@@ -116,7 +116,7 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
-  Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
+  Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
   const char* name = T()(call->dtype, call->name);
   *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
 }
index bd55377..c3cf289 100644 (file)
@@ -136,30 +136,32 @@ def test_lower_warp_memory_cuda_half_a_warp():
             print("Skip because gpu does not have fp16 support")
             return
 
-        m = 16
-        A = te.placeholder((m,), name='A', dtype=dtype)
-        B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')
+        n, m = 16, 16
+        A = te.placeholder((n, m,), name='A', dtype=dtype)
+        B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B')
 
         cuda_target = tvm.target.create("cuda")
         assert cuda_target.thread_warp_size == 2 * m
         with cuda_target:
             s = te.create_schedule(B.op)
             tx = te.thread_axis("threadIdx.x")
+            ty = te.thread_axis("threadIdx.y")
             bx = te.thread_axis("blockIdx.x")
 
             AA = s.cache_read(A, "warp", [B])
-            xo, xi = s[B].split(B.op.axis[0], nparts=1)
-            s[B].bind(xi, tx)
-            s[B].bind(xo, bx)
-            s[AA].compute_at(s[B], xo)
-            xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
-            s[AA].bind(xo, bx)
-            s[AA].bind(xi, tx)
+            y, x = B.op.axis
+            z, y = s[B].split(y, nparts=2)
+            s[B].bind(x, tx)
+            s[B].bind(y, ty)
+            s[B].bind(z, bx)
+            s[AA].compute_at(s[B], y)
+            _, x = AA.op.axis
+            s[AA].bind(x, tx)
 
             ctx = tvm.gpu(0)
             func = tvm.build(s, [A, B], "cuda")
-            A_np = np.array(list(range(m)), dtype=dtype)
-            B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
+            A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
+            B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
             A_nd = tvm.nd.array(A_np, ctx)
             B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
             func(A_nd, B_nd)