[ARITH] Explicitly state truncdiv/mod in pattern matching. (#3986)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 24 Sep 2019 18:01:37 +0000 (11:01 -0700)
committerGitHub <noreply@github.com>
Tue, 24 Sep 2019 18:01:37 +0000 (11:01 -0700)
* [ARITH] Explicitly state truncdiv/mod in pattern matching.

* Fix the dependent cpp test

include/tvm/expr_operator.h
src/arithmetic/canonical_simplify.cc
src/arithmetic/int_operator.h
src/arithmetic/modular_set.cc
src/arithmetic/pattern_match.h
src/arithmetic/rewrite_simplify.cc
src/lang/expr_operator.cc
tests/cpp/pattern_match_test.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index b0e82e7..5f0f849 100644 (file)
@@ -333,6 +333,20 @@ TVM_DLL Expr operator||(Expr a, Expr b);
  */
 TVM_DLL Expr operator!(Expr a);
 /*!
+ * \brief compute division in C semantics.
+ *
+ * a / b as in C/C++.
+ *
+ * When operands are integers, it directly corresponds to truncdiv.
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL Expr div(Expr a, Expr b);
+/*!
  * \brief compute trunc(a / b)
  *
  * This is the default integer division behavior in C.
@@ -640,6 +654,21 @@ inline Expr make_zero(Type t) {
   return make_const(t, 0);
 }
 
+/*!
+ * \brief Helper function to raise a compiler error about division ambiguity.
+ * \note The call to this function will always results in a compiler error.
+ * \tparam TA Any class type.
+ */
+template<typename TA>
+inline void DivAmbiguityError(const TA& a) {
+  constexpr bool div_ambiguity = !std::is_class<TA>::value;
+  static_assert(div_ambiguity,
+                "TVM supports multiple types of integer divisions, "
+                "please call div, floordiv/floormod or truncdiv/truncmod directly "
+                "to avoid ambiguity in the code. "
+                "Checkout these functions in expr_operator.h.");
+}
+
 // additional const expression overloading
 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)            \
   inline Expr Name(Expr& a, Expr b) {                          \
@@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
+TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>);  // NOLINT(*)
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<);  // NOLINT(*)
 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
 // integer related ops
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
index e1fa6d6..8b0d29e 100644 (file)
@@ -67,7 +67,7 @@ enum DivMode {
 
 inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
   if (mode == kTruncDiv) {
-    return a % b;
+    return truncmod(a, b);
   } else {
     CHECK_EQ(mode, kFloorDiv);
     return floormod(a, b);
@@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
 
 inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
   if (mode == kTruncDiv) {
-    return a / b;
+    return truncdiv(a, b);
   } else {
     CHECK_EQ(mode, kFloorDiv);
     return floordiv(a, b);
index d920944..e1694a3 100644 (file)
@@ -93,6 +93,26 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
 }
 
 /*!
+ * \brief Peform trunc division of two integers.
+ * \param x The left operand.
+ * \param y The right operand.
+ * \return the result.
+ */
+inline int64_t truncdiv(int64_t x, int64_t y) {
+  return x / y;
+}
+
+/*!
+ * \brief Compute the truncdiv remainder of two integers.
+ * \param x The left operand.
+ * \param y The right operand.
+ * \return the result.
+ */
+inline int64_t truncmod(int64_t x, int64_t y) {
+  return x % y;
+}
+
+/*!
  * \brief Peform floor division of two integers.
  * \param x The left operand.
  * \param y The right operand.
index c072f09..08454dd 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file modular_set.cc
  * \brief Modular set analysis
  */
@@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl :
     PVar<Var> var;
     PVar<Integer> coeff, base;
     // pattern match interesting constraints
-    if (((var % coeff) == base).Match(constraint)) {
+    if ((truncmod(var, coeff) == base).Match(constraint) ||
+        (floormod(var, coeff) == base).Match(constraint)) {
       Entry entry(coeff.Eval()->value, base.Eval()->value);
       return UpdateByIntersect(var.Eval(), entry);
     }
index 1278c7d..f7d5483 100644 (file)
@@ -300,31 +300,41 @@ class PConstWithTypeLike :
 };
 
 
-#define TVM_PATTERN_BINARY_OP(FuncName, NodeName)                   \
+#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep)     \
   template<typename TA, typename TB>                                \
   inline PBinaryExpr<NodeName, TA, TB>                              \
   FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {            \
+    CheckStep;                                                      \
     return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
   }                                                                 \
   template<typename TA>                                             \
   inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> >         \
   FuncName(const Pattern<TA>& a, int64_t b) {                       \
+    CheckStep;                                                      \
     return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b));     \
   }                                                                 \
   template<typename TA>                                             \
   inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA>          \
   FuncName(int64_t b, const Pattern<TA>& a) {                       \
+    CheckStep;                                                      \
     return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a);     \
   }
 
+#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
+  TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
+
+
+// raise ambiguity error for operator overload of / and %
+TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
+TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
+
 // arithmetic expressions
 TVM_PATTERN_BINARY_OP(operator+, ir::Add);
 TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
 TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
