[RELAY][PASS] Common subexpression elimination (#2639)
authorWuwei Lin <vincentl13x@gmail.com>
Sun, 3 Mar 2019 18:15:12 +0000 (02:15 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 3 Mar 2019 18:15:12 +0000 (10:15 -0800)
python/tvm/relay/ir_pass.py
src/relay/pass/eliminate_common_subexpr.cc [new file with mode: 0644]
src/relay/pass/pattern_util.h
tests/python/relay/test_pass_eliminate_common_subexpr.py [new file with mode: 0644]

index 02a6e8b..04b92ba 100644 (file)
@@ -564,3 +564,23 @@ def get_total_mac_number(expr):
       The number of MACs (multiply-accumulate) of a model
     """
     return _ir_pass.GetTotalMacNumber(expr)
+
+
+def eliminate_common_subexpr(expr, fskip=None):
+    """
+    Eliminate common subexpressions.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression.
+
+    fskip: function
+        The callback function that decides whether an expression should be skipped.
+
+    Returns
+    -------
+    expr : tvm.relay.Expr
+      The output expression.
+    """
+    return _ir_pass.eliminate_common_subexpr(expr, fskip)
diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc
new file mode 100644 (file)
index 0000000..10e6f92
--- /dev/null
@@ -0,0 +1,72 @@
+/*!
+ * Copyright (c) 2019 by Contributors
+ *
+ * \file eliminate_common_subexpr.cc
+ * \brief Combine common subexpressions.
+ *
+ * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
+ * to replace an expression with a previously appeared expression with the same input and
+ * attributes. The fskip callback argument allows us to skip specific expressions.
+ */
+#include <tvm/relay/pass.h>
+#include <tvm/relay/expr_functor.h>
+#include <unordered_map>
+#include "./pattern_util.h"
+
+namespace tvm {
+namespace relay {
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+  explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
+
+  Expr VisitExpr_(const CallNode* call) final {
+    static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+    Expr new_expr = ExprMutator::VisitExpr_(call);
+    const CallNode* new_call = new_expr.as<CallNode>();
+    CHECK(new_call);
+    const OpNode* op = new_call->op.as<OpNode>();
+    AttrsEqual attrs_equal;
+
+    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
+      return new_expr;
+    }
+    if (fskip_ != nullptr && fskip_(new_expr)) {
+      return new_expr;
+    }
+
+    auto it = expr_map_.find(new_call->op);
+    if (it != expr_map_.end()) {
+      for (const CallNode* candidate : it->second) {
+        bool is_equivalent = true;
+        if (!attrs_equal(new_call->attrs, candidate->attrs)) {
+          continue;
+        }
+        for (size_t i = 0; i < new_call->args.size(); i++) {
+          if (!new_call->args[i].same_as(candidate->args[i]) &&
+              !IsEqualScalar(new_call->args[i], candidate->args[i])) {
+            is_equivalent = false;
+            break;
+          }
+        }
+        if (!is_equivalent) continue;
+        return GetRef<Call>(candidate);
+      }
+    }
+    expr_map_[new_call->op].push_back(new_call);
+    return new_expr;
+  }
+
+  std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
+  runtime::TypedPackedFunc<bool(Expr)> fskip_;
+};
+
+Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
+  return CommonSubexprEliminator(callback)(expr);
+}
+
+TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
+.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
+
+}  // namespace relay
+}  // namespace tvm
index 0644c26..e59efa9 100644 (file)
@@ -191,6 +191,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
   return ConstantNode::make(arr);
 }
 
+/*!
+ * \brief Check if two expressions are equal scalars.
+ * \param a The expression to be checked.
+ * \param b The expression to be checked
+ * \return Whether two expressions are equal scalars.
+ */
+inline bool IsEqualScalar(const Expr& a, const Expr& b) {
+  const auto* constant_a = a.as<ConstantNode>();
+  const auto* constant_b = b.as<ConstantNode>();
+  if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
+    return false;
+  }
+  return AlphaEqual(a, b);
+}
+
 inline Expr GetField(Expr t, size_t i) {
   return TupleGetItemNode::make(t, i);
 }
diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py
new file mode 100644 (file)
index 0000000..381a54a
--- /dev/null
@@ -0,0 +1,63 @@
+"""Test eliminate common subexpr pass"""
+from tvm import relay
+from tvm.relay.op import register_alter_op_layout
+from tvm.relay import ir_pass
+
+
+def test_simple():
+    def before():
+        x = relay.var("x", shape=(1, 16))
+        y1 = relay.nn.relu(x)
+        y2 = relay.nn.relu(x)
+        y1 = relay.add(y1, relay.const(1.0, "float32"))
+        y2 = relay.add(y2, relay.const(1.0, "float32"))
+        y = relay.add(y1, y2)
+        f = relay.Function([x], y)
+        return f
+
+    def expected():
+        x = relay.var("x", shape=(1, 16))
+        y = relay.nn.relu(x)
+        y = relay.add(y, relay.const(1.0, "float32"))
+        y = relay.add(y, y)
+        f = relay.Function([x], y)
+        return f
+
+    z = before()
+    z = ir_pass.eliminate_common_subexpr(z)
+    assert ir_pass.alpha_equal(z, expected())
+
+
+def test_callback():
+    def before():
+        x = relay.var("x", shape=(1, 16))
+        y1 = relay.nn.relu(x)
+        y2 = relay.nn.relu(x)
+        y1 = relay.add(y1, relay.const(1.0, "float32"))
+        y2 = relay.add(y2, relay.const(1.0, "float32"))
+        y = relay.add(y1, y2)
+        f = relay.Function([x], y)
+        return f
+
+    def expected():
+        x = relay.var("x", shape=(1, 16))
+        y = relay.nn.relu(x)
+        y1 = relay.add(y, relay.const(1.0, "float32"))
+        y2 = relay.add(y, relay.const(1.0, "float32"))
+        y = relay.add(y1, y2)
+        f = relay.Function([x], y)
+        return f
+
+    def fskip(expr):
+        if isinstance(expr, relay.expr.Call) and expr.op.name == 'add':
+            return True
+        return False
+
+    z = before()
+    z = ir_pass.eliminate_common_subexpr(z, fskip)
+    assert ir_pass.alpha_equal(z, expected())
+
+
+if __name__ == "__main__":
+    test_simple()
+    test_callback()