[Reduction] Fix cross thread redunction (#5551)
authorWei Pan <60017475+wpan11nv@users.noreply.github.com>
Fri, 15 May 2020 21:28:19 +0000 (14:28 -0700)
committerGitHub <noreply@github.com>
Fri, 15 May 2020 21:28:19 +0000 (14:28 -0700)
- The predictions were not correctly applied after transformation.
  This leads to normal reduction itervar appearing outside of the loop,
  which is undefined. See detailed comments.

Signed-off-by: Wei Pan <weip@nvidia.com>
src/te/operation/cross_thread_reduction.cc
tests/python/unittest/test_target_codegen_cuda.py

index 0905631..cdcb124 100644 (file)
@@ -28,6 +28,56 @@ namespace tvm {
 namespace te {
 using namespace tir;
 
+//
+// Cross thread reduction transformation.
+//
+// The input loop nest in generic form (single reduction/thread case)
+//
+// let m be the reduction extent
+// let N be the thread extent
+// let input_pred be the predicate on the reduction
+//
+// B[..] = 0
+// for (tid, 0, N)
+//   for (i, 0, floordiv(m+N-1, N))
+//     if (i + tid * floordiv(m+N-1, N) < m)
+//       if (input_pred)
+//         B[..] = op(B[..], A[i + tid  * floordiv(m+N-1,N)])
+//
+// The threaded reduction looks like
+//
+// (1) normal reductions (leaves)
+// for (i, 0, floordiv(m+N-1, N))
+//   if (i + tid * floordiv(m+N-1, N) < m)
+//     if (input_pred)
+//       B_temp[0] = op(B_temp[0], A[i + tid  * floordiv(m+N-1,N)])
+//
+// (2) threaded reduction does not require predicates as an identity
+//     element will be filled if out of bounds.
+//
+// tvm_thread_allreduce(size, B_temp, (bool)1, tid)
+//
+// The last step is to write the final reduction variable,
+// which should be predicated by the existing input_pred if any
+// The consequence is that input_pred should be independent of
+// the reduction axis. Otherwise, we need to seperate it into
+// dependent part and independent one.
+//
+// (3) write back
+// if (input_pred)
+//    B[..] = B_temp[0]
+//
+// In summary, we are going to need two predicates
+//
+// * the original input_pred from reduction itself
+//
+// * the normal reduction axis predicate
+//     normal_pred = (i + tid * floordiv(m+N-1,N)) < m
+//   this predicate depends on the normal reduction variable.
+//
+// input_pred will be applied to both normal reduction and
+// the writeback step.
+//
 Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
                               const std::unordered_map<IterVar, Range>& dom_map,
                               bool debug_keep_trivial_loop) {
@@ -38,7 +88,6 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   std::unordered_map<IterVar, PrimExpr> value_map;
   auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map,
                            debug_keep_trivial_loop);
-  auto conds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
 
   size_t size = self->body.size();
   CHECK_GT(size, 0);
@@ -48,10 +97,17 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
     CHECK(reduce);
     reduces[i] = reduce;
   }
-  PrimExpr cond = reduces[0]->condition;
-  for (PrimExpr v : conds) {
-    cond = cond && v;
-  }
+
+  // This computes the bound checking predicates in normal reduction.
+  auto normal_preds =
+      MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
+
+  // normal_pred = input_pred && normal_pred
+  PrimExpr input_pred = reduces[0]->condition;
+  normal_preds.push_back(input_pred);
+  normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(),
+                                    [](const PrimExpr& e) { return !e.defined(); }),
+                     normal_preds.end());
 
   std::vector<std::vector<Stmt>> common, normal_red;
   for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
@@ -109,7 +165,10 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
       freduce_args.push_back(reduces[0]->source[i]);
     }
   }
-  freduce_args.push_back(cond);
+
+  // No constraints on the thread reduction step. It may have redundent
+  // computation for rare cases. TODO(tvm-team): revisit this.
+  freduce_args.push_back(const_true(1));
   std::vector<Var> res_handles(size);
   for (size_t idx = 0; idx < size; ++idx) {
     res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
@@ -125,12 +184,16 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
       }
     }
   }
+
   // Checks for the thread.