-TVM_PATTERN_BINARY_OP(operator/, ir::Div);
-TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
 TVM_PATTERN_BINARY_OP(min, ir::Min);
 TVM_PATTERN_BINARY_OP(max, ir::Max);
+TVM_PATTERN_BINARY_OP(div, ir::Div);
 TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
 TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
 TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
index a567f50..e3b3e7a 100644 (file)
@@ -194,7 +194,7 @@ Mutate_(const Add* op, const Expr& self) {
 
     // DivMod rules
     // truc div
-    TVM_TRY_REWRITE((x / c1) * c1 + x % c1, x);
+    TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
     // floor div
     TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
 
@@ -208,7 +208,7 @@ Mutate_(const Add* op, const Expr& self) {
 
     // DivMod rules
     // truc div
-    TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));
+    TVM_TRY_RECURSIVE_REWRITE(truncmod(y, c1) + x * c1, x * c1 + truncmod(y, c1));
     // floor div
     TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1));
   }
@@ -314,48 +314,49 @@ Mutate_(const Sub* op, const Expr& self) {
     // DivMod rules
     // trucdiv
     // NOTE: c*(x/c) + x % c == x is true all division mode.
-    TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
+    TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1),
                        c1.Eval()->value != 0);
-    TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1),
                        c1.Eval()->value != 0);
-    TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y,
+    TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y,
                        c1.Eval()->value != 0);
-    TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1),
+    TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1),
                        c1.Eval()->value != 0);
-    TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y,
+    TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y,
                        c1.Eval()->value != 0);
-    TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y,
+    TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y,
                        c1.Eval()->value != 0);
 
-    TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2,
+    TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
-    TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2,
+    TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2,
+    TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
                        c1.Eval()->value != 0 &&
                        c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
 
     // Proof in the case of floordiv, need positive condition.
     // let x = a * c3 + r
     // (x + c1) / c3 - x / c3 => (r + c1) / c3
-    TVM_TRY_REWRITE_IF((x + c1) / c3  - (x + c2) / c3,
-                       ((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3,
+    // NOTE: the use of floormod(c2, c3) was intentional to simplify the const.
+    TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3)  - truncdiv(x + c2, c3),
+                       truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
                        CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
                        c1.Eval()->value >= c2.Eval()->value &&
                        c3.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF((x + c1) / c3  - x / c3,
-                       (x % c3 + c1) / c3,
+    TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3)  - truncdiv(x, c3),
+                       truncdiv(truncmod(x, c3) + c1, c3),
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        c1.Eval()->value >= 0 &&
                        c3.Eval()->value > 0);
@@ -478,14 +479,15 @@ Mutate_(const Div* op, const Expr& self) {
 
   // Vector rules
   if (op->type.lanes() != 1) {
-    TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes),
-                    broadcast(x / y, lanes));
+    // NOTE: use div as the pattern also works for float.
+    TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)),
+                    broadcast(div(x, y), lanes));
     // ramp / bcast
-    if ((ramp(b1, c1, lanes) / broadcast(c2, lanes)).Match(ret)) {
+    if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
       if (c1val % c2val == 0) {
-        return ramp(b1 / c2, c1 / c2, lanes).Eval();
+        return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
       }
       // If all possible indices in ramp are the same.
       if (CanProveGreaterEqual(b1.Eval(), 0)) {
@@ -493,7 +495,7 @@ Mutate_(const Div* op, const Expr& self) {
         int64_t ramp_min = bmod->base / c2val;
         int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
         if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
-          return broadcast(b1 / c2, lanes).Eval();
+          return broadcast(div(b1, c2), lanes).Eval();
         }
       }
     }
@@ -508,73 +510,79 @@ Mutate_(const Div* op, const Expr& self) {
     // parts of tvm which still assume euclidean div. In this simplifier we assume that the division
     // is truncated, so perform const folding again.
     // NOTE: trunc div required
-    if ((c1 / c2).Match(ret)) {
+    if (truncdiv(c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
-      return make_const(op->type, c1val / c2val);
+      return make_const(op->type, truncdiv(c1val, c2val));
     }
 
     // while it is always true for trunc div
     // restrict to common case(positive div)
-    TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2),
+    TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2),
                        c1.Eval()->value > 0 && c2.Eval()->value > 0);
 
-    TVM_TRY_REWRITE_IF((x / c1 + c2) / c3, (x + c1 * c2) / (c1 * c3),
+    TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3),
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value >= 0 &&
                        c3.Eval()->value > 0 &&
                        CanProveGreaterEqual(x.Eval(), 0));
 
-    if (((x * c1) / c2).Match(ret)) {
+    if (truncdiv(x * c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
       if (c1val > 0 && c2val > 0) {
-        if (c1val % c2val == 0) return (x * (c1 / c2)).Eval();
-        if (c2val % c1val == 0) return (x / (c2 / c1)).Eval();
+        if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval();
+        if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval();
       }
     }
 
-    TVM_TRY_REWRITE(x / x, OneWithTypeLike(x));
-    TVM_TRY_REWRITE(x * c1 / x, c1);
-    TVM_TRY_REWRITE(c1 * x / x, c1);
+    TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x));
+    TVM_TRY_REWRITE(truncdiv(x * c1, x), c1);
+    TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1);
 
     // Rules involving 2-operands.
