From: xqdan Date: Tue, 9 Jun 2020 15:34:27 +0000 (+0800) Subject: [ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5749) X-Git-Tag: upstream/0.7.0~596 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0ea99698f4379eedba26b06a9f426e613bf5b25e;p=platform%2Fupstream%2Ftvm.git [ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5749) * fix a min/max simplify bug * fix cpplint * turn into oposite when c1val<0 and add more case * fix c1=0 Co-authored-by: xqdan --- diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 223b2e6..4149b15 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1025,8 +1025,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { if (min(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + if (c1val == 0) { + return c2val < 0 ? c2.Eval() : c1.Eval(); + } if (c2val % c1val == 0) { - if (c2val / c1val >= 0) { + if (c1val > 0) { return (min(x, c2val / c1val) * c1val).Eval(); } else { return (max(x, c2val / c1val) * c1val).Eval(); @@ -1185,8 +1188,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { if (max(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + if (c1val == 0) { + return c2val > 0 ? c2.Eval() : c1.Eval(); + } if (c2val % c1val == 0) { - if (c2val / c1val >= 0) { + if (c1val > 0) { return (max(x, c2val / c1val) * c1val).Eval(); } else { return (min(x, c2val / c1val) * c1val).Eval(); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index dbfdde3..813e10a 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -529,7 +529,13 @@ def test_min_index_simplify(): ck.verify(tvm.te.min(tvm.te.min(x, 11), 10), tvm.te.min(x, 10)) ck.verify(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3) + ck.verify(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2) + ck.verify(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2) ck.verify(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x, 1)) + ck.verify(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2) + ck.verify(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2) + ck.verify(tvm.te.min(x * (0), 4), 0) + ck.verify(tvm.te.min(x * (0), -4), -4) # DivMod rules # truc div @@ -610,6 +616,12 @@ def test_max_index_simplify(): ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3) ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2)) + ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2) + ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2) + ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2) + ck.verify(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2) + ck.verify(tvm.te.max(x * (0), 4), 4) + ck.verify(tvm.te.max(x * (0), -4), 0) # DivMod rules # truc div