-  std::vector<PrimExpr> thread_head_check;
+  std::vector<PrimExpr> output_preds;
   if (stage->store_predicate.defined()) {
-    thread_head_check.emplace_back(stage->store_predicate);
+    output_preds.emplace_back(stage->store_predicate);
   }
 
+  // Apply the existing input predicate if any.
+  output_preds.push_back(input_pred);
+
   Stmt reduce_body = EvaluateNode::make(CallNode::make(
       DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic));
   reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope,
@@ -139,9 +202,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   if (!normal_red.empty()) {
     Stmt init_body = SeqStmt::Flatten(normal_init);
     Stmt update_body = SeqStmt::Flatten(normal_update);
+    update_body = MergeNest(MakeIfNest(normal_preds), update_body);
     update_body = MergeNest(normal_red, update_body);
     reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
-    reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
   }
 
   std::vector<Stmt> assigns(size);
@@ -151,8 +214,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
         stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
   }
   Stmt assign_body = SeqStmt::Flatten(assigns);
-  assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body);
-  assign_body = MergeNest(MakeIfNest(conds), assign_body);
+  assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
   Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
   for (size_t idx = size; idx != 0; --idx) {
     body =
index 50705e8..9692058 100644 (file)
@@ -227,6 +227,85 @@ def test_cuda_shuffle():
         module(nda, ndb, ndc)
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
 
+def test_crossthread_reduction1():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    n = te.var("n")
+    m = te.var("m")
+    A = te.placeholder((n, m), name='A')
+    k = te.reduce_axis((0, m), "m")
+    B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+
+    def sched(nthd):
+        s = te.create_schedule(B.op)
+        ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd)
+        s[B].bind(ko, te.thread_axis("threadIdx.x"))
+        s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+        func = tvm.build(s, [A, B], "cuda")
+        return func
+
+    def verify(nthd):
+        func = sched(nthd)
+        nn = 3
+        # checks three typical cases
+        vals = [nthd-1, nthd, nthd+1]
+        for kk in [x for x in vals]:
+            size = (nn, kk)
+            ctx = tvm.context("cuda", 0)
+            a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+            b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+            func(a, b)
+            tvm.testing.assert_allclose(b.asnumpy(), \
+                np.sum(a.asnumpy(), axis=1), rtol=1e-3)
+
+    verify(16)
+    verify(32)
+    verify(64)
+
+def test_crossthread_reduction2():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    n = te.var("n")
+    k0 = te.var("k0")
+    k1 = te.var("k1")
+    A = te.placeholder((n, k0, k1), name='A')
+    k0 = te.reduce_axis((0, k0), "k0")
+    k1 = te.reduce_axis((0, k1), "k1")
+    B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
+
+    def sched(nthdx, nthdy):
+        s = te.create_schedule(B.op)
+        k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx)
+        k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy)
+        s[B].bind(k0o, te.thread_axis("threadIdx.x"))
+        s[B].bind(k1o, te.thread_axis("threadIdx.y"))
+        s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+        func = tvm.build(s, [A, B], "cuda")
+        return func
+
+    def verify(nthdx, nthdy):
+        func = sched(nthdx, nthdy)
+        nn = 3
+        # checks three typical cases
+        vx = [nthdx-1, nthdx, nthdx+1]
+        vy = [nthdy-1, nthdy, nthdy+1]
+        for kk0, kk1 in [(x, y) for x in vx for y in vy]:
+            size = (nn, kk0, kk1)
+            ctx = tvm.context("cuda", 0)
+            a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+            b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+            func(a, b)
+            tvm.testing.assert_allclose(b.asnumpy(), \
+                np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
+
+    verify(16, 16)
+    verify(32, 32)
+    verify(16, 32)
+    verify(32, 16)
 
 def test_cuda_reducition_binding():
     if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
@@ -609,6 +688,8 @@ if __name__ == "__main__":
     test_cuda_shuffle()
     test_vectorized_casts()
     test_cuda_reducition_binding()
+    test_crossthread_reduction1()
+    test_crossthread_reduction2()
     test_rfactor_predicates()
     test_cuda_const_float_to_half()
     test_cuda_reduction()