-    TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2),
+                       x * truncdiv(c1, c2) + truncdiv(y, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(min(x * c1, y) / c2, min(x * (c1 / c2), y / c2),
+    TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2),
+                       min(x * truncdiv(c1, c2), truncdiv(y, c2)),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(max(x * c1, y) / c2, max(x * (c1 / c2), y / c2),
+    TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2),
+                       max(x * truncdiv(c1, c2), truncdiv(y, c2)),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((y + x * c1) / c2, y / c2 + x * (c1 / c2),
+    TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2),
+                       truncdiv(y, c2) + x * truncdiv(c1, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(min(y, x * c1) / c2, min(y / c2, x * (c1 / c2)),
+    TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2),
+                       min(truncdiv(y, c2), x * truncdiv(c1, c2)),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(max(y, x * c1) / c2, max(y / c2, x * (c1 / c2)),
+    TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2),
+                       max(truncdiv(y, c2), x * truncdiv(c1, c2)),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
@@ -582,80 +590,89 @@ Mutate_(const Div* op, const Expr& self) {
                        CanProveGreaterEqual(y.Eval(), 0));
 
     // Rules involving 3-operands.
-    TVM_TRY_REWRITE_IF((x * c1 + y + z) / c2, x * (c1 / c2) + (y + z)/ c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2),
+                       x * truncdiv(c1, c2) + truncdiv(y + z, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x * c1 - y + z) / c2, x * (c1 / c2) + (z - y)/ c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2),
+                       x * truncdiv(c1, c2) + truncdiv(z - y, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((z - y).Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x * c1 + y - z) / c2, x * (c1 / c2) + (y - z)/ c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2),
+                       x * truncdiv(c1, c2) + truncdiv(y - z, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y - z).Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((y + x * c1 + z) / c2, x * (c1 / c2) + (y + z) / c2,
+    TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2),
+                       x * truncdiv(c1, c2) + truncdiv(y + z, c2),
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x + c1) / c2, x / c2 + c1 / c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2),
+                       truncdiv(x, c2) + truncdiv(c1, c2),
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x + y) / x, y / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
-    TVM_TRY_REWRITE_IF((y + x) / x, y / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(((x + y) + z) / x, (y + z) / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x),
+                       truncdiv(y + z, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
-    TVM_TRY_REWRITE_IF(((y + x) + z) / x, (y + z) / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x),
+                       truncdiv(y + z, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
-    TVM_TRY_REWRITE_IF((y + (z + x)) / x, (y + z) / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x),
+                       truncdiv(y + z, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
-    TVM_TRY_REWRITE_IF((y + (x + z)) / x, (y + z) / x + 1,
+    TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x),
+                       truncdiv(y + z, x) + 1,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual((y + z).Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x * y) / y, x,
+    TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
-    TVM_TRY_REWRITE_IF((y * x) / y, x,
+    TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x * z + y) / z, x + y / z,
+    TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z),
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0) &&
                        CanProveGreaterEqual(z.Eval(), 0));
-    TVM_TRY_REWRITE_IF((z * x + y) / z, x + y / z,
+    TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z),
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0) &&
                        CanProveGreaterEqual(z.Eval(), 0));
-    TVM_TRY_REWRITE_IF((y + x * z) / z, y / z + x,
+    TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0) &&
                        CanProveGreaterEqual(z.Eval(), 0));
-    TVM_TRY_REWRITE_IF((y + z * x) / z, y / z + x,
+    TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x,
                        CanProveGreaterEqual(x.Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0) &&
                        CanProveGreaterEqual(z.Eval(), 0));
@@ -679,15 +696,15 @@ Mutate_(const Mod* op, const Expr& self) {
 
   // Vector rules
   if (op->type.lanes() != 1) {
-    TVM_TRY_REWRITE(broadcast(x, lanes) % broadcast(y, lanes),
-                    broadcast(x % y, lanes));
+    TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)),
+                    broadcast(truncmod(x, y), lanes));
 
     // ramp % bcast
-    if ((ramp(b1, c1, lanes) % broadcast(c2, lanes)).Match(ret)) {
+    if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
       if (c1val % c2val == 0) {
-        return broadcast(b1 % c2, lanes).Eval();
+        return broadcast(truncmod(b1, c2), lanes).Eval();
       }
       // If all possible indices in ramp are the same.
       if (CanProveGreaterEqual(b1.Eval(), 0)) {
@@ -696,9 +713,10 @@ Mutate_(const Mod* op, const Expr& self) {
         int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
         if (bmod->coeff % c2val == 0) {
           if (ramp_min == ramp_max) {
-            return ramp(bmod->base % c2, c1, lanes).Eval();
+            return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
           } else {
-            return (ramp(bmod->base % c2, c1, lanes) % broadcast(c2, lanes)).Eval();
+            return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes),
+                            broadcast(c2, lanes)).Eval();
           }
         }
       }
