static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);
+
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
*/
class GraphPartitioner {
public:
- explicit GraphPartitioner(support::Arena* arena, int opt_level)
- : arena_(arena), opt_level_(opt_level) {}
+ explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
+ : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
/*!
* \brief Group as a union find data structure.
*/
support::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
+ /*! \brief The maximum number of operations in one fused function */
+ size_t max_fuse_depth_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
* \param parent The parent group.
*/
void MergeFromTo(Group* child, Group* parent) {
- // update the number of nodes of the parent group
- parent->num_nodes += child->num_nodes;
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
+ // update the number of nodes of the parent group
+ parent->num_nodes += child->num_nodes;
child->parent = parent;
// update master ref and pattern
if (child->master_ref != nullptr) {
CommitFuse_(src, sink, target);
}
+ size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
+ if (src == sink || visited_.count(src)) return 0;
+ visited_.insert(src);
+ Group* gnode = groups_[src->index];
+ CHECK(gnode != nullptr);
+ auto sum = gnode->num_nodes;
+ for (auto link = src->outputs.head; link != nullptr; link = link->next) {
+ sum += CountNodesUptoSink_(link->value.node, sink);
+ }
+ return sum;
+ }
+
+ // Count the number of nodes in a fused subgraph if child is additionaly fused.
+ // dom_parent is already known to be a part of the subgraph.
+ // For a diamond structure, there can be multiple paths connecting child and dom_parent.
+ // All intermediate nodes between child and dom_parent are taken into account.
+ // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
+ // is important for correct calculation.
+ size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
+ IndexedForwardGraph::Node* dom_parent) {
+ Group* target = groups_[dom_parent->index];
+ visited_.clear();
+ CHECK(child != dom_parent);
+ return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
+ }
+
// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
size_t dom_parent_gindex = dom_node->parent->gnode->index;
// refuse the fusion if too many ops are going to be fused together
- if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue;
+ if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
+ continue;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
class FuseMutator : private ExprMutator {
public:
// Run the transform
- Expr Transform(const Expr& body, int fuse_opt_level) {
+ Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
// setup the group map.
auto graph = IndexedForwardGraph::Create(&arena_, body);
- auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph);
+ auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
CHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
}
};
-Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
- return FuseMutator().Transform(expr, fuse_opt_level);
+Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
+ return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
}
namespace transform {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
- return Downcast<Function>(FuseOps(f, opt_level, m));
+ auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
+ return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
def test_fuse_max():
"""Test the constraint of number of nodes in op fusion."""
- max_fused_ops = 256
- # n is the number of nodes to be fused, should be less than 2*max_fused_ops
- n = 300
- def before():
+ def before(n):
x = relay.var("x", shape=(10, 20))
y = x
for i in range(n):
y = relay.exp(y)
return relay.Function([x], y)
- def expected():
+ def expected(n, max_fused_ops):
x = relay.var("p", shape=(10, 20))
y = x
for i in range(max_fused_ops):
z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20))
yy = xx
+ # it is assumed that there are two fused functions
for i in range(n-max_fused_ops):
yy = relay.exp(yy)
f2 = relay.Function([xx], yy)
zz = relay.Call(f2, [z])
return relay.Function([x], zz)
- z = before()
+ max_fused_ops = 256
+ n = 300
+ z = before(n)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
- after = run_opt_pass(expected(), transform.InferType())
+ after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
+ assert tvm.ir.structural_equal(zz, after)
+
+ max_fused_ops = 10
+ n = 20
+ z = before(n)
+ after = run_opt_pass(expected(n, max_fused_ops), transform.InferType())
+
+ with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
+ zz = run_opt_pass(z, transform.FuseOps())
+
assert tvm.ir.structural_equal(zz, after)
assert tvm.ir.structural_equal(m["main"], after)
+def test_fuse_max_diamond():
+ def create_diamond(x, branch_len):
+ x1 = x
+ x2 = x
+ for _ in range(branch_len):
+ x1 = relay.exp(x1)
+ x2 = relay.exp(x2)
+ return relay.add(x1, x2)
+
+ def before(branch_len, num_diamond):
+ x = relay.var("x", shape=(10, 20))
+ out = x
+ for _ in range(num_diamond):
+ out = create_diamond(out, branch_len)
+ return relay.Function([x], out)
+
+ def after(branch_len, num_diamond):
+ def create_diamond_func(inp):
+ inp_var = relay.var("p", shape=(10, 20))
+ d = create_diamond(inp_var, branch_len)
+ f = relay.Function([inp_var], d)
+ f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ return relay.Call(f, [inp])
+
+ inp = relay.var("x", shape=(10, 20))
+ out = inp
+ for _ in range(num_diamond):
+ out = create_diamond_func(out)
+ return relay.Function([inp], out)
+
+ branch_len = 5
+ max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond
+ num_diamond = 3
+
+ with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
+ fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps())
+
+ expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType())
+ assert tvm.ir.structural_equal(fused, expected)
+
+
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_fuse_take()
test_fuse_gather_nd()
test_fuse_bcast_reduce_scalar()
+ test_fuse_max_diamond()