return self;
}
+Expr RewriteSimplifier::Impl::
+Mutate_(const Cast* op, const Expr& self) {
+ Expr ret = IRMutator::Mutate_(op, self);
+ op = ret.as<Cast>();
+ return cast(op->type, op->value);
+}
+
Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order
Expr res = expr;
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
+ Expr Mutate_(const Cast* op, const Expr& self) override;
protected:
/*! \brief internal structure for comparison. */
Expr cast(const Type& t, Expr value) {
using ir::IntImm;
+ using ir::UIntImm;
using ir::FloatImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
+ } else if (const UIntImm* op = value.as<UIntImm>()) {
+ return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
return make_const(t, op->value);
}
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
+ } else if (const UIntImm* op = value.as<UIntImm>()) {
+ return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
value = make_const(vtype, op->value);
} else {
z = tvm.expr.Let(x, 1, x + 1)
ck.verify(z + z, 4)
+def test_cast_simplify():
+ ck = RewriteChecker()
+ x = tvm.var("x")
+
+ dtypes = ["float32", "float16", "int32", "int8", "bool"]
+ for dtype1 in dtypes:
+ ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
+ ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
+ for dtype2 in dtypes:
+ for i in [0, 1, 2, 3]:
+ ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
+
if __name__ == "__main__":
test_floordiv_index_simplify()
test_floormod_index_simplify()
test_select_simplify()
test_logical_simplify()
test_let_simplify()
+ test_cast_simplify()