From 6c81d784dc9459d684604fcf4190fda4cb956c1c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 30 Jun 2019 18:05:48 -0700 Subject: [PATCH] [ARITH] Canonicalize comparison to move constant to one side (#3467) --- src/arithmetic/rewrite_simplify.cc | 6 ++++++ tests/python/unittest/test_arith_canonical_simplify.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 8ffebe5..bc8666e 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1187,6 +1187,12 @@ Mutate_(const LT* op, const Expr& self) { TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y); TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y); + TVM_TRY_RECURSIVE_REWRITE(x < c1 - y, x + y < c1); + TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1); + TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y); + TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); + + TVM_TRY_REWRITE(x - c1 < 0, x < c1); TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1); } diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index f926da2..56d2bb1 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -166,12 +166,18 @@ def test_simplify_if_then_else(): tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), (((((x*4) + y) - 466036) % 24528) -24512) % 16, x), y) + + res2 = tvm.if_then_else((x * 4) >= 466036 - y, + tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), + (((((x*4) + y) - 466036) % 24528) -24512) % 16, + x), y) expected = tvm.if_then_else( tvm.expr.LE(466036, (x * 4 + y)), tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)), (((x*4) + y) - 4) % 16, x), y) ck.verify(res, expected) + ck.verify(res2, expected) # can only simplify if condition res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3) expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3) -- 2.7.4