* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
- // refuse the fusion if too many ops are going to be fused together
- if (child->num_nodes + parent->num_nodes > kMaxFusedOps)
- return;
+ // update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child = child->FindRoot();
parent = parent->FindRoot();
CHECK(!graph_node->extern_ref);
size_t dom_parent_gindex = dom_node->parent->gnode->index;
+ // refuse the fusion if too many ops are going to be fused together
+ if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps)
+ continue;
+
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)
+def test_fuse_max():
+ """Test the constraint of number of nodes in op fusion."""
+ max_fused_ops = 256
+ # n is the number of nodes to be fused, should be less than 2*max_fused_ops
+ n = 300
+ def before():
+ x = relay.var("x", shape=(10, 20))
+ y = x
+ for i in range(n):
+ y = relay.exp(y)
+ return relay.Function([x], y)
+
+ def expected():
+ x = relay.var("p", shape=(10, 20))
+ y = x
+ for i in range(max_fused_ops):
+ y = relay.exp(y)
+ f1 = relay.Function([x], y)
+ x = relay.var("x", shape=(10, 20))
+ z = relay.Call(f1, [x])
+ xx = relay.var("pp", shape=(10, 20))
+ yy = xx
+ for i in range(n-max_fused_ops):
+ yy = relay.exp(yy)
+ f2 = relay.Function([xx], yy)
+ zz = relay.Call(f2, [z])
+ return relay.Function([x], zz)
+
+ z = before()
+ zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
+ zz = run_opt_pass(z, transform.FuseOps())
+ after = run_opt_pass(expected(), transform.InferType())
+ assert relay.analysis.alpha_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
test_fuse_parallel_injective()
test_immutable()
test_split()
+ test_fuse_max()