[RELAY/PASS] Fix the extent for the post_stmt in the loop partition (#3734)
authorUmang Yadav <umang.yadav1@huawei.com>
Wed, 2 Oct 2019 20:13:10 +0000 (16:13 -0400)
committerziheng <ziheng@apache.org>
Wed, 2 Oct 2019 20:13:10 +0000 (13:13 -0700)
src/pass/loop_partition.cc
tests/python/unittest/test_pass_bound_checkers.py
topi/tests/python/test_topi_math.py

index dc79147..b2a1bea 100644 (file)
@@ -492,7 +492,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
     std::tie(middle_interval, cond_set) =
         GetIntervalAndCondset(finder.partitions, for_interval, false);
     if (middle_interval.is_nothing())
-      // we couldn't find an interval in which the condintions are provably true or false
+      // we couldn't find an interval in which the conditions are provably true or false
       // Therefore, we can't partition the loop based on those conds
       return Stmt();
     cond_value = false;
@@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   bool pre_stmt_recurse = true;
   if (middle_interval_i->HasLowerBound()) {
     body_begin = ir::Simplify(middle_interval.min());
-    if (!analyzer_.CanProve(body_begin == min)) {
-      Expr cond = (body_begin - min >= 0);
-      if (!analyzer_.CanProve(cond)) {
-        LOG(WARNING) << "Cannot prove: " << cond
-                     << ", when generating the pre doubt loop";
-        body_begin = Max::make(body_begin, min);
-        // stop recursing on this interval if we can't prove it has non-negative length
-        pre_stmt_recurse = false;
-      }
-      if (!partition_thread_scope) {
-        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
-        pre_stmt = MakeFor(node, body_begin - min, pre_body);
-      }
+    Expr cond = (body_begin - min >= 0);
+    if (!analyzer_.CanProve(cond)) {
+      LOG(WARNING) << "Cannot prove: " << cond
+                   << ", when generating the pre doubt loop";
+      body_begin = Max::make(body_begin, min);
+      // stop recursing on this interval if we can't prove it has non-negative length
+      pre_stmt_recurse = false;
+    }
+    if (!partition_thread_scope) {
+      Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+      pre_stmt = MakeFor(node, body_begin - min, pre_body);
     }
   } else {
     body_begin = min;
   }
 
   // Calculating post-subrange and generating code for it.
-  // post-subrange = [post_doubt_begin, max]
+  // post-subrange = [post_doubt_begin, max+1)
   Expr post_doubt_begin;
   Stmt post_stmt;
   bool post_stmt_recurse = true;
   if (middle_interval_i->HasUpperBound()) {
     post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
-    if (!analyzer_.CanProve(middle_interval.max() == max)) {
-      // require the extent to be non-negative
-      Expr cond = (max - post_doubt_begin + 1 >= 0);
-      if (!analyzer_.CanProve(cond)) {
-        LOG(WARNING) << "Cannot prove: " << cond
-                     << ", when generating the post doubt loop";
-        post_doubt_begin = Min::make(post_doubt_begin, max);
-        // stop recursing on this interval if we can't prove it has non-negative length
-        post_stmt_recurse = false;
-      }
-      if (!partition_thread_scope) {
-        Stmt post_body =
-                Substitute(body, {{Var{var}, var + post_doubt_begin}});
-        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
-      }
+    // require the extent to be non-negative
+    Expr cond = (max - post_doubt_begin + 1 >= 0);
+    if (!analyzer_.CanProve(cond)) {
+      LOG(WARNING) << "Cannot prove: " << cond
+                   << ", when generating the post doubt loop";
+      post_doubt_begin = Min::make(post_doubt_begin, max+1);
+      // stop recursing on this interval if we can't prove it has non-negative length
+      post_stmt_recurse = false;
+    }
+    if (!partition_thread_scope) {
+      Stmt post_body =
+        Substitute(body, {{Var{var}, var + post_doubt_begin}});
+      post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
     }
   } else {
     post_doubt_begin = max + 1;
index c319d4d..ada8169 100644 (file)
@@ -37,6 +37,7 @@ def lower(sch, args):
     bounds = tvm.schedule.InferBound(sch)
     stmt = tvm.schedule.ScheduleOps(sch, bounds)
     stmt = tvm.ir_pass.LoopPartition(stmt, True)
+    stmt = tvm.ir_pass.RemoveNoOp(stmt)
     stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
     stmt = tvm.ir_pass.CanonicalSimplify(stmt)
     stmt = tvm.ir_pass.VectorizeLoop(stmt)
index 660d22c..ebbf6f7 100644 (file)
@@ -69,8 +69,16 @@ def test_ewise():
             foo(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        for device in get_all_backend():
-            check_device(device)
+        check_device('llvm')
+        check_device('cuda')
+        check_device('opencl')
+        check_device('metal')
+        check_device('rocm')
+        check_device('vulkan')
+        check_device('nvptx')
+        check_device('llvm -device=arm-cpu')
+        check_device('opencl -device=mali')
+        check_device('aocl_sw_emu')
 
     def test_isnan(
         low,
@@ -109,8 +117,16 @@ def test_ewise():
             foo(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        for device in get_all_backend():
-            check_device(device)
+        check_device('llvm')
+        check_device('cuda')
+        check_device('opencl')
+        check_device('metal')
+        check_device('rocm')
+        check_device('vulkan')
+        check_device('nvptx')
+        check_device('llvm -device=arm-cpu')
+        check_device('opencl -device=mali')
+        check_device('aocl_sw_emu')
 
     test_apply(topi.floor, "floor", np.floor, -100, 100)
     test_apply(topi.ceil, "ceil", np.ceil, -100, 100)