Set split node's range to minimum of ext and split factor or split nparts, but only...
authoryongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com>
Thu, 12 Mar 2020 03:59:23 +0000 (23:59 -0400)
committerGitHub <noreply@github.com>
Thu, 12 Mar 2020 03:59:23 +0000 (20:59 -0700)
src/te/schedule/message_passing.cc
tests/python/unittest/test_schedule_bound_inference.py

index 5b6fa86..a7b2482 100644 (file)
@@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
   }
 }
 
+/*!
+ * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
+ * a thread.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The propagation result of each IterVar.
+ */
+void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
+  auto bound_to_thread = [&stage](const IterVar& iv) {
+    bool bound = false;
+    auto it = stage->iter_var_attrs.find(iv);
+    if (it != stage->iter_var_attrs.end()) {
+      bound = (*it).second->bind_thread.defined();
+    }
+    return bound;
+  };
+
+  auto& state = *p_state;
+  // Fill p_state with leaf itervars
+  for (const IterVar& iv : stage->leaf_iter_vars) {
+    state[iv] = bound_to_thread(iv);
+  }
+  // Traverse the graph bottom-up to propagate thread binding information
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      state[s->parent] = state[s->inner] || state[s->outer];
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      state[s->inner] = state[s->fused];
+      state[s->outer] = state[s->fused];
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      state[s->parent] = state[s->rebased];
+    } else if (rel.as<SingletonNode>()) {
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
 void PassDownDomain(const Stage& stage,
                     std::unordered_map<IterVar, Range>* p_state,
                     arith::Analyzer* actx,
                     bool allow_missing) {
-  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
+  auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
     if (actx->CanProve(indexmod(a, b) == 0)) {
       return actx->Simplify(indexdiv(a, b));
     }
     return actx->Simplify(indexdiv(a + (b - 1), b));
   };
 
+  auto minimum_or_later  = [actx](const PrimExpr& a, const PrimExpr& b) {
+    if (actx->CanProve(a < b)) {
+      return actx->Simplify(a);
+    }
+    return actx->Simplify(b);
+  };
+
+  std::unordered_map<IterVar, bool> dominating_thread;
+  PassUpThreadBinding(stage, &dominating_thread);
+
   auto& state = *p_state;
   // forwar iteration on relations
   for (IterVarRelation rel : stage->relations) {
@@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
       }
       CHECK(!state.count(r->inner));
       const Range& range_parent = state.at(r->parent);
+      // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
+      // following conditions are met:
+      // 1. No leaf IterVar derived from iv binds to any thread.  People may use split
+      // to force an IterVar extent to match the number of allocated threads to fuse stages
+      // that require different number of threads.  We don't want to change these extents.
+      // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
+      // rather than by an early compiler phase, such as rfactor().  We don't want to tighten an
+      // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
+      // 3. range_parent's extent is not 0.  At lest one Topi test has a case where a tensor has one
+      // zero-sized dimension.  Split creates iv with a positive extent to avoid zero-extent
+      // IterVar.  We don't touch it.
+      auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
+        return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
+                   ? factor_or_nparts
+                   : minimum_or_later(range_parent->extent, factor_or_nparts);
+      };
       if (r->factor.defined()) {
         Update(p_state, r->inner,
-               Range::make_by_min_extent(0, r->factor), actx);
+               Range::make_by_min_extent(
+                   0, resolve_min_extent_for_split(r->inner, r->factor)),
+               actx);
         Update(p_state, r->outer,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->factor)), actx);
       } else {
-        Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
+        Update(p_state, r->outer,
+               Range::make_by_min_extent(
+                   0, resolve_min_extent_for_split(r->outer, r->nparts)),
+               actx);
         Update(p_state, r->inner,
                Range::make_by_min_extent(
                    0, ceil_div(range_parent->extent, r->nparts)), actx);
index 484aa50..edae527 100644 (file)
@@ -70,6 +70,32 @@ def test_bound3():
     assert(bounds[A1.op.axis[0]].extent.value==32)
     assert(bounds[A1.op.axis[1]].extent.value==16)
 
+def test_bound_split_ext_less_than_factor():
+    m = 8
+    I = te.placeholder((m,), name='I')
+    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    s = te.create_schedule([E.op])
+    xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
+    s[EF].compute_at(s[E], xo)
+
+    bounds = tvm.te.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xi].extent.value == m
+
+def test_bound_split_ext_less_than_naprts():
+    m = 8
+    I = te.placeholder((m,), name='I')
+    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
+    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
+    s = te.create_schedule([E.op])
+    xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
+    s[EF].compute_at(s[E], xo)
+
+    bounds = tvm.te.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xo].extent.value == m
+
 def test_bound_split_divisible():
     m = te.var('m')
     l = te.var('l')