[Fix] Fix the logic of the number of nodes checking in op fusion (#4074)
authorYida Wang <yidawa@gmail.com>
Thu, 10 Oct 2019 19:24:32 +0000 (12:24 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Thu, 10 Oct 2019 19:24:32 +0000 (12:24 -0700)
* move the number of nodes constraint in op fusion up to the dom tree level

* add test case of limiting the max number of ops to be fused

* uncomment other test cases

src/relay/pass/fuse_ops.cc
tests/python/relay/test_pass_fuse_ops.py

index 935c37d..acee2c1 100644 (file)
@@ -623,9 +623,7 @@ class GraphPartitioner {
    * \param parent The parent group.
    */
   void MergeFromTo(Group* child, Group* parent) {
-    // refuse the fusion if too many ops are going to be fused together
-    if (child->num_nodes + parent->num_nodes > kMaxFusedOps)
-      return;
+    // update the number of nodes of the parent group
     parent->num_nodes += child->num_nodes;
     child = child->FindRoot();
     parent = parent->FindRoot();
@@ -701,6 +699,10 @@ class GraphPartitioner {
       CHECK(!graph_node->extern_ref);
       size_t dom_parent_gindex = dom_node->parent->gnode->index;
 
+      // refuse the fusion if too many ops are going to be fused together
+      if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps)
+        continue;
+
       if (phase == 2) {
         // Fuse injective ops into intermediate tuples, if any
         if (group_node->pattern > kInjective) continue;
index f148502..45faa14 100644 (file)
@@ -552,6 +552,39 @@ def test_split():
     mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
     mod = transform.FuseOps()(mod)
 
+def test_fuse_max():
+    """Test the constraint of number of nodes in op fusion."""
+    max_fused_ops = 256
+    # n is the number of nodes to be fused, should be less than 2*max_fused_ops
+    n = 300
+    def before():
+        x = relay.var("x", shape=(10, 20))
+        y = x
+        for i in range(n):
+            y = relay.exp(y)
+        return relay.Function([x], y)
+
+    def expected():
+        x = relay.var("p", shape=(10, 20))
+        y = x
+        for i in range(max_fused_ops):
+            y = relay.exp(y)
+        f1 = relay.Function([x], y)
+        x = relay.var("x", shape=(10, 20))
+        z = relay.Call(f1, [x])
+        xx = relay.var("pp", shape=(10, 20))
+        yy = xx
+        for i in range(n-max_fused_ops):
+            yy = relay.exp(yy)
+        f2 = relay.Function([xx], yy)
+        zz = relay.Call(f2, [z])
+        return relay.Function([x], zz)
+
+    z = before()
+    zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
+    zz = run_opt_pass(z, transform.FuseOps())
+    after = run_opt_pass(expected(), transform.InferType())
+    assert relay.analysis.alpha_equal(zz, after)
 
 if __name__ == "__main__":
     test_fuse_simple()
@@ -568,3 +601,4 @@ if __name__ == "__main__":
     test_fuse_parallel_injective()
     test_immutable()
     test_split()
+    test_fuse_max()