@@ -709,23 +727,23 @@ Mutate_(const Mod* op, const Expr& self) {
     // Be-aware of the division rules:
     // We adopt the default C division uses truncation instead of floordiv.
     // This means most rules need to check non-negativeness of the operands.
-    TVM_TRY_REWRITE_IF((x * c1) % c2, ZeroWithTypeLike(x),
+    TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x),
                        c2.Eval()->value != 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0);
 
-    TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2,
+    TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2),
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual((x * c1).Eval(), 0) &&
                        CanProveGreaterEqual(y.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2,
+    TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2),
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value >= 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2,
+    TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2),
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value % c2.Eval()->value == 0 &&
                        CanProveGreaterEqual(x.Eval(), 0) &&
@@ -733,18 +751,18 @@ Mutate_(const Mod* op, const Expr& self) {
 
     // canonicalization: x % c == x % (-c) for truncated division
     // NOTE: trunc div required
-    TVM_TRY_RECURSIVE_REWRITE_IF(x % c1,
-                                 x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)),
+    TVM_TRY_RECURSIVE_REWRITE_IF(truncmod(x, c1),
+                                 truncmod(x, PConst<Expr>(make_const(op->type, -c1.Eval()->value))),
                                  c1.Eval()->value < 0);
 
     // try modular analysis
-    if ((x % c1).Match(ret)) {
+    if (truncmod(x, c1).Match(ret)) {
       ModularSet mod = analyzer_->modular_set(x.Eval());
       int64_t c1val = c1.Eval()->value;
       if (mod->coeff % c1val == 0 &&
           c1val > 0 &&
           CanProveGreaterEqual(x.Eval(), 0)) {
-        return (mod->base % c1).Eval();
+        return truncmod(mod->base, c1).Eval();
       }
     }
   }
@@ -798,7 +816,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
       int64_t c2val = c2.Eval()->value;
       if (c1val > 0 && c2val > 0) {
         if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval();
-        if (c2val % c1val == 0) return (floordiv(x, floordiv(c2, c1))).Eval();
+        if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval();
       }
     }
 
@@ -1025,18 +1043,18 @@ Mutate_(const Min* op, const Expr& self) {
     // DivMod rules
     // Divide up rounding: truc div
     // NOTE: trucdiv(x, y) >= floordiv(x, y)
-    TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x,
+    TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x,
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2),
+    TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value &&
                        CanProveGreaterEqual(x.Eval(), 0));
 
-    TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x,
+    TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x,
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2),
+    TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2),
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value &&
                        CanProveGreaterEqual(x.Eval(), 0));
@@ -1104,11 +1122,11 @@ Mutate_(const Min* op, const Expr& self) {
     TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));
 
     // scaling rule
-    if (min(x / c1, y / c1).Match(ret)) {
+    if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
       if (c1.Eval()->value > 0) {
-        return (min(x, y) / c1).Eval();
+        return truncdiv(min(x, y), c1).Eval();
       } else {
-        return (max(x, y) / c1).Eval();
+        return truncdiv(max(x, y), c1).Eval();
       }
     }
     if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
@@ -1210,10 +1228,12 @@ Mutate_(const Max* op, const Expr& self) {
     // DivMod rules
     // Divide up rounding: truc div
     // NOTE: trucdiv(x, y) >= floordiv(x, y)
-    TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2,
+    TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x),
+                       truncdiv(x + c1, c2) * c2,
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value);
-    TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2,
+    TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2),
+                       truncdiv(x + c1, c2) * c2,
                        c2.Eval()->value > 0 &&
                        c1.Eval()->value + 1 == c2.Eval()->value);
 
@@ -1276,11 +1296,11 @@ Mutate_(const Max* op, const Expr& self) {
     TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
 
     // scaling rule
-    if (max(x / c1, y / c1).Match(ret)) {
+    if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
       if (c1.Eval()->value > 0) {
-        return (max(x, y) / c1).Eval();
+        return truncdiv(max(x, y), c1).Eval();
       } else {
-        return (min(x, y) / c1).Eval();
+        return truncdiv(min(x, y), c1).Eval();
       }
     }
     if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
@@ -1425,70 +1445,70 @@ Mutate_(const LT* op, const Expr& self) {
 
     // constant cancelation: only need to make use of one mod
     // truc div
-    TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
+    TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1,
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value > 0);
     // NOTE: trunc div required
-    TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2,
+    TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2),
                        c1.Eval()->value <= 0 &&
                        c2.Eval()->value > 0);
     // NOTE: trunc div required (euclidean is ok too, floored is not)
-    TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x,
+    TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x,
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value < 0);
     // NOTE: trunc div required (floored is ok too, euclidean is not)
-    TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x,
+    TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x,
                        c1.Eval()->value <= 0 &&
                        c2.Eval()->value < 0);
     // NOTE: trunc div required
-    TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x,
+    TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x,
                        c1.Eval()->value < 0 &&
                        c2.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
+    TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x,
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0);
     // NOTE: trunc div required (floored is ok too, euclidean is not)
