[ARITH] Bugfix: check arg positiveness for mod rules (#3279)
authorSergei Grechanik <grechanik.sergey@huawei.com>
Mon, 3 Jun 2019 15:52:31 +0000 (18:52 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 3 Jun 2019 15:52:31 +0000 (08:52 -0700)
src/arithmetic/rewrite_simplify.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index 00198d9..ee32656 100644 (file)
@@ -634,10 +634,12 @@ Mutate_(const Mod* op, const Expr& self) {
     TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2,
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
+                       CanProveGreaterEqual((x * c1).Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
     TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2,
                        c2.Eval()->value > 0 &&
+                       c1.Eval()->value >= 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0));
 
@@ -645,7 +647,7 @@ Mutate_(const Mod* op, const Expr& self) {
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
-                       CanProveGreaterEqual(y.Eval(), 0));
+                       CanProveGreaterEqual((y * c1).Eval(), 0));
 
     // canonicalization: x % c == x % (-c) for truncated division
     // NOTE: trunc div required
index 1b03253..ee113e1 100644 (file)
@@ -302,9 +302,11 @@ def test_div_index_simplify():
 
 def test_mod_index_simplify():
     ck = RewriteChecker()
-    x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+    x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z")
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
+    ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True)
+    ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True)
 
     ck.verify(x * 10 % 2, 0)
     ck.verify((x * 10 + y) % 2, y % 2)
@@ -317,6 +319,19 @@ def test_mod_index_simplify():
     ck.verify((x + y * 10) % -2, x % 2)
     ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)
 
+    ck.verify(x * (-10) % 2, 0)
+    ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2)
+    ck.verify((x + (-10)) % 2, (x + (-10)) % 2)
+    ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2)
+    ck.verify(x * (-10) % -2, 0)
+
+    ck.verify(nx * 10 % 2, 0)
+    ck.verify((nx * (-10) + y) % 2, y % 2)
+    ck.verify((x + ny * (-10)) % 2, x % 2)
+    ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1)
+    ck.verify(nx * 10 % -2, 0)
+    ck.verify((nx * (-10) + y) % -2, y % 2)
+    ck.verify((x + ny * (-10)) % -2, x % 2)
 
 def test_min_index_simplify():
     ck = RewriteChecker()