namespace te {
using namespace tir;
+//
+// Cross thread reduction transformation.
+//
+// The input loop nest in generic form (single reduction/thread case)
+//
+// let m be the reduction extent
+// let N be the thread extent
+// let input_pred be the predicate on the reduction
+//
+// B[..] = 0
+// for (tid, 0, N)
+// for (i, 0, floordiv(m+N-1, N))
+// if (i + tid * floordiv(m+N-1, N) < m)
+// if (input_pred)
+// B[..] = op(B[..], A[i + tid * floordiv(m+N-1,N)])
+//
+// The threaded reduction looks like
+//
+// (1) normal reductions (leaves)
+// for (i, 0, floordiv(m+N-1, N))
+// if (i + tid * floordiv(m+N-1, N) < m)
+// if (input_pred)
+// B_temp[0] = op(B_temp[0], A[i + tid * floordiv(m+N-1,N)])
+//
+// (2) threaded reduction does not require predicates as an identity
+// element will be filled if out of bounds.
+//
+// tvm_thread_allreduce(size, B_temp, (bool)1, tid)
+//
+// The last step is to write the final reduction variable,
+// which should be predicated by the existing input_pred if any
+// The consequence is that input_pred should be independent of
+// the reduction axis. Otherwise, we need to seperate it into
+// dependent part and independent one.
+//
+// (3) write back
+// if (input_pred)
+// B[..] = B_temp[0]
+//
+// In summary, we are going to need two predicates
+//
+// * the original input_pred from reduction itself
+//
+// * the normal reduction axis predicate
+// normal_pred = (i + tid * floordiv(m+N-1,N)) < m
+// this predicate depends on the normal reduction variable.
+//
+// input_pred will be applied to both normal reduction and
+// the writeback step.
+//
Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
std::unordered_map<IterVar, PrimExpr> value_map;
auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map,
debug_keep_trivial_loop);
- auto conds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
size_t size = self->body.size();
CHECK_GT(size, 0);
CHECK(reduce);
reduces[i] = reduce;
}
- PrimExpr cond = reduces[0]->condition;
- for (PrimExpr v : conds) {
- cond = cond && v;
- }
+
+ // This computes the bound checking predicates in normal reduction.
+ auto normal_preds =
+ MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set<IterVar>());
+
+ // normal_pred = input_pred && normal_pred
+ PrimExpr input_pred = reduces[0]->condition;
+ normal_preds.push_back(input_pred);
+ normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(),
+ [](const PrimExpr& e) { return !e.defined(); }),
+ normal_preds.end());
std::vector<std::vector<Stmt>> common, normal_red;
for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
}
- freduce_args.push_back(cond);
+
+ // No constraints on the thread reduction step. It may have redundent
+ // computation for rare cases. TODO(tvm-team): revisit this.
+ freduce_args.push_back(const_true(1));
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
}
}
}
+
// Checks for the thread.
- std::vector<PrimExpr> thread_head_check;
+ std::vector<PrimExpr> output_preds;
if (stage->store_predicate.defined()) {
- thread_head_check.emplace_back(stage->store_predicate);
+ output_preds.emplace_back(stage->store_predicate);
}
+ // Apply the existing input predicate if any.
+ output_preds.push_back(input_pred);
+
Stmt reduce_body = EvaluateNode::make(CallNode::make(
DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic));
reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope,
if (!normal_red.empty()) {
Stmt init_body = SeqStmt::Flatten(normal_init);
Stmt update_body = SeqStmt::Flatten(normal_update);
+ update_body = MergeNest(MakeIfNest(normal_preds), update_body);
update_body = MergeNest(normal_red, update_body);
reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
- reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
}
std::vector<Stmt> assigns(size);
stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = SeqStmt::Flatten(assigns);
- assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body);
- assign_body = MergeNest(MakeIfNest(conds), assign_body);
+ assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body =
module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref)
+def test_crossthread_reduction1():
+ if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+ print("skip because cuda is not enabled..")
+ return
+
+ n = te.var("n")
+ m = te.var("m")
+ A = te.placeholder((n, m), name='A')
+ k = te.reduce_axis((0, m), "m")
+ B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+
+ def sched(nthd):
+ s = te.create_schedule(B.op)
+ ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd)
+ s[B].bind(ko, te.thread_axis("threadIdx.x"))
+ s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+ func = tvm.build(s, [A, B], "cuda")
+ return func
+
+ def verify(nthd):
+ func = sched(nthd)
+ nn = 3
+ # checks three typical cases
+ vals = [nthd-1, nthd, nthd+1]
+ for kk in [x for x in vals]:
+ size = (nn, kk)
+ ctx = tvm.context("cuda", 0)
+ a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+ func(a, b)
+ tvm.testing.assert_allclose(b.asnumpy(), \
+ np.sum(a.asnumpy(), axis=1), rtol=1e-3)
+
+ verify(16)
+ verify(32)
+ verify(64)
+
+def test_crossthread_reduction2():
+ if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+ print("skip because cuda is not enabled..")
+ return
+
+ n = te.var("n")
+ k0 = te.var("k0")
+ k1 = te.var("k1")
+ A = te.placeholder((n, k0, k1), name='A')
+ k0 = te.reduce_axis((0, k0), "k0")
+ k1 = te.reduce_axis((0, k1), "k1")
+ B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
+
+ def sched(nthdx, nthdy):
+ s = te.create_schedule(B.op)
+ k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx)
+ k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy)
+ s[B].bind(k0o, te.thread_axis("threadIdx.x"))
+ s[B].bind(k1o, te.thread_axis("threadIdx.y"))
+ s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+ func = tvm.build(s, [A, B], "cuda")
+ return func
+
+ def verify(nthdx, nthdy):
+ func = sched(nthdx, nthdy)
+ nn = 3
+ # checks three typical cases
+ vx = [nthdx-1, nthdx, nthdx+1]
+ vy = [nthdy-1, nthdy, nthdy+1]
+ for kk0, kk1 in [(x, y) for x in vx for y in vy]:
+ size = (nn, kk0, kk1)
+ ctx = tvm.context("cuda", 0)
+ a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+ func(a, b)
+ tvm.testing.assert_allclose(b.asnumpy(), \
+ np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
+
+ verify(16, 16)
+ verify(32, 32)
+ verify(16, 32)
+ verify(32, 16)
def test_cuda_reducition_binding():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
test_cuda_shuffle()
test_vectorized_casts()
test_cuda_reducition_binding()
+ test_crossthread_reduction1()
+ test_crossthread_reduction2()
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()