[TE] Fix MakeLoopNest for warp memory (#5382)
authorTang, Shizhi <rd0x01@gmail.com>
Sat, 9 May 2020 00:52:30 +0000 (08:52 +0800)
committerGitHub <noreply@github.com>
Sat, 9 May 2020 00:52:30 +0000 (17:52 -0700)
src/te/operation/op_util.cc
tests/python/unittest/test_tir_transform_lower_warp_memory.py

index f7e0e51..bee573e 100644 (file)
@@ -163,9 +163,21 @@ MakeLoopNest(const Stage& stage,
         value_map[iv] = dom->min;
       } else {
         runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
-        if (stage->scope == "" || stage->scope == "warp" ||
+        if (stage->scope == "" ||
             static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
           value_map[iv] = var;
+        } else if (stage->scope == "warp" && ts.rank == 1) {
+          // To determine whether a thread index is inside or outside a warp, we need
+          // to know the thread extent. We leave a warning for now.
+          if (ts.dim_index == 0) {
+            value_map[iv] = var;
+          } else {
+            LOG(WARNING)
+              << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
+              << "TVM assumes only threadIdx.x indicates threads inside a warp, "
+              << "while threadIdx.y and threadIdx.z indicates different warps.";
+            value_map[iv] = dom->min;
+          }
         } else {
           value_map[iv] = dom->min;
         }
index 51be480..bd55377 100644 (file)
@@ -47,6 +47,42 @@ def test_lower_warp_memory_local_scope():
     assert(fdevice.body.body.value.value == "local")
     assert(fdevice.body.body.body.extents[0].value == 2)
 
+def test_lower_warp_memory_correct_indices():
+    n = 32
+    A = te.placeholder((2, n, n), name='A', dtype="float32")
+    C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')
+
+    s = te.create_schedule(C.op)
+    bk_x = te.thread_axis("blockIdx.x")
+    th_y = te.thread_axis("threadIdx.y")
+    th_x = te.thread_axis("threadIdx.x")
+    B = s.cache_read(A, "warp", [C])
+    cx, ci, cj = C.op.axis
+    bx, bi, bj = B.op.axis
+    s[C].bind(cj, th_x)
+    s[C].bind(cx, bk_x)
+    s[B].compute_at(s[C], cx)
+    s[B].bind(bi, th_y)
+    s[B].bind(bj, th_x)
+
+    bounds = tvm.te.schedule.InferBound(s)
+    ir = tvm.te.schedule.ScheduleOps(s, bounds)
+    inner_func = ir.body.body.body.body
+    store_A_warp = inner_func.body.seq[0].body.body
+    indices = list(store_A_warp.args)
+
+    # A.warp is actually many buffers, one for each warp, although they are all called A.warp
+    # 1. If we are accessing from different threads within a same warp (different
+    #    threadIdx.x), we need to distinguish between each elements using threadIdx.x,
+    #    so threadIdx.x is one if the indices.
+    # 2. If we are accessing from different warps (different threadIdx.y), we are actually
+    #    assessing different buffers, so there is no need to distinguish from elements,
+    #    and therefore threadIdx.y is NOT a index.
+    idx_names = map(lambda x: x.name,
+            filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
+    assert "threadIdx.x" in idx_names
+    assert "threadIdx.y" not in idx_names
+
 def test_lower_warp_memory_cuda_end_to_end():
     def check_cuda(dtype):
         if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
@@ -182,6 +218,7 @@ def test_lower_warp_memory_cuda_2_buffers():
 
 if __name__ == "__main__":
     test_lower_warp_memory_local_scope()
+    test_lower_warp_memory_correct_indices()
     test_lower_warp_memory_cuda_end_to_end()
     test_lower_warp_memory_cuda_half_a_warp()
     test_lower_warp_memory_cuda_2_buffers()