[ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5749)
authorxqdan <danxiaoqiang@126.com>
Tue, 9 Jun 2020 15:34:27 +0000 (23:34 +0800)
committerGitHub <noreply@github.com>
Tue, 9 Jun 2020 15:34:27 +0000 (08:34 -0700)
* fix a min/max simplify bug

* fix cpplint

* turn into oposite when c1val<0 and add more case

* fix c1=0

Co-authored-by: xqdan <danxiaoqiang@huawei.com>
src/arith/rewrite_simplify.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index 223b2e6..4149b15 100644 (file)
@@ -1025,8 +1025,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
     if (min(x * c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
+      if (c1val == 0) {
+        return c2val < 0 ? c2.Eval() : c1.Eval();
+      }
       if (c2val % c1val == 0) {
-        if (c2val / c1val >= 0) {
+        if (c1val > 0) {
           return (min(x, c2val / c1val) * c1val).Eval();
         } else {
           return (max(x, c2val / c1val) * c1val).Eval();
@@ -1185,8 +1188,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
     if (max(x * c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
+      if (c1val == 0) {
+        return c2val > 0 ? c2.Eval() : c1.Eval();
+      }
       if (c2val % c1val == 0) {
-        if (c2val / c1val >= 0) {
+        if (c1val > 0) {
           return (max(x, c2val / c1val) * c1val).Eval();
         } else {
           return (min(x, c2val / c1val) * c1val).Eval();
index dbfdde3..813e10a 100644 (file)
@@ -529,7 +529,13 @@ def test_min_index_simplify():
     ck.verify(tvm.te.min(tvm.te.min(x, 11), 10), tvm.te.min(x, 10))
 
     ck.verify(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3)
+    ck.verify(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2)
+    ck.verify(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2)
     ck.verify(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x,  1))
+    ck.verify(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2)
+    ck.verify(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2)
+    ck.verify(tvm.te.min(x * (0), 4), 0)
+    ck.verify(tvm.te.min(x * (0), -4), -4)
 
     # DivMod rules
     # truc div
@@ -610,6 +616,12 @@ def test_max_index_simplify():
 
     ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3)
     ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x,  2))
+    ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2)
+    ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2)
+    ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2)
+    ck.verify(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2)
+    ck.verify(tvm.te.max(x * (0), 4), 4)
+    ck.verify(tvm.te.max(x * (0), -4), 0)
 
     # DivMod rules
     # truc div