[ARITH] Improve div/mod in rewrite simplifier (#3149)
authorSergei Grechanik <grechanik.sergey@huawei.com>
Mon, 27 May 2019 16:33:13 +0000 (19:33 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 27 May 2019 16:33:13 +0000 (09:33 -0700)
* [ARITH] Improve div/mod in rewrite simplifier

* Fix lint error

* Fuller file name in src/arithmetic/modular_set.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>
* Generalize some rules

* Replace gcd factoring with specialized rules

* Mark rules that don't work for non-truncated division

* More tests

src/arithmetic/modular_set.cc
src/arithmetic/rewrite_simplify.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index 57e8294..b3e943f 100644 (file)
@@ -26,6 +26,8 @@
 #include <tvm/expr_operator.h>
 #include <tvm/ir_functor_ext.h>
 #include <limits>
+#include <utility>
+#include <unordered_map>
 #include "pattern_match.h"
 
 namespace tvm {
index 0de2a25..00198d9 100644 (file)
@@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) {
       return kLT;
     }
   }
-  if (val == 0) {
-    ModularSet dmod = parent_->modular_set(diff);
-    if (dmod->base != 0) {
-      return kNE;
-    }
-  }
   ConstIntBound dbound = parent_->const_int_bound(diff);
   if (dbound->min_value > val) {
     return kGT;
@@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) {
   if (dbound->max_value <= val) {
     return kLE;
   }
+  if (val == 0) {
+    ModularSet dmod = parent_->modular_set(diff);
+    if (dmod->base != 0) {
+      return kNE;
+    }
+  }
   return kUnknown;
 }
 
@@ -284,11 +284,39 @@ Mutate_(const Sub* op, const Expr& self) {
                        CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
 
     // modular-div simplification
-    // Always pre-condition on positive integer domain
+    // Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division
     TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
-                       CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
+                       c1.Eval()->value != 0);
     TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
-                       CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
+                       c1.Eval()->value != 0);
+    TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y,
+                       c1.Eval()->value != 0);
+    TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1),
+                       c1.Eval()->value != 0);
+    TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y,
+                       c1.Eval()->value != 0);
+    TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y,
+                       c1.Eval()->value != 0);
+
+    TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+    TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+    TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+    TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+    TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+    TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2,
+                       c1.Eval()->value != 0 &&
+                       c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
+
     TVM_TRY_REWRITE_IF((x + c1) / c3  - (x + c2) / c3,
                        ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3,
                        CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
@@ -348,6 +376,7 @@ Mutate_(const Mul* op, const Expr& self) {
 
     // canonicalization
     TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
+    TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
     TVM_TRY_RECURSIVE_REWRITE_IF(
         (x - y) * c1, (y - x) * (0 - c1),
         c1.Eval()->value < 0);
@@ -396,6 +425,16 @@ Mutate_(const Div* op, const Expr& self) {
     // We adopt the default C division uses truncation instead of floordiv.
     // This means most rules need to check non-negativeness of the operands.
 
+    // TryConstFold doesn't work for negative cases because it is also used by legacy
+    // parts of tvm which still assume euclidean div. In this simplifier we assume that the division
+    // is truncated, so perform const folding again.
+    // NOTE: trunc div required
+    if ((c1 / c2).Match(ret)) {
+      int64_t c1val = c1.Eval()->value;
+      int64_t c2val = c2.Eval()->value;
+      return make_const(op->type, c1val / c2val);
+    }
+
     // while it is always true for trunc div
     // restrict to common case(positive div)
     TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2),
@@ -608,6 +647,12 @@ Mutate_(const Mod* op, const Expr& self) {
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
+    // canonicalization: x % c == x % (-c) for truncated division
+    // NOTE: trunc div required
+    TVM_TRY_RECURSIVE_REWRITE_IF(x % c1,
+                                 x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)),
+                                 c1.Eval()->value < 0);
+
     // try modular analysis
     if ((x % c1).Match(ret)) {
       ModularSet mod = parent_->modular_set(x.Eval());
@@ -1025,20 +1070,53 @@ Mutate_(const LT* op, const Expr& self) {
     TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
                        c1.Eval()->value < 0);
 
-    // require c1 > 0 to work for any div mode
     TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
+    // NOTE: trunc div required
+    TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2,
+                       c1.Eval()->value <= 0 &&
+                       c2.Eval()->value > 0);
+    // NOTE: trunc div required (euclidean is ok too, floored is not)
+    TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x,
                        c1.Eval()->value > 0 &&
+                       c2.Eval()->value < 0);
+    // NOTE: trunc div required (floored is ok too, euclidean is not)
+    TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x,
+                       c1.Eval()->value <= 0 &&
+                       c2.Eval()->value < 0);
+
+    // NOTE: trunc div required
+    TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x,
+                       c1.Eval()->value < 0 &&
                        c2.Eval()->value > 0);