-    TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1,
+    TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1,
                        c1.Eval()->value < 0 &&
                        c2.Eval()->value < 0);
     // NOTE: trunc div required (euclidean is ok too, floored is not)
-    TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2,
+    TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2),
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value < 0);
     // DivMod rules
     // trucdiv
-    TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2,
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value > 0);
     // NOTE: trunc div required
-    TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1,
                        c1.Eval()->value > 0 &&
                        c2.Eval()->value <= 0);
 
-    TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
+    TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x,
                        c1.Eval()->value >= 0 &&
                        c2.Eval()->value > 0);
     // NOTE: trunc div required
-    TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x,
+    TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x,
                        c1.Eval()->value < 0 &&
                        c2.Eval()->value > 0);
 
     // invariance for any div mod: x - (x / c1) * c1 == x % c1
-    TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1),
                        c1.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y,
                        c1.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1,
+    TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1),
                        c1.Eval()->value > 0);
 
-    TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x,
-                       c2 < (x + c2) % c1,
+    TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x,
+                       c2 < truncmod(x + c2, c1),
                        c1.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x + y,
-                       c2 < (x + c2) % c1 + y,
+    TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y,
+                       c2 < truncmod(x + c2, c1) + y,
                        c1.Eval()->value > 0);
-    TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x - y,
-                       y < (x + c2) % c1 + (0 - c2),
+    TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y,
+                       y < truncmod(x + c2, c1) + (0 - c2),
                        c1.Eval()->value > 0);
 
     // floordiv
index d7a40c1..f66b997 100644 (file)
@@ -178,13 +178,19 @@ Expr operator*(Expr a, Expr b) {
   return ir::Mul::make(a, b);
 }
 
-Expr truncdiv(Expr a, Expr b) {
+Expr div(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::Div>(a, b);
   if (ret.defined()) return ret;
   return ir::Div::make(a, b);
 }
 
+Expr truncdiv(Expr a, Expr b) {
+  CHECK(a.type().is_int() || a.type().is_uint());
+  CHECK(b.type().is_int() || b.type().is_uint());
+  return div(a, b);
+}
+
 Expr truncmod(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   Expr ret = arith::TryConstFold<ir::Mod>(a, b);
@@ -193,7 +199,7 @@ Expr truncmod(Expr a, Expr b) {
 }
 
 Expr operator/(Expr a, Expr b) {
-  return truncdiv(a, b);
+  return div(a, b);
 }
 
 Expr operator%(Expr a, Expr b) {
index 934ac62..7fb654b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -47,9 +47,9 @@ TEST(Pattern, Basic) {
   }
   CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
   CHECK((px + min(py, px)).Match(z + min(y, z)));
-  CHECK((px + py / (px * py)).Match(x + 2 / (x * 2)));
-  CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2)));
-  CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2)));
+  CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
+  CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
+  CHECK((px - floormod(py, px * PConst<Expr>(2))).Match(x - floormod(2, x * 2)));
 
   // logicals
   CHECK((px == pz).Match(x == 1));
index ca30354..246ac13 100644 (file)
@@ -56,24 +56,26 @@ def test_vector_simplify():
               tvm.expr.Ramp(x * 2, 8, 4))
 
     ## DivMod rules
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # truc div
-    ck.verify(y.astype("int32x2") / x.astype("int32x2"),
-              (y / x).astype("int32x2"))
-    ck.verify(tvm.expr.Ramp(x, 4, 4) / 2,
-              tvm.expr.Ramp(x/ 2, 2, 4))
+    ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")),
+              tdiv(y, x).astype("int32x2"))
+    ck.verify(tdiv(tvm.expr.Ramp(x, 4, 4), 2),
+              tvm.expr.Ramp(tdiv(x, 2), 2, 4))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
-    ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8,
+    ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
               (x).astype("int32x4"))
-    ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8,
-              tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8)
-    ck.verify(y.astype("int32x2") % x.astype("int32x2"),
-              (y % x).astype("int32x2"))
-    ck.verify(tvm.expr.Ramp(x, 4, 4) % 2,
-              tvm.expr.Broadcast(x % 2, 4))
-    ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8,
+    ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8),
+              tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8))
+    ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
+              tmod(y, x).astype("int32x2"))
+    ck.verify(tmod(tvm.expr.Ramp(x, 4, 4), 2),
+              tvm.expr.Broadcast(tmod(x, 2), 4))
+    ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
               tvm.expr.Ramp(1, 1, 4))
-    ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
-              tvm.expr.Ramp(1, 15, 4) % 8)
+    ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8),
+              tmod(tvm.expr.Ramp(1, 15, 4), 8))
 
     # floor div
     fld = tvm.floordiv
