From 2421a85474cc3ea291e87dab8133022e0d13c07f Mon Sep 17 00:00:00 2001 From: yongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com> Date: Wed, 11 Mar 2020 23:59:23 -0400 Subject: [PATCH] Set split node's range to minimum of ext and split factor or split nparts, but only when PassDownDomain is called with allow_missing == false, i.e. by InferBound. Add a helper PassUpThreadBinding() to get a map telling whether an IterVar has at least one leaf IterVar deriving from it binding to a thread. Add two unit tests. (#5044) --- src/te/schedule/message_passing.cc | 76 +++++++++++++++++++++- .../unittest/test_schedule_bound_inference.py | 26 ++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5b6fa86..a7b2482 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -51,17 +51,66 @@ void Update(std::unordered_map* p_state, } } +/*! + * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to + * a thread. + * + * \param stage The stage to operate on. + * \param p_state The propagation result of each IterVar. + */ +void PassUpThreadBinding(const Stage& stage, std::unordered_map* p_state) { + auto bound_to_thread = [&stage](const IterVar& iv) { + bool bound = false; + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end()) { + bound = (*it).second->bind_thread.defined(); + } + return bound; + }; + + auto& state = *p_state; + // Fill p_state with leaf itervars + for (const IterVar& iv : stage->leaf_iter_vars) { + state[iv] = bound_to_thread(iv); + } + // Traverse the graph bottom-up to propagate thread binding information + for (size_t i = stage->relations.size(); i != 0; --i) { + IterVarRelation rel = stage->relations[i - 1]; + if (const SplitNode* s = rel.as()) { + state[s->parent] = state[s->inner] || state[s->outer]; + } else if (const FuseNode* s = rel.as()) { + state[s->inner] = state[s->fused]; + state[s->outer] = state[s->fused]; + } else if (const RebaseNode* s = rel.as()) { + state[s->parent] = state[s->rebased]; + } else if (rel.as()) { + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, bool allow_missing) { - auto ceil_div = [actx](PrimExpr a, PrimExpr b) { + auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); } return actx->Simplify(indexdiv(a + (b - 1), b)); }; + auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { + if (actx->CanProve(a < b)) { + return actx->Simplify(a); + } + return actx->Simplify(b); + }; + + std::unordered_map dominating_thread; + PassUpThreadBinding(stage, &dominating_thread); + auto& state = *p_state; // forwar iteration on relations for (IterVarRelation rel : stage->relations) { @@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage, } CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); + // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the + // following conditions are met: + // 1. No leaf IterVar derived from iv binds to any thread. People may use split + // to force an IterVar extent to match the number of allocated threads to fuse stages + // that require different number of threads. We don't want to change these extents. + // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, + // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an + // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later. + // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one + // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent + // IterVar. We don't touch it. + auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) { + return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent) + ? factor_or_nparts + : minimum_or_later(range_parent->extent, factor_or_nparts); + }; if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent(0, r->factor), actx); + Range::make_by_min_extent( + 0, resolve_min_extent_for_split(r->inner, r->factor)), + actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); + Update(p_state, r->outer, + Range::make_by_min_extent( + 0, resolve_min_extent_for_split(r->outer, r->nparts)), + actx); Update(p_state, r->inner, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->nparts)), actx); diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 484aa50..edae527 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -70,6 +70,32 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_split_ext_less_than_factor(): + m = 8 + I = te.placeholder((m,), name='I') + EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") + E = te.compute((m,), lambda i: EF[i] * 2, name = "E") + s = te.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], factor = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.te.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xi].extent.value == m + +def test_bound_split_ext_less_than_naprts(): + m = 8 + I = te.placeholder((m,), name='I') + EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") + E = te.compute((m,), lambda i: EF[i] * 2, name = "E") + s = te.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], nparts = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.te.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent.value == m + def test_bound_split_divisible(): m = te.var('m') l = te.var('l') -- 2.7.4