[FIX] Fix for a specific case when loop partitioning with indivisble (#4243)
authorKimish Patel <kimishpatel@fb.com>
Fri, 15 Nov 2019 22:37:37 +0000 (14:37 -0800)
committerziheng <ziheng@apache.org>
Fri, 15 Nov 2019 22:37:37 +0000 (14:37 -0800)
factors and resulting nested loop is broken.
This is due to the fact that we are creating zero extent loops which
are fixed afterwards. However unroll pass breaks due to the zero extent
loop.

src/pass/loop_partition.cc
tests/python/unittest/test_pass_loop_partition.py

index b2a1bea..1d669c8 100644 (file)
@@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   bool pre_stmt_recurse = true;
   if (middle_interval_i->HasLowerBound()) {
     body_begin = ir::Simplify(middle_interval.min());
-    Expr cond = (body_begin - min >= 0);
-    if (!analyzer_.CanProve(cond)) {
-      LOG(WARNING) << "Cannot prove: " << cond
-                   << ", when generating the pre doubt loop";
-      body_begin = Max::make(body_begin, min);
-      // stop recursing on this interval if we can't prove it has non-negative length
-      pre_stmt_recurse = false;
-    }
-    if (!partition_thread_scope) {
-      Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
-      pre_stmt = MakeFor(node, body_begin - min, pre_body);
+    if (!analyzer_.CanProve(body_begin == min)) {
+      Expr cond = (body_begin - min >= 0);
+      if (!analyzer_.CanProve(cond)) {
+        LOG(WARNING) << "Cannot prove: " << cond
+                     << ", when generating the pre doubt loop";
+        body_begin = Max::make(body_begin, min);
+        // stop recursing on this interval if we can't prove it has non-negative length
+        pre_stmt_recurse = false;
+      }
+      if (!partition_thread_scope) {
+        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+        pre_stmt = MakeFor(node, body_begin - min, pre_body);
+      }
     }
   } else {
     body_begin = min;
@@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   bool post_stmt_recurse = true;
   if (middle_interval_i->HasUpperBound()) {
     post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
-    // require the extent to be non-negative
-    Expr cond = (max - post_doubt_begin + 1 >= 0);
-    if (!analyzer_.CanProve(cond)) {
-      LOG(WARNING) << "Cannot prove: " << cond
-                   << ", when generating the post doubt loop";
-      post_doubt_begin = Min::make(post_doubt_begin, max+1);
-      // stop recursing on this interval if we can't prove it has non-negative length
-      post_stmt_recurse = false;
-    }
-    if (!partition_thread_scope) {
-      Stmt post_body =
-        Substitute(body, {{Var{var}, var + post_doubt_begin}});
-      post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+    if (!analyzer_.CanProve(middle_interval.max() == max)) {
+      // require the extent to be non-negative
+      Expr cond = (max - post_doubt_begin + 1 >= 0);
+      if (!analyzer_.CanProve(cond)) {
+        LOG(WARNING) << "Cannot prove: " << cond
+                     << ", when generating the post doubt loop";
+        post_doubt_begin = Min::make(post_doubt_begin, max+1);
+        // stop recursing on this interval if we can't prove it has non-negative length
+        post_stmt_recurse = false;
+      }
+      if (!partition_thread_scope) {
+        Stmt post_body =
+          Substitute(body, {{Var{var}, var + post_doubt_begin}});
+        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+      }
     }
   } else {
     post_doubt_begin = max + 1;
index b6fcfa3..0217095 100644 (file)
@@ -365,6 +365,27 @@ def test_conv_tiling():
     stmt = tvm.ir_pass.Simplify(stmt)
     assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
 
+
+def test_multilevel_splitting_with_indivisble_factors():
+    import topi
+    A = tvm.placeholder((130,), dtype="float32")
+    B = topi.nn.relu(A)
+    s = tvm.create_schedule(B.op)
+    (y,) = s[B].op.axis
+    (yo, yi) = s[B].split(y, factor=8)
+    (yoo, yoi) = s[B].split(yo, factor=16)
+    s[B].reorder(yoo, yoi, yi)
+    s[B].unroll(yi)
+
+    ## But this does the right thing.
+    with tvm.build_config(partition_const_loop=True):
+        lowered_body = tvm.lower(s, [A, B]).body
+        def visit_stmt(op):
+            return(isinstance(op, tvm.expr.Max))
+        num_max = collect_visit(lowered_body, visit_stmt)
+        assert num_max.count(True) == 10
+
+
 def test_double_splitting_with_indivisible_factors():
     m = 48
     dtype="float32"
@@ -443,4 +464,5 @@ if __name__ == "__main__":
     test_cce_loop_3()
     test_conv_tiling()
     test_double_splitting_with_indivisible_factors()
+    test_multilevel_splitting_with_indivisble_factors()
     test_simple_rfactor()