bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+
+def test_local_stage_predicate():
+ m = 1
+ n = 3
+ p = 2
+ A = tvm.te.placeholder((m, n, p), name='A')
+ B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
+ C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
+ by = tvm.te.thread_axis("blockIdx.y")
+ tx = tvm.te.thread_axis("threadIdx.x")
+ vx = tvm.te.thread_axis("vthread")
+
+ def schedule(thread_tag, mem_scope) :
+ s = tvm.te.create_schedule(C.op)
+ s[B].compute_at(s[C], s[C].op.axis[0])
+ s[B].set_scope(mem_scope)
+ bno, bni = s[B].split(s[B].op.axis[1], n)
+ bx = tvm.te.thread_axis("blockIdx.x")
+ s[C].bind(s[C].op.axis[0], bx)
+ s[C].bind(s[C].op.axis[1], thread_tag)
+ s[B].bind(bni, thread_tag)
+ return s
+
+ def collect_visit(stmt, f):
+ ret = []
+ tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+ return ret
+ # local vs. threadIdx
+ s = schedule(tx, "local")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ # local vs. vthread
+ s = schedule(vx, "local")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ # shared vs. blockIdx
+ s = schedule(by, "shared")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
+def test_local_stage_predicate2():
+ A = tvm.te.placeholder((128, ), name="A")
+ B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
+ C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
+ s = tvm.te.create_schedule(C.op)
+ AA = s.cache_read(A, "local", [B])
+ s[B].set_scope("shared")
+ block_x = tvm.te.thread_axis("blockIdx.x")
+ thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x")
+ oc, ic = s[C].split(s[C].op.axis[0], factor=64)
+ ooc, ioc = s[C].split(oc, factor=2)
+ oic, iic = s[C].split(ic, factor=32)
+ s[C].bind(ooc, block_x)
+ s[C].bind(iic, thread_x)
+ s[B].compute_at(s[C], ioc)
+ ob, ib = s[B].split(s[B].op.axis[0], factor=32)
+ s[B].bind(ib, thread_x)
+ s[AA].compute_root()
+ s[AA].compute_at(s[C], ooc)
+ oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
+ s[AA].bind(iaa, thread_x)
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+
+ def collect_visit(stmt, f):
+ ret = []
+ tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+ return ret
+
+ def visit_stmt(op):
+ print(op)
+ if (isinstance(op, tvm.tir.Allocate)):
+ return op.extents[0].value == 97
+ return False
+
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ assert (any(collect_visit(lowered_body, visit_stmt)))
+
+
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
+ test_local_stage_predicate()
+ test_local_stage_predicate2()
\ No newline at end of file