[ARITH] Bugfix div subtract rewrite rule (#3504)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sun, 7 Jul 2019 03:44:13 +0000 (20:44 -0700)
committerGitHub <noreply@github.com>
Sun, 7 Jul 2019 03:44:13 +0000 (20:44 -0700)
src/arithmetic/rewrite_simplify.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index 6cc829d..06a28b5 100644 (file)
@@ -342,13 +342,16 @@ Mutate_(const Sub* op, const Expr& self) {
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
 
+    // Proof in the case of floordiv, need positive condition.
+    // let x = a * c3 + r
+    // (x + c1) / c3 - x / c3 => (r + c1) / c3
     TVM_TRY_REWRITE_IF((x + c1) / c3  - (x + c2) / c3,
-                       ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
+                       ((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3,
                        CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
                        c1.Eval()->value >= c2.Eval()->value &&
                        c3.Eval()->value > 0);
     TVM_TRY_REWRITE_IF((x + c1) / c3  - x / c3,
-                       ((x + (c1 % c3)) % c3 + c1) / c3,
+                       (x % c3 + c1) / c3,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        c1.Eval()->value >= 0 &&
                        c3.Eval()->value > 0);
index 8bbade9..0ad62ec 100644 (file)
@@ -236,7 +236,9 @@ def test_sub_index_simplify():
     # div pattern
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
     ck.verify(x - (x / 3) * 3, x % 3)
-    ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
+
+    ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3)
+    ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3)
 
     ck.verify(y - (y / (-5)) * (-5), y % 5)
     ck.verify((y / 3) * 3 - y, 0 - y % 3)
@@ -258,6 +260,7 @@ def test_sub_index_simplify():
     ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
     ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)
 
+
 def test_mul_index_simplify():
     ck = RewriteChecker()
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")