Add TupleGetItem to CSE (#5931)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Fri, 26 Jun 2020 14:22:43 +0000 (07:22 -0700)
committerGitHub <noreply@github.com>
Fri, 26 Jun 2020 14:22:43 +0000 (07:22 -0700)
* Add TupleGetItem to CSE

* rename a local variable

src/relay/transforms/eliminate_common_subexpr.cc
tests/python/relay/test_pass_eliminate_common_subexpr.py

index 8f7375c..dc3f77e 100644 (file)
@@ -58,27 +58,52 @@ class CommonSubexprEliminator : public ExprMutator {
 
     auto it = expr_map_.find(new_call->op);
     if (it != expr_map_.end()) {
-      for (const CallNode* candidate : it->second) {
-        bool is_equivalent = true;
-        if (!attrs_equal(new_call->attrs, candidate->attrs)) {
-          continue;
+      for (const Expr& candidate_expr : it->second) {
+        if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
+          bool is_equivalent = true;
+          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
+            continue;
+          }
+          for (size_t i = 0; i < new_call->args.size(); i++) {
+            if (!new_call->args[i].same_as(candidate->args[i]) &&
+                !IsEqualScalar(new_call->args[i], candidate->args[i])) {
+              is_equivalent = false;
+              break;
+            }
+          }
+          if (!is_equivalent) continue;
+          return GetRef<Call>(candidate);
         }
-        for (size_t i = 0; i < new_call->args.size(); i++) {
-          if (!new_call->args[i].same_as(candidate->args[i]) &&
-              !IsEqualScalar(new_call->args[i], candidate->args[i])) {
-            is_equivalent = false;
-            break;
+      }
+    }
+    expr_map_[new_call->op].push_back(new_expr);
+    return new_expr;
+  }
+
+  Expr VisitExpr_(const TupleGetItemNode* op) final {
+    Expr new_expr = ExprMutator::VisitExpr_(op);
+    const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
+    CHECK(new_tuple_item);
+
+    if (fskip_ != nullptr && fskip_(new_expr)) {
+      return new_expr;
+    }
+
+    auto it = expr_map_.find(new_tuple_item->tuple);
+    if (it != expr_map_.end()) {
+      for (const Expr& candidate_expr : it->second) {
+        if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) {
+          if (new_tuple_item->index == candidate->index) {
+            return GetRef<Expr>(candidate);
           }
         }
-        if (!is_equivalent) continue;
-        return GetRef<Call>(candidate);
       }
     }
-    expr_map_[new_call->op].push_back(new_call);
+    expr_map_[new_tuple_item->tuple].push_back(new_expr);
     return new_expr;
   }
 
-  std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
+  std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
   runtime::TypedPackedFunc<bool(Expr)> fskip_;
 };
 
index 7af524d..45d21a4 100644 (file)
@@ -84,6 +84,35 @@ def test_callback():
     z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
     assert tvm.ir.structural_equal(z, expected())
 
+def test_tuple_get_time():
+    def before():
+        x = relay.var('x', shape=(1, 16, 1, 1))
+        var = relay.var('var', shape=(16,))
+        mean = relay.var('mean', shape=(16,))
+        beta = relay.var('beta', shape=(16,))
+        gamma = relay.var('gamma', shape=(16,))
+        BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+        T1 = BN[0]
+        T2 = BN[0]
+        add = T1 + T2
+        f = relay.Function([x, var, mean, beta, gamma], add)
+        return f
+
+    def expected():
+        x = relay.var('x', shape=(1, 16, 1, 1))
+        var = relay.var('var', shape=(16,))
+        mean = relay.var('mean', shape=(16,))
+        beta = relay.var('beta', shape=(16,))
+        gamma = relay.var('gamma', shape=(16,))
+        BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+        T1 = BN[0]
+        add = T1 + T1
+        f = relay.Function([x, var, mean, beta, gamma], add)
+        return run_opt_pass(f, transform.InferType())
+
+    z = before()
+    z = run_opt_pass(z, transform.EliminateCommonSubexpr())
+    assert tvm.ir.structural_equal(z, expected())
 
 if __name__ == "__main__":
     test_simple()