[schedule] Improve ceil_divide in tile/split (#3842)
authorYizhi Liu <liuyizhi@apache.org>
Fri, 6 Sep 2019 13:29:31 +0000 (21:29 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 6 Sep 2019 13:29:31 +0000 (21:29 +0800)
src/schedule/message_passing.cc
tests/python/unittest/test_schedule_bound_inference.py

index b39a6c6..c5c79ea 100644 (file)
@@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage,
                     arith::Analyzer* actx,
                     bool allow_missing) {
   auto ceil_div = [actx](Expr a, Expr b) {
+    if (actx->CanProve(a % b == 0)) {
+      return actx->Simplify(a / b);
+    }
     return actx->Simplify((a + (b - 1)) / b);
   };
 
index 21be6b7..1ff9853 100644 (file)
@@ -69,6 +69,33 @@ def test_bound3():
     assert(bounds[A1.op.axis[0]].extent.value==32)
     assert(bounds[A1.op.axis[1]].extent.value==16)
 
+def test_bound_split_divisible():
+    m = tvm.var('m')
+    l = tvm.var('l')
+    A = tvm.placeholder((8 * m, l), name='A')
+    B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
+    s = tvm.create_schedule(B.op)
+    xo, xi = s[B].split(B.op.axis[0], 8)
+    bounds = tvm.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xo].extent == m
+    assert bounds[xi].extent.value == 8
+
+def test_bound_tile_divisible():
+    m = tvm.var('m')
+    l = tvm.var('l')
+    shape = (8 * m, 32 * l)
+    A = tvm.placeholder(shape, name='A')
+    B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
+    s = tvm.create_schedule(B.op)
+    xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
+    bounds = tvm.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    assert bounds[xo].extent == m
+    assert bounds[xi].extent.value == 8
+    assert bounds[yo].extent == l
+    assert bounds[yi].extent.value == 32
+
 def test_bound_fusesplit1():
     m = tvm.var('m')
     l = tvm.var('l')
@@ -393,3 +420,5 @@ if __name__ == "__main__":
     test_bound_simplification_failure()
     test_bound_fusesplit1()
     test_bound_fusesplit2()
+    test_bound_split_divisible()
+    test_bound_tile_divisible()