assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
+def test_bound_split_divisible():
+ m = tvm.var('m')
+ l = tvm.var('l')
+ A = tvm.placeholder((8 * m, l), name='A')
+ B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
+ s = tvm.create_schedule(B.op)
+ xo, xi = s[B].split(B.op.axis[0], 8)
+ bounds = tvm.schedule.InferBound(s)
+ assert isinstance(bounds, tvm.container.Map)
+ assert bounds[xo].extent == m
+ assert bounds[xi].extent.value == 8
+
+def test_bound_tile_divisible():
+ m = tvm.var('m')
+ l = tvm.var('l')
+ shape = (8 * m, 32 * l)
+ A = tvm.placeholder(shape, name='A')
+ B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
+ s = tvm.create_schedule(B.op)
+ xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
+ bounds = tvm.schedule.InferBound(s)
+ assert isinstance(bounds, tvm.container.Map)
+ assert bounds[xo].extent == m
+ assert bounds[xi].extent.value == 8
+ assert bounds[yo].extent == l
+ assert bounds[yi].extent.value == 32
+
def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
+ test_bound_split_divisible()
+ test_bound_tile_divisible()