value_map[iv] = dom->min;
} else {
runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
- if (stage->scope == "" || stage->scope == "warp" ||
+ if (stage->scope == "" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
value_map[iv] = var;
+ } else if (stage->scope == "warp" && ts.rank == 1) {
+ // To determine whether a thread index is inside or outside a warp, we need
+ // to know the thread extent. We leave a warning for now.
+ if (ts.dim_index == 0) {
+ value_map[iv] = var;
+ } else {
+ LOG(WARNING)
+ << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
+ << "TVM assumes only threadIdx.x indicates threads inside a warp, "
+ << "while threadIdx.y and threadIdx.z indicates different warps.";
+ value_map[iv] = dom->min;
+ }
} else {
value_map[iv] = dom->min;
}
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
+def test_lower_warp_memory_correct_indices():
+ n = 32
+ A = te.placeholder((2, n, n), name='A', dtype="float32")
+ C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')
+
+ s = te.create_schedule(C.op)
+ bk_x = te.thread_axis("blockIdx.x")
+ th_y = te.thread_axis("threadIdx.y")
+ th_x = te.thread_axis("threadIdx.x")
+ B = s.cache_read(A, "warp", [C])
+ cx, ci, cj = C.op.axis
+ bx, bi, bj = B.op.axis
+ s[C].bind(cj, th_x)
+ s[C].bind(cx, bk_x)
+ s[B].compute_at(s[C], cx)
+ s[B].bind(bi, th_y)
+ s[B].bind(bj, th_x)
+
+ bounds = tvm.te.schedule.InferBound(s)
+ ir = tvm.te.schedule.ScheduleOps(s, bounds)
+ inner_func = ir.body.body.body.body
+ store_A_warp = inner_func.body.seq[0].body.body
+ indices = list(store_A_warp.args)
+
+ # A.warp is actually many buffers, one for each warp, although they are all called A.warp
+ # 1. If we are accessing from different threads within a same warp (different
+ # threadIdx.x), we need to distinguish between each elements using threadIdx.x,
+ # so threadIdx.x is one if the indices.
+ # 2. If we are accessing from different warps (different threadIdx.y), we are actually
+ # assessing different buffers, so there is no need to distinguish from elements,
+ # and therefore threadIdx.y is NOT a index.
+ idx_names = map(lambda x: x.name,
+ filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
+ assert "threadIdx.x" in idx_names
+ assert "threadIdx.y" not in idx_names
+
def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
if __name__ == "__main__":
test_lower_warp_memory_local_scope()
+ test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()