-
     TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0);
+    // NOTE: trunc div required (floored is ok too, euclidean is not)
+    TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1,
+                       c1.Eval()->value < 0 &&
+                       c2.Eval()->value < 0);
+    // NOTE: trunc div required (euclidean is ok too, floored is not)
+    TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2,
+                       c1.Eval()->value >= 0 &&
+                       c2.Eval()->value < 0);
+
+    TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
+                       c1.Eval()->value > 0 &&
+                       c2.Eval()->value > 0);
+    // NOTE: trunc div required
+    TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1,
+                       c1.Eval()->value > 0 &&
+                       c2.Eval()->value <= 0);
+
     TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0);
+    // NOTE: trunc div required
+    TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x,
+                       c1.Eval()->value < 0 &&
+                       c2.Eval()->value > 0);
 
     // division related simplificationx
     // invariance for any div mod: x - (x / c1) * c1 == x % c1
index be961a5..1b03253 100644 (file)
@@ -227,6 +227,25 @@ def test_sub_index_simplify():
     ck.verify(x - (x / 3) * 3, x % 3)
     ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
 
+    ck.verify(y - (y / (-5)) * (-5), y % 5)
+    ck.verify((y / 3) * 3 - y, 0 - y % 3)
+    ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6)
+    ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5)
+    ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z)
+    ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5)
+    ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z)
+    ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z)
+
+    ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3)
+    ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2))
+    ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
+    ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5)
+    ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2)
+    ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2)
+    ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
+    ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5)
+    ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
+    ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)
 
 def test_mul_index_simplify():
     ck = RewriteChecker()
@@ -292,6 +311,11 @@ def test_mod_index_simplify():
     ck.verify((x + 10) % 2, x % 2)
     ck.verify((x + y * 10) % 2, x % 2)
     ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
+    ck.verify(x * 10 % -2, 0)
+    ck.verify((x * 10 + y) % -2, y % 2)
+    ck.verify((x + 10) % -2, x % 2)
+    ck.verify((x + y * 10) % -2, x % 2)
+    ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)
 
 
 def test_min_index_simplify():
@@ -449,6 +473,50 @@ def test_cmp_simplify():
     ck.verify(x / 2 < 3, x < 6)
     ck.verify(x * 4 <= 2, x <= 0)
     ck.verify(3 < x / 2, tvm.expr.LT(7, x))
+    ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
+    ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
+    ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
+    ck.verify(2 * x <= 0, x <= 0)
+
+    ck.verify(x * 2 >= 3, tvm.expr.LE(2, x))
+    ck.verify(x * 2 >= 2, tvm.expr.LE(1, x))
+    ck.verify(x * 2 >= 1, tvm.expr.LE(1, x))
+    ck.verify(x * 2 >= 0, tvm.expr.LE(0, x))
+    ck.verify(x * 2 >= -1, tvm.expr.LE(0, x))
+    ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x))
+    ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x))
+
+    ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1))
+    ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1))
+    ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0))
+    ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0))
+    ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1))
+    ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1))
+    ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2))
+
+    ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2))
+    ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1))
+    ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1))
+    ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0))
+    ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0))
+    ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1))
+    ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1))
+
+    ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x))
+    ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x))
+    ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x))
+    ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x))
+    ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x))
+    ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x))
+    ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x))
+
+    ck.verify(x / 2 >= 1, tvm.expr.LE(2, x))
+    ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x))
+    ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x))
+
+    ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3))
+    ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1))
+    ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2))
 
     ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
     ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
@@ -480,6 +548,7 @@ def test_cmp_simplify():
     ck.verify(x*y <= 0, tvm.const(1, "bool"))
     ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
     ck.verify(y*y >= 0, tvm.const(1, "bool"))
+    ck.verify(x*6 <= -3, tvm.const(0, "bool"))
 
 
 def test_logical_simplify():