[ARITH] Bugfix min/max const canonicalize rule (#3386)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 18 Jun 2019 04:51:33 +0000 (21:51 -0700)
committerGitHub <noreply@github.com>
Tue, 18 Jun 2019 04:51:33 +0000 (21:51 -0700)
3rdparty/dmlc-core
src/arithmetic/rewrite_simplify.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index fbe142b..3943914 160000 (submodule)
@@ -1 +1 @@
-Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
+Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
index ee32656..ea65306 100644 (file)
@@ -813,7 +813,9 @@ Mutate_(const Min* op, const Expr& self) {
 
     // canonicalization
     TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
-    TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1));
+    TVM_TRY_RECURSIVE_REWRITE_IF(
+        min(c1 - x, c2), c1 - max(x, c1 - c2),
+        c2.Eval()->value != 0);
   }
 
   // condition rules.
@@ -961,7 +963,8 @@ Mutate_(const Max* op, const Expr& self) {
 
     // canonicalization
     TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
-    TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1));
+    TVM_TRY_RECURSIVE_REWRITE_IF(
+        max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
   }
 
   // condition rules.
index 596e54d..07d460e 100644 (file)
@@ -392,6 +392,7 @@ def test_min_index_simplify():
     ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
     ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
     ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
+    ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x,  1))
 
 
 def test_max_index_simplify():
@@ -448,6 +449,7 @@ def test_max_index_simplify():
     ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
     ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
     ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
+    ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x,  2))
 
 
 def test_cmp_simplify():