fix lower_warp_memory (#5247)
authorTang, Shizhi <rd0x01@gmail.com>
Mon, 6 Apr 2020 15:43:38 +0000 (23:43 +0800)
committerGitHub <noreply@github.com>
Mon, 6 Apr 2020 15:43:38 +0000 (08:43 -0700)
src/tir/transforms/lower_warp_memory.cc
tests/python/unittest/test_tir_transform_lower_warp_memory.py

index 0361100..1921db5 100644 (file)
@@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
  protected:
-  PrimExpr Mutate_(const VarNode* op) {
+  PrimExpr VisitExpr_(const VarNode* op) override {
     CHECK(op != buffer_)
         << "Cannot access address of warp memory directly";
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const StoreNode* op) {
+  Stmt VisitStmt_(const StoreNode* op) override {
     if (op->buffer_var.get() == buffer_) {
       PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
@@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
     }
   }
 
-  PrimExpr Mutate_(const LoadNode* op) {
+  PrimExpr VisitExpr_(const LoadNode* op) override {
     if (op->buffer_var.get() == buffer_) {
       PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
index cf6ef72..25204eb 100644 (file)
 # under the License.
 import tvm
 from tvm import te
+from tvm.contrib.nvcc import have_fp16
 
-def test_lower_warp_mem():
+import numpy as np
+
+def test_lower_warp_memory_local_scope():
     m = 128
     A = te.placeholder((m,), name='A')
     B = te.compute((m,), lambda i: A[i] + 3, name='B')
@@ -44,6 +47,50 @@ def test_lower_warp_mem():
     assert(fdevice.body.body.value.value == "local")
     assert(fdevice.body.body.body.extents[0].value == 2)
 
+def test_lower_warp_memory_cuda_end_to_end():
+    def check_cuda(dtype):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        m = 128
+        A = te.placeholder((m,), name='A', dtype=dtype)
+        B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')
+
+        cuda_target = tvm.target.create("cuda")
+        assert cuda_target.thread_warp_size == 32
+        with cuda_target:
+            s = te.create_schedule(B.op)
+            AA = s.cache_read(A, "warp", [B])
+            xo, xi = s[B].split(B.op.axis[0], 64)
+            xi0, xi1 = s[B].split(xi, factor=32)
+            tx = te.thread_axis("threadIdx.x")
+            s[B].bind(xi1, tx)
+            s[B].bind(xo, te.thread_axis("blockIdx.x"))
+            s[AA].compute_at(s[B], xo)
+            xo, xi = s[AA].split(s[AA].op.axis[0], 32)
+            s[AA].bind(xi, tx)
+
+            ctx = tvm.gpu(0)
+            func = tvm.build(s, [A, B], "cuda")
+            A_np = np.array(list(range(m)), dtype=dtype)
+            B_np = np.array(
+                    list(range(1, 32)) + [0] +
+                    list(range(33, 64)) + [32] +
+                    list(range(65, 96)) + [64] +
+                    list(range(97, 128)) + [96],
+                    dtype=dtype)
+            A_nd = tvm.nd.array(A_np, ctx)
+            B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
+            func(A_nd, B_nd)
+            tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)
+
+    check_cuda("float32")
+    check_cuda("float16")
 
 if __name__ == "__main__":
-    test_lower_warp_mem()
+    test_lower_warp_memory_local_scope()
+    test_lower_warp_memory_cuda_end_to_end()