Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
+ Expr Mutate_(const Min* op, const Expr& self) final;
+ Expr Mutate_(const Max* op, const Expr& self) final;
+ Expr Mutate_(const EQ* op, const Expr& self) final;
+ Expr Mutate_(const NE* op, const Expr& self) final;
+ Expr Mutate_(const LT* op, const Expr& self) final;
+ Expr Mutate_(const LE* op, const Expr& self) final;
+ Expr Mutate_(const GT* op, const Expr& self) final;
+ Expr Mutate_(const GE* op, const Expr& self) final;
+ Expr Mutate_(const And* op, const Expr& self) final;
+ Expr Mutate_(const Or* op, const Expr& self) final;
+ Expr Mutate_(const Not* op, const Expr& self) final;
+ Expr Mutate_(const Select* op, const Expr& self) final;
+ Expr Mutate_(const Ramp* op, const Expr& self) final;
private:
+ /*! \brief internal structure for comparison. */
+ enum CompareResult {
+ kUnknown,
+ kEQ,
+ kGT,
+ kLT,
+ kNE
+ };
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
- Expr res = Mutate(x);
- if (const auto* ptr = res.as<ir::IntImm>()) {
- return ptr->value == val;
+ return TryCompare(x, val) == kEQ;
+ }
+ // try to prove x equals val
+ CompareResult TryCompare(const Expr& x, int64_t val) {
+ Expr diff = Mutate(x);
+ if (const auto* ptr = diff.as<IntImm>()) {
+ if (ptr->value == val) {
+ return kEQ;
+ } else if (ptr->value > val) {
+ return kGT;
+ } else if (ptr->value < val) {
+ return kLT;
+ }
+ }
+ if (val == 0) {
+ ModularSet dmod = parent_->modular_set(diff);
+ if (dmod->base != 0) {
+ return kNE;
+ }
+ }
+ ConstIntBound dbound = parent_->const_int_bound(diff);
+ if (dbound->min_value > val) {
+ return kGT;
+ }
+ if (dbound->max_value < val) {
+ return kLT;
}
- return false;
+ return kUnknown;
}
+
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
- PVar<Integer> c1, c2, c3;
+ PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
return ret;
}
+Expr RewriteSimplifier::Impl::
+Mutate_(const Min* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Min>();
+ Expr const_res = TryConstFold<Min>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y, z, s1, s2;
+ // Pattern var match IntImm
+ PVar<Integer> c1, c2;
+ PVar<int> lanes;
+
+ // vector rule
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)),
+ broadcast(min(x, y), lanes));
+ TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
+ min(x, broadcast(min(y, z), lanes)));
+ }
+ if (IsIndexType(op->type)) {
+ TVM_TRY_REWRITE(min(x, x), x);
+
+ // constant int bound
+ ConstIntBound a_bound = parent_->const_int_bound(op->a);
+ ConstIntBound b_bound = parent_->const_int_bound(op->b);
+ if (a_bound->max_value <= b_bound->min_value) {
+ return op->a;
+ }
+ if (b_bound->max_value <= a_bound->min_value) {
+ return op->b;
+ }
+
+ // constant comparison
+ if (min(x + c1, x + c2).Match(ret)) {
+ if (c1.Eval()->value < c2.Eval()->value) {
+ return (x + c1).Eval();
+ } else {
+ return (x + c2).Eval();
+ }
+ }
+ if (min(x + c1, x).Match(ret) ||
+ min(x, x + c1).Match(ret)) {
+ if (c1.Eval()->value < 0) {
+ return (x + c1).Eval();
+ } else {
+ return x.Eval();
+ }
+ }
+ if (min(c1 - x, c2 - x).Match(ret)) {
+ if (c1.Eval()->value < c2.Eval()->value) {
+ return (c1 - x).Eval();
+ } else {
+ return (c2 - x).Eval();
+ }
+ }
+
+ // Divide up rounding
+ TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x,
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2),
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value &&
+ CanProveGreaterEqual(x.Eval(), 0));
+
+ TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x,
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2),
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value &&
+ CanProveGreaterEqual(x.Eval(), 0));
+
+ TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y));
+ TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y));
+ TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y));
+ TVM_TRY_REWRITE(min(min(x, y), max(y, x)), min(x, y));
+
+ TVM_TRY_REWRITE(min(max(x, y), x), x);
+ TVM_TRY_REWRITE(min(max(x, y), y), y);
+ TVM_TRY_REWRITE(min(min(x, y), x), min(x, y));
+ TVM_TRY_REWRITE(min(min(x, y), y), min(x, y));
+
+ TVM_TRY_REWRITE(min(x, max(x, y)), x);
+ TVM_TRY_REWRITE(min(y, max(x, y)), y);
+ TVM_TRY_REWRITE(min(x, min(x, y)), min(x, y));
+ TVM_TRY_REWRITE(min(y, min(x, y)), min(x, y));
+
+ TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z));
+ TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1));
+ TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y),
+ min(min(min(min(x, y), z), s1), s2));
+
+ TVM_TRY_REWRITE(min(max(x, y), max(x, z)), max(min(y, z), x));
+ TVM_TRY_REWRITE(min(max(x, y), max(z, x)), max(min(y, z), x));
+ TVM_TRY_REWRITE(min(max(y, x), max(x, z)), max(min(y, z), x));
+ TVM_TRY_REWRITE(min(max(y, x), max(z, x)), max(min(y, z), x));
+
+ TVM_TRY_REWRITE(min(min(x, y), min(x, z)), min(min(y, z), x));
+ TVM_TRY_REWRITE(min(min(x, y), min(z, x)), min(min(y, z), x));
+ TVM_TRY_REWRITE(min(min(y, x), min(x, z)), min(min(y, z), x));
+ TVM_TRY_REWRITE(min(min(y, x), min(z, x)), min(min(y, z), x));
+
+ TVM_TRY_REWRITE(min(y + x, z + x), min(y, z) + x);
+ TVM_TRY_REWRITE(min(y + x, x + z), min(y, z) + x);
+ TVM_TRY_REWRITE(min(x + y, x + z), min(y, z) + x);
+ TVM_TRY_REWRITE(min(x + y, z + x), min(y, z) + x);
+
+ // sub distribution
+ TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x);
+ TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z));
+
+ // constant folding rule.
+ TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));
+
+ // scaling rule
+ if (min(x / c1, y / c1).Match(ret)) {
+ if (c1.Eval()->value > 0) {
+ return (min(x, y) / c1).Eval();
+ } else {
+ return (max(x, y) / c1).Eval();
+ }
+ }
+ if (min(x * c1, y * c1).Match(ret)) {
+ if (c1.Eval()->value > 0) {
+ return (min(x, y) * c1).Eval();
+ } else {
+ return (max(x, y) * c1).Eval();
+ }
+ }
+ if (min(x * c1, c2).Match(ret)) {
+ int64_t c1val = c1.Eval()->value;
+ int64_t c2val = c2.Eval()->value;
+ if (c2val % c1val == 0) {
+ if (c2val / c1val >= 0) {
+ return (min(x, c2val / c1val) * c1val).Eval();
+ } else {
+ return (max(x, c2val / c1val) * c1val).Eval();
+ }
+ }
+ }
+
+ // canonicalization
+ TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
+ TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1));
+ }
+
+ // condition rules.
+ TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)),
+ select(x, min(y, s1), min(z, s2)));
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const Max* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Max>();
+ Expr const_res = TryConstFold<Max>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y, z, s1, s2;
+ // Pattern var match IntImm
+ PVar<Integer> c1, c2;
+ PVar<int> lanes;
+
+ // vector rule
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)),
+ broadcast(max(x, y), lanes));
+ TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
+ max(x, broadcast(max(y, z), lanes)));
+ }
+ if (IsIndexType(op->type)) {
+ TVM_TRY_REWRITE(max(x, x), x);
+
+ // constant int bound
+ ConstIntBound a_bound = parent_->const_int_bound(op->a);
+ ConstIntBound b_bound = parent_->const_int_bound(op->b);
+ if (a_bound->min_value >= b_bound->max_value) {
+ return op->a;
+ }
+ if (b_bound->min_value >= a_bound->max_value) {
+ return op->b;
+ }
+
+ // constant comparison
+ if (max(x + c1, x + c2).Match(ret)) {
+ if (c1.Eval()->value > c2.Eval()->value) {
+ return (x + c1).Eval();
+ } else {
+ return (x + c2).Eval();
+ }
+ }
+ if (max(x + c1, x).Match(ret) ||
+ max(x, x + c1).Match(ret)) {
+ if (c1.Eval()->value > 0) {
+ return (x + c1).Eval();
+ } else {
+ return x.Eval();
+ }
+ }
+ if (max(c1 - x, c2 - x).Match(ret)) {
+ if (c1.Eval()->value > c2.Eval()->value) {
+ return (c1 - x).Eval();
+ } else {
+ return (c2 - x).Eval();
+ }
+ }
+
+ // Divide up rounding
+ TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2,
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value);
+ TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2,
+ c2.Eval()->value > 0 &&
+ c1.Eval()->value + 1 == c2.Eval()->value);
+
+ TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y));
+ TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y));
+ TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y));
+ TVM_TRY_REWRITE(max(max(x, y), min(y, x)), max(x, y));
+
+ TVM_TRY_REWRITE(max(min(x, y), x), x);
+ TVM_TRY_REWRITE(max(min(x, y), y), y);
+ TVM_TRY_REWRITE(max(max(x, y), x), max(x, y));
+ TVM_TRY_REWRITE(max(max(x, y), y), max(x, y));
+
+ TVM_TRY_REWRITE(max(x, min(x, y)), x);
+ TVM_TRY_REWRITE(max(y, min(x, y)), y);
+ TVM_TRY_REWRITE(max(x, max(x, y)), max(x, y));
+ TVM_TRY_REWRITE(max(y, max(x, y)), max(x, y));
+
+ TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z));
+ TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1));
+ TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y),
+ max(max(max(max(x, y), z), s1), s2));
+
+ // max/max cancelation
+ TVM_TRY_REWRITE(max(max(x, y), max(x, z)), max(max(y, z), x));
+ TVM_TRY_REWRITE(max(max(x, y), max(z, x)), max(max(y, z), x));
+ TVM_TRY_REWRITE(max(max(y, x), max(x, z)), max(max(y, z), x));
+ TVM_TRY_REWRITE(max(max(y, x), max(z, x)), max(max(y, z), x));
+
+ // max/min distribution
+ TVM_TRY_REWRITE(max(min(x, y), min(x, z)), min(max(y, z), x));
+ TVM_TRY_REWRITE(max(min(x, y), min(z, x)), min(max(y, z), x));
+ TVM_TRY_REWRITE(max(min(y, x), min(x, z)), min(max(y, z), x));
+ TVM_TRY_REWRITE(max(min(y, x), min(z, x)), min(max(y, z), x));
+
+ // add distribution
+ TVM_TRY_REWRITE(max(y + x, z + x), max(y, z) + x);
+ TVM_TRY_REWRITE(max(y + x, x + z), max(y, z) + x);
+ TVM_TRY_REWRITE(max(x + y, x + z), max(y, z) + x);
+ TVM_TRY_REWRITE(max(x + y, z + x), max(y, z) + x);
+
+ // sub distribution
+ TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x);
+ TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z));
+
+ // constant folding rule.
+ TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
+
+ // scaling rule
+ if (max(x / c1, y / c1).Match(ret)) {
+ if (c1.Eval()->value > 0) {
+ return (max(x, y) / c1).Eval();
+ } else {
+ return (min(x, y) / c1).Eval();
+ }
+ }
+ if (max(x * c1, y * c1).Match(ret)) {
+ if (c1.Eval()->value > 0) {
+ return (max(x, y) * c1).Eval();
+ } else {
+ return (min(x, y) * c1).Eval();
+ }
+ }
+ if (max(x * c1, c2).Match(ret)) {
+ int64_t c1val = c1.Eval()->value;
+ int64_t c2val = c2.Eval()->value;
+ if (c2val % c1val == 0) {
+ if (c2val / c1val >= 0) {
+ return (max(x, c2val / c1val) * c1val).Eval();
+ } else {
+ return (min(x, c2val / c1val) * c1val).Eval();
+ }
+ }
+ }
+
+ // canonicalization
+ TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
+ TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1));
+ }
+
+ // condition rules.
+ TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)),
+ select(x, max(y, s1), max(z, s2)));
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const EQ* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<EQ>();
+ Expr const_res = TryConstFold<EQ>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y;
+ // Pattern var match IntImm
+ PVar<Integer> c1;
+ PVar<int> lanes;
+
+ // vector rule
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes),
+ broadcast(x == y, lanes));
+ }
+
+ if (IsIndexType(op->a.type())) {
+ CompareResult result = TryCompare(op->a - op->b, 0);
+ if (result != kUnknown) {
+ if (result == kEQ) {
+ return make_const(op->type, true);
+ } else {
+ return make_const(op->type, false);
+ }
+ }
+ TVM_TRY_REWRITE(x - c1 == 0, x == c1);
+ TVM_TRY_REWRITE(c1 - x == 0, x == c1);
+ TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
+ TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0);
+ }
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const NE* op, const Expr& self) {
+ return Mutate(Not::make(op->a == op->b));
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const LE* op, const Expr& self) {
+ return Mutate(Not::make(op->b < op->a));
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const GT* op, const Expr& self) {
+ return Mutate(op->b < op->a);
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const GE* op, const Expr& self) {
+ return Mutate(Not::make(op->a < op->b));
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const LT* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<LT>();
+ Expr const_res = TryConstFold<LT>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y, z, s1, s2;
+ // Pattern var match IntImm
+ PVar<Integer> c1, c2;
+ PVar<int> lanes;
+
+ // vector rule
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes),
+ broadcast(x < y, lanes));
+ TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes),
+ broadcast(x < y, lanes));
+ }
+
+ if (IsIndexType(op->a.type())) {
+ CompareResult result = TryCompare(op->a - op->b, 0);
+ if (result == kLT) {
+ return make_const(op->type, true);
+ }
+ if (result == kEQ || result == kGT) {
+ return make_const(op->type, false);
+ }
+
+ TVM_TRY_REWRITE(x + y < x + z, y < z);
+ TVM_TRY_REWRITE(x + y < z + x, y < z);
+ TVM_TRY_REWRITE(y + x < x + z, y < z);
+ TVM_TRY_REWRITE(y + x < z + x, y < z);
+ TVM_TRY_REWRITE(y - x < z - x, y < z);
+ TVM_TRY_REWRITE(x - y < x - z, z < y);
+
+ TVM_TRY_REWRITE(x < x + z, 0 < z);
+ TVM_TRY_REWRITE(x < z + x, 0 < z);
+ TVM_TRY_REWRITE(x < x - z, z < 0);
+ TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x);
+ TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1);
+
+ TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y,
+ c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
+ c1.Eval()->value < 0);
+
+ // require c1 > 0 to work for any div mode
+ TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
+ c1.Eval()->value > 0 &&
+ c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
+ c1.Eval()->value > 0 &&
+ c2.Eval()->value > 0);
+
+ TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
+ c1.Eval()->value >= 0 &&
+ c2.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
+ c1.Eval()->value >= 0 &&
+ c2.Eval()->value > 0);
+
+ // division related simplificationx
+ // invariance for any div mod: x - (x / c1) * c1 == x % c1
+ TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1,
+ c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y,
+ c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1,
+ c1.Eval()->value > 0);
+
+ TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x,
+ c2 < (x + c2) % c1,
+ c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x + y,
+ c2 < (x + c2) % c1 + y,
+ c1.Eval()->value > 0);
+ TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x - y,
+ y < (x + c2) % c1 + (0 - c2),
+ c1.Eval()->value > 0);
+
+ // canonicalization rule
+ TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z);
+ TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z);
+ TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y);
+ TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y);
+
+ TVM_TRY_REWRITE(x - c1 < 0, x < c1);
+ TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
+ }
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const Not* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Not>();
+ Expr const_res = TryConstFold<Not>(op->a);
+ if (const_res.defined()) return const_res;
+ // Pattern var to match any expression
+ PVar<Expr> x, y;
+ PVar<int> lanes;
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
+ }
+
+ TVM_TRY_REWRITE(!(!x), x);
+ TVM_TRY_REWRITE(!(x <= y), y < x);
+ TVM_TRY_REWRITE(!(x >= y), x < y);
+ TVM_TRY_REWRITE(!(x < y), y <= x);
+ TVM_TRY_REWRITE(!(x > y), x <= y);
+ TVM_TRY_REWRITE(!(x == y), x != y);
+ TVM_TRY_REWRITE(!(x != y), x == y);
+ TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
+ TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const And* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<And>();
+ Expr const_res = TryConstFold<And>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y;
+ // Pattern var match IntImm
+ PVar<Integer> c1, c2;
+ PVar<int> lanes;
+
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes),
+ broadcast(x && y, lanes));
+ }
+
+ auto cfalse = PConst<Expr>(make_const(op->type, false));
+ TVM_TRY_REWRITE(x == y && x != y, cfalse);
+ TVM_TRY_REWRITE(x != y && x == y, cfalse);
+ TVM_TRY_REWRITE(x && !x, cfalse);
+ TVM_TRY_REWRITE(x <= y && y < x, cfalse);
+ TVM_TRY_REWRITE(y < x && y <= x, cfalse);
+
+ TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse,
+ c2.Eval()->value + 1 >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse,
+ c2.Eval()->value + 1 >= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse,
+ c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse,
+ c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse,
+ c2.Eval()->value >= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse,
+ c2.Eval()->value >= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse,
+ c2.Eval()->value > c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse,
+ c2.Eval()->value > c1.Eval()->value);
+
+ TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
+ TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const Or* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Or>();
+ Expr const_res = TryConstFold<Or>(op->a, op->b);
+ if (const_res.defined()) return const_res;
+
+ // Pattern var to match any expression
+ PVar<Expr> x, y;
+ // Pattern var match IntImm
+ PVar<Integer> c1, c2;
+ PVar<int> lanes;
+
+ if (op->type.lanes() != 1) {
+ TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes),
+ broadcast(x || y, lanes));
+ }
+
+ auto ctrue = PConst<Expr>(make_const(op->type, true));
+
+ TVM_TRY_REWRITE(x == y || x != y, ctrue);
+ TVM_TRY_REWRITE(x != y || x == y, ctrue);
+ TVM_TRY_REWRITE(x || !x, ctrue);
+ TVM_TRY_REWRITE(x <= y || y < x, ctrue);
+ TVM_TRY_REWRITE(y < x || y <= x, ctrue);
+
+ TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue,
+ c2.Eval()->value < c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue,
+ c2.Eval()->value < c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue,
+ c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue,
+ c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue,
+ c2.Eval()->value <= c1.Eval()->value);
+ TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue,
+ c2.Eval()->value <= c1.Eval()->value);
+
+ TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue,
+ c2.Eval()->value <= c1.Eval()->value + 1);
+ TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue,
+ c2.Eval()->value <= c1.Eval()->value + 1);
+
+ TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
+ TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const Ramp* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Ramp>();
+ if (is_zero(op->stride)) {
+ return Broadcast::make(op->base, op->lanes);
+ }
+ return ret;
+}
+
+Expr RewriteSimplifier::Impl::
+Mutate_(const Select* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Select>();
+ if (is_zero(op->condition)) {
+ return op->false_value;
+ }
+ if (is_one(op->condition)) {
+ return op->true_value;
+ }
+ // Pattern var to match any expression
+ PVar<Expr> x, y;
+
+ TVM_TRY_REWRITE(select(x, y, y), y);
+ return ret;
+}
Expr RewriteSimplifier::operator()(const Expr& expr) {
return impl_->PostOrderSimplify(expr);
def verify(self, data, expected):
res = self.analyzer.rewrite_simplify(data)
- assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(
- data, res, expected)
+ assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
def test_vector_simplify():
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
tvm.expr.Ramp(1, 15, 4) % 8)
+ # Min/Max rules
+ vx = tvm.var("vx", dtype="int32x2")
+ vc = tvm.var("vc", dtype="uint1")
+ ck.verify(tvm.min(y.astype("int32x2"), x.astype("int32x2")),
+ tvm.min(y, x).astype("int32x2"))
+ ck.verify(tvm.min(tvm.min(vx, y.astype("int32x2")), x.astype("int32x2")),
+ tvm.min(vx, tvm.min(y, x).astype("int32x2")))
+ ck.verify(tvm.max(y.astype("int32x2"), x.astype("int32x2")),
+ tvm.max(y, x).astype("int32x2"))
+ ck.verify(tvm.max(tvm.max(vx, y.astype("int32x2")), x.astype("int32x2")),
+ tvm.max(vx, tvm.max(y, x).astype("int32x2")))
+
+ ## Logical rules
+ ck.verify(y.astype("int32x2").equal(x.astype("int32x2")),
+ (y.equal(x)).astype("uint1x2"))
+ ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))),
+ (tvm.expr.NE(y, x)).astype("uint1x2"))
+ ck.verify(y.astype("int32x2") > x.astype("int32x2"),
+ (x < y).astype("uint1x2"))
+ ck.verify(y.astype("int32x2") >= x.astype("int32x2"),
+ (x <= y).astype("uint1x2"))
+ ck.verify(y.astype("int32x2") < x.astype("int32x2"),
+ (y < x).astype("uint1x2"))
+ ck.verify(y.astype("int32x2") <= x.astype("int32x2"),
+ (y <= x).astype("uint1x2"))
+ ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+ (tvm.expr.And(y <= x, vc)).astype("uint1x2"))
+ ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
+ (tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
def test_select_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
- ck.verify(tvm.expr.Select(x > 0, y, 0) + tvm.expr.Select(x > 0, 1, z),
- tvm.expr.Select(x > 0, y + 1, z))
- ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z),
- tvm.expr.Select(x > 0, y + (-1), 1 - z))
- ck.verify(tvm.expr.Select(x > 0, y, z) - y,
- tvm.expr.Select(x > 0, 0, z - y))
- ck.verify(tvm.expr.Select(x > 0, y, z) - z,
- tvm.expr.Select(x > 0, y - z, 0))
+ ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z),
+ tvm.expr.Select(x < 0, y + 1, z))
+ ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z),
+ tvm.expr.Select(x < 0, y + (-1), 1 - z))
+ ck.verify(tvm.expr.Select(x < 0, y, z) - y,
+ tvm.expr.Select(x < 0, 0, z - y))
+ ck.verify(tvm.expr.Select(x < 0, y, z) - z,
+ tvm.expr.Select(x < 0, y - z, 0))
+ ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
+ tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z)))
+ ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
+ tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z)))
+
+ ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y)
+ ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z)
+ ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1)
def test_add_index_simplify():
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
+def test_min_index_simplify():
+ ck = RewriteChecker()
+ x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+ # const int bound
+ ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2)
+
+ ck.verify(tvm.min(x + 1, x + 10), x + 1)
+ ck.verify(tvm.min(x + 111, x + 10), x + 10)
+ ck.verify(tvm.min(x + 1, x), x)
+ ck.verify(tvm.min(x, x + 2), x)
+ ck.verify(tvm.min(1 - x, 2 - x), 1 - x)
+ ck.verify(tvm.min(3 - x, 2 - x), 2 - x)
+
+ ck.verify(tvm.min((x + 3) / 4 * 4, x), x)
+ ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
+ ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4))
+ ck.verify(tvm.min(x, (x + 3) / 4 * 4), x)
+ ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4))
+ ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
+
+ ck.verify(tvm.min(tvm.max(x, y), tvm.min(x, y)), tvm.min(x, y))
+ ck.verify(tvm.min(tvm.max(x, y), tvm.min(y, x)), tvm.min(x, y))
+
+ ck.verify(tvm.min(tvm.max(x, y), x), x)
+ ck.verify(tvm.min(tvm.max(y, x), x), x)
+ ck.verify(tvm.min(tvm.min(x, y), x), tvm.min(x, y))
+ ck.verify(tvm.min(tvm.min(x, y), y), tvm.min(x, y))
+
+ ck.verify(tvm.min(x, tvm.max(x, y)), x)
+ ck.verify(tvm.min(x, tvm.max(y, x)), x)
+ ck.verify(tvm.min(x, tvm.min(x, y)), tvm.min(x, y))
+ ck.verify(tvm.min(y, tvm.min(x, y)), tvm.min(x, y))
+
+ ck.verify(tvm.min(tvm.min(tvm.min(x, y), z), y),
+ tvm.min(tvm.min(x, y), z))
+ ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), y),
+ tvm.min(tvm.min(tvm.min(x, y), z), x * 2))
+ ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2), y),
+ tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2))
+
+ ck.verify(tvm.min(tvm.max(x, y), tvm.max(x, z)), tvm.max(tvm.min(y, z), x))
+ ck.verify(tvm.min(tvm.max(x, y), tvm.max(z, x)), tvm.max(tvm.min(y, z), x))
+ ck.verify(tvm.min(tvm.max(y, x), tvm.max(x, z)), tvm.max(tvm.min(y, z), x))
+ ck.verify(tvm.min(tvm.max(y, x), tvm.max(z, x)), tvm.max(tvm.min(y, z), x))
+
+ ck.verify(tvm.min(y + x, z + x), tvm.min(y, z) + x)
+ ck.verify(tvm.min(y + x, x + z), tvm.min(y, z) + x)
+ ck.verify(tvm.min(x + y, z + x), tvm.min(y, z) + x)
+ ck.verify(tvm.min(x + y, x + z), tvm.min(y, z) + x)
+
+ ck.verify(tvm.min(x - y, x - z), x - tvm.max(y, z))
+ ck.verify(tvm.min(y - x, z - x), tvm.min(y, z) - x)
+
+ ck.verify(tvm.min(tvm.min(x, 1), 10), tvm.min(x, 1))
+ ck.verify(tvm.min(tvm.min(x, 11), 10), tvm.min(x, 10))
+
+ ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
+ ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
+ ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
+
+
+def test_max_index_simplify():
+ ck = RewriteChecker()
+ x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+ # const int bound
+ ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10)
+
+ ck.verify(tvm.max(x + 1, x + 10), x + 10)
+ ck.verify(tvm.max(x + 111, x + 10), x + 111)
+ ck.verify(tvm.max(x + 1, x), x + 1)
+ ck.verify(tvm.max(x, x + 2), x + 2)
+ ck.verify(tvm.max(1 - x, 2 - x), 2 - x)
+ ck.verify(tvm.max(3 - x, 2 - x), 3 - x)
+
+ ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4)
+
+ ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y))
+ ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y))
+
+ ck.verify(tvm.max(tvm.min(x, y), x), x)
+ ck.verify(tvm.max(tvm.min(y, x), x), x)
+ ck.verify(tvm.max(tvm.max(x, y), x), tvm.max(x, y))
+ ck.verify(tvm.max(tvm.max(x, y), y), tvm.max(x, y))
+
+ ck.verify(tvm.max(x, tvm.min(x, y)), x)
+ ck.verify(tvm.max(x, tvm.min(y, x)), x)
+ ck.verify(tvm.max(x, tvm.max(x, y)), tvm.max(x, y))
+ ck.verify(tvm.max(y, tvm.max(x, y)), tvm.max(x, y))
+
+ ck.verify(tvm.max(tvm.max(tvm.max(x, y), z), y),
+ tvm.max(tvm.max(x, y), z))
+ ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), y),
+ tvm.max(tvm.max(tvm.max(x, y), z), x * 2))
+ ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2), y),
+ tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2))
+
+ ck.verify(tvm.max(tvm.min(x, y), tvm.min(x, z)), tvm.min(tvm.max(y, z), x))
+ ck.verify(tvm.max(tvm.min(x, y), tvm.min(z, x)), tvm.min(tvm.max(y, z), x))
+ ck.verify(tvm.max(tvm.min(y, x), tvm.min(x, z)), tvm.min(tvm.max(y, z), x))
+ ck.verify(tvm.max(tvm.min(y, x), tvm.min(z, x)), tvm.min(tvm.max(y, z), x))
+
+ ck.verify(tvm.max(y + x, z + x), tvm.max(y, z) + x)
+ ck.verify(tvm.max(y + x, x + z), tvm.max(y, z) + x)
+ ck.verify(tvm.max(x + y, z + x), tvm.max(y, z) + x)
+ ck.verify(tvm.max(x + y, x + z), tvm.max(y, z) + x)
+
+ ck.verify(tvm.max(x - y, x - z), x - tvm.min(y, z))
+ ck.verify(tvm.max(y - x, z - x), tvm.max(y, z) - x)
+
+ ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10))
+ ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11))
+
+ ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
+ ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
+ ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
+
+
+def test_cmp_simplify():
+ ck = RewriteChecker()
+ x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+ # const int bound
+ ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool"))
+ ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool"))
+ ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool"))
+ ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool"))
+ ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool"))
+ ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool"))
+
+ # canonicalization
+ ck.verify((x - 10).equal(0), x.equal(10))
+ ck.verify((10 - x).equal(0), x.equal(10))
+ ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0)))
+
+ # cmp bound
+ ck.verify(x + y < x + z, y < z)
+ ck.verify(x + y < z + x, y < z)
+ ck.verify(y + x < x + z, y < z)
+ ck.verify(y + x < z + x, y < z)
+ ck.verify(y - x < z - x, y < z)
+ ck.verify(x - y < x - z, z < y)
+
+ ck.verify(x < z + x, tvm.expr.LT(0, z))
+ ck.verify(x < x + z, tvm.expr.LT(0, z))
+
+ ck.verify(100 < x + 1, tvm.expr.LT(99, x))
+ ck.verify(1 < 100 - x, tvm.expr.LT(x, 99))
+ ck.verify(x * 3 < y * 3, x < y)
+ ck.verify(x * (-3) < y * (-3), y < x)
+ ck.verify(x * 3 >= y * 3, y <= x)
+
+ ck.verify(x * 4 >= 2, tvm.expr.LE(1, x))
+ ck.verify(x * 2 >= 50, tvm.expr.LE(25, x))
+ ck.verify(x / 2 < 3, x < 6)
+ ck.verify(x * 4 <= 2, x <= 0)
+ ck.verify(3 < x / 2, tvm.expr.LT(7, x))
+
+ ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
+ ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
+
+ ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y))
+ ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4))
+
+ ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2))
+ ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2))
+ ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y))
+
+
+ ck.verify(tvm.min(x, 11) < 10, x < 10)
+ ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool"))
+ ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
+ ck.verify(x + 1 < tvm.max(8, x), x < 7)
+
+
+def test_logical_simplify():
+ ck = RewriteChecker()
+ x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
+
+ ck.verify(tvm.expr.And(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
+ tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
+ tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x > 1, tvm.expr.Not(x > 1)), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x <= y, y < x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(y < x, y <= x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x < 1, 0 < x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x < 0, 1 < x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x < 1, 1 <= x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x <= 1, 1 < x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(1 <= x, x < 1), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(1 < x, x <= 1), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x <= 1, 2 <= x), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(2 <= x, x <= 1), tvm.const(False, "bool"))
+ ck.verify(tvm.expr.And(x == 1, x != 2), x == 1)
+
+
+ ck.verify(tvm.expr.Or(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
+ tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
+ tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(x > y, tvm.expr.Not(x < y)), tvm.const(True, "bool"))
+
+ ck.verify(tvm.expr.Or(x <= y, y < x), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(y < x, y <= x), tvm.const(True, "bool"))
+
+ ck.verify(tvm.expr.Or(x < 1, 0 < x), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(0 < x, x < 1), tvm.const(True, "bool"))
+
+ ck.verify(tvm.expr.Or(x < 1, 1 <= x), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(x <= 1, 1 < x), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(1 <= x, x < 1), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(1 < x, x <= 1), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(x <= 1, 2 <= x), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(2 <= x, x <= 1), tvm.const(True, "bool"))
+ ck.verify(tvm.expr.Or(x != 1, x == 2), x != 1)
+
+
if __name__ == "__main__":
- test_mod_index_simplify()
+ test_cmp_simplify()
test_vector_simplify()
test_add_index_simplify()
test_sub_index_simplify()
test_mul_index_simplify()
test_div_index_simplify()
+ test_max_index_simplify()
+ test_min_index_simplify()
+ test_mod_index_simplify()
test_select_simplify()
+ test_logical_simplify()