TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2,
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
+ CanProveGreaterEqual((x * c1).Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2,
c2.Eval()->value > 0 &&
+ c1.Eval()->value >= 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0));
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) &&
- CanProveGreaterEqual(y.Eval(), 0));
+ CanProveGreaterEqual((y * c1).Eval(), 0));
// canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required
def test_mod_index_simplify():
ck = RewriteChecker()
- x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+ x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
+ ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True)
+ ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True)
ck.verify(x * 10 % 2, 0)
ck.verify((x * 10 + y) % 2, y % 2)
ck.verify((x + y * 10) % -2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)
+ ck.verify(x * (-10) % 2, 0)
+ ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2)
+ ck.verify((x + (-10)) % 2, (x + (-10)) % 2)
+ ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2)
+ ck.verify(x * (-10) % -2, 0)
+
+ ck.verify(nx * 10 % 2, 0)
+ ck.verify((nx * (-10) + y) % 2, y % 2)
+ ck.verify((x + ny * (-10)) % 2, x % 2)
+ ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1)
+ ck.verify(nx * 10 % -2, 0)
+ ck.verify((nx * (-10) + y) % -2, y % 2)
+ ck.verify((x + ny * (-10)) % -2, x % 2)
def test_min_index_simplify():
ck = RewriteChecker()