[ARITH] Simplify casts of constants 0 and 1 (#3758)
authorSergei Grechanik <grechanik.sergey@huawei.com>
Tue, 13 Aug 2019 15:28:28 +0000 (18:28 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 13 Aug 2019 15:28:28 +0000 (08:28 -0700)
* [ARITH] Simplify casts of constants 0 and 1

* [EXPR] is_const_value to check whether non-ints are consts

* Revert "[EXPR] is_const_value to check whether non-ints are consts"

This reverts commit 7e1b3462e3f74fd0afb1541d72978107cfa23c30.

* Use tvm::cast

src/arithmetic/rewrite_simplify.cc
src/arithmetic/rewrite_simplify.h
src/lang/expr_operator.cc
tests/python/unittest/test_arith_rewrite_simplify.py

index cf26d12..8b4114b 100644 (file)
@@ -1757,6 +1757,13 @@ Mutate_(const Variable* op, const Expr& self) {
   return self;
 }
 
+Expr RewriteSimplifier::Impl::
+Mutate_(const Cast* op, const Expr& self) {
+  Expr ret = IRMutator::Mutate_(op, self);
+  op = ret.as<Cast>();
+  return cast(op->type, op->value);
+}
+
 Expr RewriteSimplifier::operator()(const Expr& expr) {
   // Run simplification in post order
   Expr res = expr;
index f46dcc9..0a2bedc 100644 (file)
@@ -70,6 +70,7 @@ class RewriteSimplifier::Impl : public IRMutator {
   Expr Mutate_(const Call* op, const Expr& self) override;
   Expr Mutate_(const Let* op, const Expr& self) override;
   Expr Mutate_(const Variable* op, const Expr& self) override;
+  Expr Mutate_(const Cast* op, const Expr& self) override;
 
  protected:
   /*! \brief internal structure for comparison. */
index 9b6adff..6383b71 100644 (file)
@@ -105,12 +105,15 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
 
 Expr cast(const Type& t, Expr value) {
   using ir::IntImm;
+  using ir::UIntImm;
   using ir::FloatImm;
   if (value.type() == t) return value;
   // const fold IntImm as they are used in index computations
   if (t.lanes() == 1) {
     if (const IntImm* op = value.as<IntImm>()) {
       return make_const(t, op->value);
+    } else if (const UIntImm* op = value.as<UIntImm>()) {
+      return make_const(t, op->value);
     } else if (const FloatImm* op = value.as<FloatImm>()) {
       return make_const(t, op->value);
     }
@@ -122,6 +125,8 @@ Expr cast(const Type& t, Expr value) {
       if (value.type() != vtype) {
         if (const IntImm* op = value.as<IntImm>()) {
           value = make_const(vtype, op->value);
+        } else if (const UIntImm* op = value.as<UIntImm>()) {
+          return make_const(t, op->value);
         } else if (const FloatImm* op = value.as<FloatImm>()) {
           value = make_const(vtype, op->value);
         } else {
index 6ff0183..ca30354 100644 (file)
@@ -804,6 +804,18 @@ def test_let_simplify():
     z = tvm.expr.Let(x, 1, x + 1)
     ck.verify(z + z, 4)
 
+def test_cast_simplify():
+    ck = RewriteChecker()
+    x = tvm.var("x")
+
+    dtypes = ["float32", "float16", "int32", "int8", "bool"]
+    for dtype1 in dtypes:
+        ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
+        ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
+        for dtype2 in dtypes:
+            for i in [0, 1, 2, 3]:
+                ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
+
 if __name__ == "__main__":
     test_floordiv_index_simplify()
     test_floormod_index_simplify()
@@ -819,3 +831,4 @@ if __name__ == "__main__":
     test_select_simplify()
     test_logical_simplify()
     test_let_simplify()
+    test_cast_simplify()