for (PrimExpr v : conds) {
cond = cond && v;
}
+
+ std::vector<std::vector<Stmt>> common, normal_red;
+ for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
+ IterVar iv = stage->leaf_iter_vars[i];
+ IterVarAttr attr;
+ auto it = stage->iter_var_attrs.find(iv);
+ if (it != stage->iter_var_attrs.end()) {
+ attr = (*it).second;
+ }
+ if (iv->iter_type == kCommReduce) {
+ if (attr.defined() && attr->bind_thread.defined()) {
+ common.emplace_back(nest[i + 1]);
+ } else {
+ normal_red.emplace_back(nest[i + 1]);
+ }
+ } else {
+ common.emplace_back(nest[i + 1]);
+ }
+ }
+
+ // If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
+ // something goes wrong, so we use an extra variable here for normal reduction.
+ std::vector<Var> normal_res_handles;
+ std::vector<Stmt> normal_init, normal_update;
+ if (!normal_red.empty()) {
+ normal_res_handles.reserve(size);
+ normal_init.reserve(size);
+ normal_update.resize(size);
+ const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
+ CHECK(combiner);
+ Array<PrimExpr> lhs;
+ for (size_t i = 0; i < size; ++i) {
+ DataType t = reduces[i]->dtype;
+ normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
+ lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes())));
+ }
+ Array<PrimExpr> init_value = combiner->identity_element;
+ Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
+ for (size_t i = 0; i < size; ++i) {
+ DataType t = reduces[i]->dtype;
+ normal_init.emplace_back(StoreNode::make(
+ normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
+ normal_update.emplace_back(StoreNode::make(
+ normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+ }
+ }
+
Array<PrimExpr> freduce_args;
freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
- freduce_args.push_back(reduces[0]->source[i]);
+ if (!normal_red.empty()) {
+ DataType t = reduces[i]->dtype;
+ freduce_args.push_back(LoadNode::make(
+ t, normal_res_handles[i], 0, const_true(t.lanes())));
+ } else {
+ freduce_args.push_back(reduces[0]->source[i]);
+ }
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
tir::attr::reduce_scope,
make_zero(DataType::Handle()),
reduce_body);
+
+ if (!normal_red.empty()) {
+ Stmt init_body = SeqStmt::Flatten(normal_init);
+ Stmt update_body = SeqStmt::Flatten(normal_update);
+ update_body = MergeNest(normal_red, update_body);
+ reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
+ reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
+ }
+
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+ if (!normal_red.empty()) {
+ body = AllocateNode::make(
+ normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+ body = AttrStmtNode::make(
+ normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
+ }
}
body = Substitute(body, value_map);
- return MergeNest(nest, body);
+ return MergeNest(common, body);
}
} // namespace te
} // namespace tvm
check_cuda("float32")
check_cuda("float16")
+def test_cuda_mix_threaded_and_normal_reduction():
+ def check_cuda(dtype, m=32, n=32):
+ if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+ print("skip because cuda is not enabled..")
+ return
+ if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ print("Skip because gpu does not have fp16 support")
+ return
+
+ a = tvm.te.placeholder((m, n), name="a", dtype=dtype)
+ b = topi.sum(a)
+ with tvm.target.cuda():
+ sb = tvm.te.create_schedule(b.op)
+ i, _ = b.op.reduce_axis
+ sb[b].bind(i, tvm.te.thread_axis("threadIdx.x"))
+ ctx = tvm.gpu(0)
+ func = tvm.build(sb, [a, b], 'cuda')
+ a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
+ b_np = np.sum(a_np)
+ a_nd = tvm.nd.array(a_np, ctx)
+ b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
+ func(a_nd, b_nd)
+ tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
+
+ check_cuda("float32")
+ check_cuda("float16")
+
def test_cuda_floordiv_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
+ test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
- test_vectorized_popcount()
\ No newline at end of file
+ test_vectorized_popcount()