[OpFusion] Make the max number of fused ops configurable (#6327)
authormasahi <masahi129@gmail.com>
Mon, 24 Aug 2020 20:15:23 +0000 (05:15 +0900)
committerGitHub <noreply@github.com>
Mon, 24 Aug 2020 20:15:23 +0000 (13:15 -0700)
src/relay/transforms/fuse_ops.cc
tests/python/relay/test_pass_fuse_ops.py

index 01f1eee..85b74cc 100644 (file)
@@ -83,6 +83,8 @@ constexpr uint32_t kMaxFusedOps = 256;
 
 static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);
+
 /*!
  * \brief Indexed data flow graph in forward direction.
  *  This is a temporary data structure used for operator fusion analysis.
@@ -496,8 +498,8 @@ DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForward
  */
 class GraphPartitioner {
  public:
-  explicit GraphPartitioner(support::Arena* arena, int opt_level)
-      : arena_(arena), opt_level_(opt_level) {}
+  explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
+      : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
   /*!
    * \brief Group as a union find data structure.
    */
@@ -549,6 +551,8 @@ class GraphPartitioner {
   support::Arena* arena_;
   /*! \brief optimization level for fuse operation. */
   int opt_level_;
+  /*! \brief The maximum number of operations in one fused function */
+  size_t max_fuse_depth_;
   /*! \brief The internal groups. */
   std::vector<Group*> groups_;
   /*! \brief internal field used for deduplication */
@@ -604,11 +608,11 @@ class GraphPartitioner {
    * \param parent The parent group.
    */
   void MergeFromTo(Group* child, Group* parent) {
-    // update the number of nodes of the parent group
-    parent->num_nodes += child->num_nodes;
     child = child->FindRoot();
     parent = parent->FindRoot();
     if (child == parent) return;
+    // update the number of nodes of the parent group
+    parent->num_nodes += child->num_nodes;
     child->parent = parent;
     // update master ref and pattern
     if (child->master_ref != nullptr) {
@@ -643,6 +647,32 @@ class GraphPartitioner {
     CommitFuse_(src, sink, target);
   }
 
+  size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
+    if (src == sink || visited_.count(src)) return 0;
+    visited_.insert(src);
+    Group* gnode = groups_[src->index];
+    CHECK(gnode != nullptr);
+    auto sum = gnode->num_nodes;
+    for (auto link = src->outputs.head; link != nullptr; link = link->next) {
+      sum += CountNodesUptoSink_(link->value.node, sink);
+    }
+    return sum;
+  }
+
+  // Count the number of nodes in a fused subgraph if child is additionaly fused.
+  // dom_parent is already known to be a part of the subgraph.
+  // For a diamond structure, there can be multiple paths connecting child and dom_parent.
+  // All intermediate nodes between child and dom_parent are taken into account.
+  // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
+  // is important for correct calculation.
+  size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
+                                     IndexedForwardGraph::Node* dom_parent) {
+    Group* target = groups_[dom_parent->index];
+    visited_.clear();
+    CHECK(child != dom_parent);
+    return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
+  }
+
   // Initialize the groups.
   void InitGroups(const IndexedForwardGraph& graph) {
     groups_.resize(graph.post_dfs_order.size());
@@ -675,7 +705,8 @@ class GraphPartitioner {
       size_t dom_parent_gindex = dom_node->parent->gnode->index;
 
       // refuse the fusion if too many ops are going to be fused together
-      if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue;
+      if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
+        continue;
 
       if (phase == 2) {
         // Fuse injective ops into intermediate tuples, if any
@@ -769,10 +800,10 @@ std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
 class FuseMutator : private ExprMutator {
  public:
   // Run the transform
-  Expr Transform(const Expr& body, int fuse_opt_level) {
+  Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
     // setup the group map.
     auto graph = IndexedForwardGraph::Create(&arena_, body);
-    auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
+    auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
     for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
       CHECK(graph.post_dfs_order[nid]->ref != nullptr);
       gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
@@ -926,8 +957,8 @@ class FuseMutator : private ExprMutator {
   }
 };
 
-Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
-  return FuseMutator().Transform(expr, fuse_opt_level);
+Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
+  return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
 }
 
 namespace transform {
@@ -936,7 +967,8 @@ Pass FuseOps(int fuse_opt_level) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
         int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
-        return Downcast<Function>(FuseOps(f, opt_level, m));
+        auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
+        return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
       };
   return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
 }
index 1727429..90e80d8 100644 (file)
@@ -587,17 +587,14 @@ def test_split():
 
 def test_fuse_max():
     """Test the constraint of number of nodes in op fusion."""
-    max_fused_ops = 256
-    # n is the number of nodes to be fused, should be less than 2*max_fused_ops
-    n = 300
-    def before():
+    def before(n):
         x = relay.var("x", shape=(10, 20))
         y = x
         for i in range(n):
             y = relay.exp(y)
         return relay.Function([x], y)
 
-    def expected():
+    def expected(n, max_fused_ops):
         x = relay.var("p", shape=(10, 20))
         y = x
         for i in range(max_fused_ops):
@@ -608,6 +605,7 @@ def test_fuse_max():
         z = relay.Call(f1, [x])
         xx = relay.var("pp", shape=(10, 20))
         yy = xx
+        # it is assumed that there are two fused functions
         for i in range(n-max_fused_ops):
             yy = relay.exp(yy)
         f2 = relay.Function([xx], yy)
@@ -615,10 +613,22 @@ def test_fuse_max():
         zz = relay.Call(f2, [z])
         return relay.Function([x], zz)
 
-    z = before()
+    max_fused_ops = 256
+    n = 300
+    z = before(n)
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     zz = run_opt_pass(z, transform.FuseOps())
-    after = run_opt_pass(expected(), transform.InferType())
+    after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
+    assert tvm.ir.structural_equal(zz, after)
+
+    max_fused_ops = 10
+    n = 20
+    z = before(n)
+    after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
+
+    with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
+        zz = run_opt_pass(z, transform.FuseOps())
+
     assert tvm.ir.structural_equal(zz, after)
 
 
@@ -722,6 +732,47 @@ def test_fuse_bcast_reduce_scalar():
     assert tvm.ir.structural_equal(m["main"], after)
 
 
+def test_fuse_max_diamond():
+    def create_diamond(x, branch_len):
+        x1 = x
+        x2 = x
+        for _ in range(branch_len):
+            x1 = relay.exp(x1)
+            x2 = relay.exp(x2)
+        return relay.add(x1, x2)
+
+    def before(branch_len, num_diamond):
+        x = relay.var("x", shape=(10, 20))
+        out = x
+        for _ in range(num_diamond):
+            out = create_diamond(out, branch_len)
+        return relay.Function([x], out)
+
+    def after(branch_len, num_diamond):
+        def create_diamond_func(inp):
+            inp_var = relay.var("p", shape=(10, 20))
+            d = create_diamond(inp_var, branch_len)
+            f = relay.Function([inp_var], d)
+            f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+            return relay.Call(f, [inp])
+
+        inp = relay.var("x", shape=(10, 20))
+        out = inp
+        for _ in range(num_diamond):
+            out = create_diamond_func(out)
+        return relay.Function([inp], out)
+
+    branch_len = 5
+    max_fused_ops = branch_len * 2 + 1  # the number of ops in one diamond
+    num_diamond = 3
+
+    with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
+        fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps())
+
+    expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType())
+    assert tvm.ir.structural_equal(fused, expected)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
@@ -741,3 +792,4 @@ if __name__ == "__main__":
     test_fuse_take()
     test_fuse_gather_nd()
     test_fuse_bcast_reduce_scalar()
+    test_fuse_max_diamond()