@@ -187,10 +189,12 @@ def test_add_index_simplify():
     ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
 
     # DivMod rules
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # truc div
-    ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10))
+    ck.verify(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
-    ck.verify((x / 8) * 8 + x % 8, x)
+    ck.verify(tdiv(x, 8) * 8 + tmod(x, 8), x)
 
     # floor div
     fld = tvm.floordiv
@@ -256,31 +260,33 @@ def test_sub_index_simplify():
 
     # DivMod patterns
     # truc div
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
-    ck.verify(x - (x / 3) * 3, x % 3)
-
-    ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3)
-    ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3)
-
-    ck.verify(y - (y / (-5)) * (-5), y % 5)
-    ck.verify((y / 3) * 3 - y, 0 - y % 3)
-    ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6)
-    ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5)
-    ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z)
-    ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5)
-    ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z)
-    ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z)
-
-    ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3)
-    ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2))
-    ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
-    ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5)
-    ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2)
-    ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2)
-    ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5)
-    ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5)
-    ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
-    ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)
+    ck.verify(x - tdiv(x, 3) * 3, tmod(x, 3))
+
+    ck.verify(tdiv(x + 5, 3) - tdiv(x, 3), tdiv(tmod(x, 3) + 5, 3))
+    ck.verify(tdiv(x + 5, 3) - tdiv(x + 1, 3), tdiv(tmod(x + 1, 3) + 4, 3))
+
+    ck.verify(y - tdiv(y, (-5)) * (-5), tmod(y, 5))
+    ck.verify(tdiv(y, 3) * 3 - y, 0 - tmod(y, 3))
+    ck.verify(y - tdiv(y - 6, 5) * 5, tmod(y + (-6), 5) + 6)
+    ck.verify(tdiv(y - 6, 5) * 5 - y, (-6) - tmod(y + (-6), 5))
+    ck.verify(y - tdiv(y + z, 5) * 5, tmod(y + z, 5) - z)
+    ck.verify(tdiv(y + z, 5) * 5 - y, z - tmod(y + z, 5))
+    ck.verify(y - tdiv(y - z, 5) * 5, tmod(y - z, 5) + z)
+    ck.verify(tdiv(y - z, 5) * 5 - y, 0 - tmod(y - z, 5) - z)
+
+    ck.verify(y * 3 - tdiv(y, 2) * 6, tmod(y, 2) * 3)
+    ck.verify(tdiv(y, 3) * 6 - y * 2, tmod(y, 3) * (-2))
+    ck.verify(y * 5 - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
+    ck.verify(y * 5 - tdiv(y - z, 2) * 10, (tmod(y - z, 2) + z) * 5)
+    ck.verify(tdiv(y + z, 3) * 6 - y * 2, (z - tmod(y + z, 3)) * 2)
+    ck.verify(tdiv(y - z, 3) * 6 - y * 2, (0 - tmod(y - z, 3) - z) * 2)
+    ck.verify(5 * y - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
+    ck.verify(5 * y - 10 * tdiv(y - z, 2), (tmod(y - z, 2) + z) * 5)
+    ck.verify(6 * tdiv(y + z, 3) - y * 2, (z - tmod(y + z, 3)) * 2)
+    ck.verify(tdiv(y - z, 3) * 6 - 2 * y, (0 - tmod(y - z, 3) - z) * 2)
 
     # floor div
     fld = tvm.floordiv
@@ -323,46 +329,48 @@ def test_mul_index_simplify():
 def test_div_index_simplify():
     ck = RewriteChecker()
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
 
-    ck.verify(x / x, 1)
+    ck.verify(tdiv(x, x), 1)
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
     ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
 
-    ck.verify(x / 2 / 3, x / 6)
-    ck.verify((x / 2 + 1) / 3, (x + 2) / 6)
-    ck.verify(x * 2 / 4, x / 2)
-    ck.verify(x * 4 / 2, x * 2)
+    ck.verify(tdiv(tdiv(x, 2), 3), tdiv(x, 6))
+    ck.verify(tdiv(tdiv(x, 2) + 1, 3), tdiv(x + 2, 6))
+    ck.verify(tdiv(x * 2, 4), tdiv(x, 2))
+    ck.verify(tdiv(x * 4, 2), x * 2)
 
-    ck.verify((x * 4 + y) / 2, x * 2 + y / 2)
-    ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2))
-    ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2))
+    ck.verify(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2))
+    ck.verify(tdiv(tvm.min(x * 6, y), 2), tvm.min(x * 3, tdiv(y, 2)))
+    ck.verify(tdiv(tvm.max(x * 6, y), 2), tvm.max(x * 3, tdiv(y, 2)))
 
-    ck.verify((y + x * 4) / 2, y / 2 + x * 2)
-    ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3))
-    ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3))
+    ck.verify(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2)
+    ck.verify(tdiv(tvm.min(y, x * 6), 2), tvm.min(tdiv(y, 2), x * 3))
+    ck.verify(tdiv(tvm.max(y, x * 6), 2), tvm.max(tdiv(y, 2), x * 3))
 
     # 3-operands
-    ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2)
-    ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1)
-    ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1)
-    ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2)
-    ck.verify((x + 4) / 2, x / 2 + 2)
+    ck.verify(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2))
+    ck.verify(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1)
+    ck.verify(tdiv(x * 6 + (y + 3) - y, 2), x * 3 + 1)
+    ck.verify(tdiv(y + x * 6 + z, 2), x * 3 + tdiv(y + z, 2))
+    ck.verify(tdiv(x + 4, 2), tdiv(x, 2) + 2)
 
