Create loops according to storage scope and thread hierarchies (#5190)
authoryongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com>
Fri, 10 Apr 2020 01:49:37 +0000 (21:49 -0400)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 01:49:37 +0000 (18:49 -0700)
* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.

src/te/operation/op_util.cc
tests/python/unittest/test_te_schedule_ops.py

index 3714f43..4ecfe94 100644 (file)
@@ -29,6 +29,7 @@
 #include "op_util.h"
 #include "../schedule/message_passing.h"
 #include "../../arith/compute_expr.h"
+#include "../../runtime/thread_storage_scope.h"
 
 namespace tvm {
 namespace te {
@@ -162,7 +163,13 @@ MakeLoopNest(const Stage& stage,
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         value_map[iv] = dom->min;
       } else {
-        value_map[iv] = var;
+        runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
+        if (stage->scope == "" || stage->scope == "warp" ||
+            static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
+          value_map[iv] = var;
+        } else {
+          value_map[iv] = dom->min;
+        }
       }
     }
     // annotate the extent of the IterVar
index 8d10cee..4e27ad3 100644 (file)
@@ -482,6 +482,92 @@ def test_schedule_compute_inline():
     bounds = tvm.te.schedule.InferBound(s)
     stmt = tvm.te.schedule.ScheduleOps(s, bounds)
 
+
+def test_local_stage_predicate():
+    m = 1
+    n = 3
+    p = 2
+    A = tvm.te.placeholder((m, n, p), name='A')
+    B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
+    C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
+    by = tvm.te.thread_axis("blockIdx.y")
+    tx = tvm.te.thread_axis("threadIdx.x")
+    vx = tvm.te.thread_axis("vthread")
+
+    def schedule(thread_tag, mem_scope) :
+        s = tvm.te.create_schedule(C.op)
+        s[B].compute_at(s[C], s[C].op.axis[0])
+        s[B].set_scope(mem_scope)
+        bno, bni = s[B].split(s[B].op.axis[1], n)
+        bx = tvm.te.thread_axis("blockIdx.x")
+        s[C].bind(s[C].op.axis[0], bx)
+        s[C].bind(s[C].op.axis[1], thread_tag)
+        s[B].bind(bni, thread_tag)
+        return s
+
+    def collect_visit(stmt, f):
+        ret = []
+        tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+        return ret
+    # local vs. threadIdx
+    s = schedule(tx, "local")
+    lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+    assert (not any(
+        collect_visit(lowered_body,
+                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    # local vs. vthread
+    s = schedule(vx, "local")
+    lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+    assert (not any(
+        collect_visit(lowered_body,
+                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    # shared vs. blockIdx
+    s = schedule(by, "shared")
+    lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+    assert (not any(
+        collect_visit(lowered_body,
+                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
+def test_local_stage_predicate2():
+    A = tvm.te.placeholder((128, ), name="A")
+    B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
+    C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
+    s = tvm.te.create_schedule(C.op)
+    AA = s.cache_read(A, "local", [B])
+    s[B].set_scope("shared")
+    block_x = tvm.te.thread_axis("blockIdx.x")
+    thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x")
+    oc, ic = s[C].split(s[C].op.axis[0], factor=64)
+    ooc, ioc = s[C].split(oc, factor=2)
+    oic, iic = s[C].split(ic, factor=32)
+    s[C].bind(ooc, block_x)
+    s[C].bind(iic, thread_x)
+    s[B].compute_at(s[C], ioc)
+    ob, ib = s[B].split(s[B].op.axis[0], factor=32)
+    s[B].bind(ib, thread_x)
+    s[AA].compute_root()
+    s[AA].compute_at(s[C], ooc)
+    oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
+    s[AA].bind(iaa, thread_x)
+    lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+
+    def collect_visit(stmt, f):
+        ret = []
+        tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+        return ret
+
+    def visit_stmt(op):
+        print(op)
+        if (isinstance(op, tvm.tir.Allocate)):
+            return op.extents[0].value == 97
+        return False
+
+    assert (not any(
+        collect_visit(lowered_body,
+                      lambda x: isinstance(x, tvm.tir.IfThenElse))))
+    assert (any(collect_visit(lowered_body, visit_stmt)))
+
+
 if __name__ == "__main__":
     test_loop_dep_reduce()
     test_loop_dep_reduce_cache_write()
@@ -506,3 +592,5 @@ if __name__ == "__main__":
     test_schedule_tensor_compute3()
     test_reduction_and_dummy_fuse_split()
     test_schedule_compute_inline()
+    test_local_stage_predicate()
+    test_local_stage_predicate2()
\ No newline at end of file