[TE] Support mixing normal and cross-thread reduction (#5193)
authorTang, Shizhi <beantang.tang@gmail.com>
Sat, 4 Apr 2020 01:49:56 +0000 (09:49 +0800)
committerGitHub <noreply@github.com>
Sat, 4 Apr 2020 01:49:56 +0000 (18:49 -0700)
* Support mixing normal and cross-thread reduction

* minor improvements

src/te/operation/compute_op.cc
src/te/operation/cross_thread_reduction.cc
tests/python/unittest/test_target_codegen_cuda.py

index 6123c61..6f703c9 100644 (file)
@@ -443,8 +443,6 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
         << "Cannot mix cross thread reduction with Tensorize";
     return ComputeType::kTensorize;
   }
-  CHECK(normal_red == 0 || thread_red == 0)
-      << "Cannot mix normal reduction with thread reduce";
   if (thread_red != 0) {
     return ComputeType::kCrossThreadReduction;
   } else {
index 705d231..1b3d87d 100644 (file)
@@ -57,10 +57,63 @@ Stmt MakeCrossThreadReduction(
   for (PrimExpr v : conds) {
     cond = cond && v;
   }
+
+  std::vector<std::vector<Stmt>> common, normal_red;
+  for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
+    IterVar iv = stage->leaf_iter_vars[i];
+    IterVarAttr attr;
+    auto it = stage->iter_var_attrs.find(iv);
+    if (it != stage->iter_var_attrs.end()) {
+      attr = (*it).second;
+    }
+    if (iv->iter_type == kCommReduce) {
+      if (attr.defined() && attr->bind_thread.defined()) {
+        common.emplace_back(nest[i + 1]);
+      } else {
+        normal_red.emplace_back(nest[i + 1]);
+      }
+    } else {
+      common.emplace_back(nest[i + 1]);
+    }
+  }
+
+  // If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
+  // something goes wrong, so we use an extra variable here for normal reduction.
+  std::vector<Var> normal_res_handles;
+  std::vector<Stmt> normal_init, normal_update;
+  if (!normal_red.empty()) {
+    normal_res_handles.reserve(size);
+    normal_init.reserve(size);
+    normal_update.resize(size);
+    const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
+    CHECK(combiner);
+    Array<PrimExpr> lhs;
+    for (size_t i = 0; i < size; ++i) {
+      DataType t = reduces[i]->dtype;
+      normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
+      lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes())));
+    }
+    Array<PrimExpr> init_value = combiner->identity_element;
+    Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
+    for (size_t i = 0; i < size; ++i) {
+      DataType t = reduces[i]->dtype;
+      normal_init.emplace_back(StoreNode::make(
+            normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
+      normal_update.emplace_back(StoreNode::make(
+            normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+    }
+  }
+
   Array<PrimExpr> freduce_args;
   freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
   for (size_t i = 0; i < size; ++i) {
-    freduce_args.push_back(reduces[0]->source[i]);
+    if (!normal_red.empty()) {
+      DataType t = reduces[i]->dtype;
+      freduce_args.push_back(LoadNode::make(
+            t, normal_res_handles[i], 0, const_true(t.lanes())));
+    } else {
+      freduce_args.push_back(reduces[0]->source[i]);
+    }
   }
   freduce_args.push_back(cond);
   std::vector<Var> res_handles(size);
@@ -94,6 +147,15 @@ Stmt MakeCrossThreadReduction(
       tir::attr::reduce_scope,
       make_zero(DataType::Handle()),
       reduce_body);
+
+  if (!normal_red.empty()) {
+    Stmt init_body = SeqStmt::Flatten(normal_init);
+    Stmt update_body = SeqStmt::Flatten(normal_update);
+    update_body = MergeNest(normal_red, update_body);
+    reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
+    reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
+  }
+
   std::vector<Stmt> assigns(size);
   for (size_t idx = 0; idx < size; ++idx) {
     DataType t = reduces[idx]->dtype;
@@ -110,9 +172,15 @@ Stmt MakeCrossThreadReduction(
       res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
     body = AttrStmtNode::make(
       res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+    if (!normal_red.empty()) {
+      body = AllocateNode::make(
+        normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+      body = AttrStmtNode::make(
+        normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+    }
   }
   body = Substitute(body, value_map);
-  return MergeNest(nest, body);
+  return MergeNest(common, body);
 }
 }  // namespace te
 }  // namespace tvm
index 75d6c14..bb162f4 100644 (file)
@@ -321,6 +321,33 @@ def test_cuda_reduction():
     check_cuda("float32")
     check_cuda("float16")
 
+def test_cuda_mix_threaded_and_normal_reduction():
+    def check_cuda(dtype, m=32, n=32):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        a = tvm.te.placeholder((m, n), name="a", dtype=dtype)
+        b = topi.sum(a)
+        with tvm.target.cuda():
+            sb = tvm.te.create_schedule(b.op)
+            i, _ = b.op.reduce_axis
+            sb[b].bind(i, tvm.te.thread_axis("threadIdx.x"))
+            ctx = tvm.gpu(0)
+            func = tvm.build(sb, [a, b], 'cuda')
+            a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
+            b_np = np.sum(a_np)
+            a_nd = tvm.nd.array(a_np, ctx)
+            b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
+            func(a_nd, b_nd)
+            tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
+
+    check_cuda("float32")
+    check_cuda("float16")
+
 def test_cuda_floordiv_with_vectorization():
     if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
         print("skip because cuda is not enabled..")
@@ -528,7 +555,8 @@ if __name__ == "__main__":
     test_rfactor_predicates()
     test_cuda_const_float_to_half()
     test_cuda_reduction()
+    test_cuda_mix_threaded_and_normal_reduction()
     test_cuda_floordiv_with_vectorization()
     test_vectorized_intrin1()
     test_vectorized_intrin2()
-    test_vectorized_popcount()
\ No newline at end of file
+    test_vectorized_popcount()