-    ck.verify((x + y) / x, y / x + 1)
-    ck.verify((y + x) / x, y / x + 1)
-    ck.verify(((x + y) + z) / x, (y + z) / x + 1)
-    ck.verify(((y + x) + z) / x, (y + z) / x + 1)
-    ck.verify((y + (x + z)) / x, (y + z) / x + 1)
-    ck.verify((y + (z + x)) / x, (y + z) / x + 1)
+    ck.verify(tdiv(x + y, x), tdiv(y, x) + 1)
+    ck.verify(tdiv(y + x, x), tdiv(y, x) + 1)
+    ck.verify(tdiv((x + y) + z, x), tdiv(y + z, x) + 1)
+    ck.verify(tdiv((y + x) + z, x), tdiv(y + z, x) + 1)
+    ck.verify(tdiv(y + (x + z), x), tdiv(y + z, x) + 1)
+    ck.verify(tdiv(y + (z + x), x), tdiv(y + z, x) + 1)
 
-    ck.verify((x * y) / y, x)
-    ck.verify((y * x) / y, x)
+    ck.verify(tdiv(x * y, y), x)
+    ck.verify(tdiv(y * x, y), x)
 
-    ck.verify((x * z + y) / z, x + y / z)
-    ck.verify((z * x + y) / z, x + y / z)
-    ck.verify((y + x * z) / z, y / z + x)
-    ck.verify((y + z * x) / z, y / z + x)
+    ck.verify(tdiv(x * z + y, z), x + tdiv(y, z))
+    ck.verify(tdiv(z * x + y, z), x + tdiv(y, z))
+    ck.verify(tdiv(y + x * z, z), tdiv(y, z) + x)
+    ck.verify(tdiv(y + z * x, z), tdiv(y, z) + x)
 
 
 def test_floordiv_index_simplify():
@@ -417,31 +425,33 @@ def test_mod_index_simplify():
     ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
     ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True)
     ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True)
-
-    ck.verify(x * 10 % 2, 0)
-    ck.verify((x * 10 + y) % 2, y % 2)
-    ck.verify((x + 10) % 2, x % 2)
-    ck.verify((x + y * 10) % 2, x % 2)
-    ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
-    ck.verify(x * 10 % -2, 0)
-    ck.verify((x * 10 + y) % -2, y % 2)
-    ck.verify((x + 10) % -2, x % 2)
-    ck.verify((x + y * 10) % -2, x % 2)
-    ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1)
-
-    ck.verify(x * (-10) % 2, 0)
-    ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2)
-    ck.verify((x + (-10)) % 2, (x + (-10)) % 2)
-    ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2)
-    ck.verify(x * (-10) % -2, 0)
-
-    ck.verify(nx * 10 % 2, 0)
-    ck.verify((nx * (-10) + y) % 2, y % 2)
-    ck.verify((x + ny * (-10)) % 2, x % 2)
-    ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1)
-    ck.verify(nx * 10 % -2, 0)
-    ck.verify((nx * (-10) + y) % -2, y % 2)
-    ck.verify((x + ny * (-10)) % -2, x % 2)
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
+
+    ck.verify(tmod(x * 10, 2), 0)
+    ck.verify(tmod(x * 10 + y, 2), tmod(y, 2))
+    ck.verify(tmod(x + 10, 2), tmod(x, 2))
+    ck.verify(tmod(x + y * 10, 2), tmod(x, 2))
+    ck.verify(tmod(x* 10 + 1 + y * 2 + 2, 2), 1)
+    ck.verify(tmod(x * 10, -2), 0)
+    ck.verify(tmod(x * 10 + y, -2), tmod(y, 2))
+    ck.verify(tmod(x + 10, -2), tmod(x, 2))
+    ck.verify(tmod(x + y * 10, -2), tmod(x, 2))
+    ck.verify(tmod(x* 10 + 1 + y * 2 + 2, -2), 1)
+
+    ck.verify(tmod(x * (-10), 2), 0)
+    ck.verify(tmod(x * (-10) + y, 2), tmod(x * (-10) + y, 2))
+    ck.verify(tmod(x + (-10), 2), tmod(x + (-10), 2))
+    ck.verify(tmod(x + y * (-10), 2), tmod(x + y * (-10), 2))
+    ck.verify(tmod(x * (-10), -2), 0)
+
+    ck.verify(tmod(nx * 10, 2), 0)
+    ck.verify(tmod(nx * (-10) + y, 2), tmod(y, 2))
+    ck.verify(tmod(x + ny * (-10), 2), tmod(x, 2))
+    ck.verify(tmod(nx * (-10) + 1 + ny * (-2) + 2, 2), 1)
+    ck.verify(tmod(nx * 10, -2), 0)
+    ck.verify(tmod(nx * (-10) + y, -2), tmod(y, 2))
+    ck.verify(tmod(x + ny * (-10), -2), tmod(x, 2))
 
 
 def test_floormod_index_simplify():
@@ -468,8 +478,10 @@ def test_min_index_simplify():
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
     fld = tvm.floordiv
     flm = tvm.floormod
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # const int bound
-    ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2)
+    ck.verify(tvm.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2))
     ck.verify(tvm.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2))
 
     ck.verify(tvm.min(x + 1, x + 10), x + 1)
@@ -521,13 +533,14 @@ def test_min_index_simplify():
     # DivMod rules
     # truc div
     ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
-    ck.verify(tvm.min((x + 3) / 4 * 4, x), x)
-    ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4))
-    ck.verify(tvm.min(x, (x + 3) / 4 * 4), x)
-    ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4))
+    ck.verify(tvm.min(tdiv(x + 3, 4) * 4, x), x)
+    ck.verify(tvm.min(tdiv(x + 3, 4) * 4, tvm.max(x, 4)), tvm.max(x, 4))
+    ck.verify(tvm.min(x, tdiv(x + 3, 4) * 4), x)
+    ck.verify(tvm.min(tvm.max(x, 4), tdiv(x + 3, 4) * 4), tvm.max(x, 4))
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
-    ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
-    ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
+    ck.verify(tvm.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.min(x, y), 10))
+    ck.verify(tvm.min(tdiv(x, (-10)), tdiv(y, (-10))),
+              tdiv(tvm.max(x, y), (-10)))
 
     # floor div
     ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
@@ -545,8 +558,10 @@ def test_max_index_simplify():
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
     flm = tvm.floormod
     fld = tvm.floordiv
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # const int bound
-    ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10)
+    ck.verify(tvm.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10)
     ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10)
 
     ck.verify(tvm.max(x + 1, x + 10), x + 10)
@@ -597,9 +612,9 @@ def test_max_index_simplify():
 
     # DivMod rules
     # truc div
-    ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
-    ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
-    ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4)
+    ck.verify(tvm.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.max(x, y), 10))
+    ck.verify(tvm.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.min(x, y), (-10)))
+    ck.verify(tvm.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4)
 
     # floordiv
     ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10))
@@ -614,11 +629,13 @@ def test_cmp_simplify():
     x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
     flm = tvm.floormod
     fld = tvm.floordiv
+    tdiv = tvm.truncdiv
+    tmod = tvm.truncmod
     # const int bound
-    ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool"))
-    ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool"))
-    ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool"))
-    ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool"))
+    ck.verify((tmod(x, 2) + 10).equal(0), tvm.const(0, "bool"))
+    ck.verify(tvm.expr.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool"))
+    ck.verify(tmod(x, 2) + 10 > 1, tvm.const(1, "bool"))
+    ck.verify(tmod(x, 2) + 10 <= 1, tvm.const(0, "bool"))
     ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool"))
     ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool"))
 
@@ -688,26 +705,26 @@ def test_cmp_simplify():
 
     # DivMod rules
     # truc div
-    ck.verify(x / 2 < 3, x < 6)
-    ck.verify(3 < x / 2, tvm.expr.LT(7, x))
-    ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
-    ck.verify(x / 2 >= 1, tvm.expr.LE(2, x))
-    ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x))
-    ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x))
+    ck.verify(tdiv(x, 2) < 3, x < 6)
+    ck.verify(3 < tdiv(x, 2), tvm.expr.LT(7, x))
+    ck.verify(tdiv(x, 3) >= 0, tvm.expr.LE(-2, x))
+    ck.verify(tdiv(x, 2) >= 1, tvm.expr.LE(2, x))
+    ck.verify(tdiv(x, 2) >= 0, tvm.expr.LE(-1, x))
+    ck.verify(tdiv(x, 2) >= -1, tvm.expr.LE(-3, x))
 
-    ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3))
-    ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1))
-    ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2))
+    ck.verify(tdiv(x, 2) <= 1, tvm.expr.LE(x, 3))
+    ck.verify(tdiv(x, 2) <= 0, tvm.expr.LE(x, 1))
+    ck.verify(tdiv(x, 2) <= -1, tvm.expr.LE(x, -2))
 
-    ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
-    ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
+    ck.verify(tdiv(x, 4) * 4 < x, tvm.expr.LT(0, tmod(x, 4)))
+    ck.verify(tdiv(x, 4) * 4 >= x, tvm.expr.LE(tmod(x, 4), 0))
 
-    ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y))
-    ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4))
+    ck.verify(tdiv(x, 4) * 4 < x + y, tvm.expr.LT(0, tmod(x, 4) + y))
+    ck.verify(tdiv(x, 4) * 4 < x - y, tvm.expr.LT(y, tmod(x, 4)))
 
-    ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2))
-    ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2))
-    ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.expr.LE(tmod(x + 2, 4), 2))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.expr.LE(tmod(x + 2, 4) + y, 2))
+    ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.expr.LE(tmod(x + 2, 4) + (-2), y))
 
     # floor div
     ck.verify(fld(x, 2) < 3, x < 6)
@@ -753,7 +770,7 @@ def test_cmp_simplify():
     ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
     ck.verify(y*y >= 0, tvm.const(1, "bool"))
     ck.verify(x*6 <= -3, tvm.const(0, "bool"))
-    ck.verify((y - 1) % 3 == 0, (y + (-1)) % 3 == 0)
+    ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0)
 
 
 def test_logical_simplify():