void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
- tuple_node->pattern = kInjective;
+ tuple_node->pattern = kTuple;
for (const Expr& field : op->fields) {
if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective);
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
CHECK(!graph_node->extern_ref);
- // Skip if current node is already fused to the parent.
size_t dom_parent_gindex = dom_node->parent->gnode->index;
+
+ if (phase == 2) {
+ // Fuse injective ops into intermediate tuples, if any
+ if (group_node->pattern > kInjective) continue;
+ Group* dom_parent_group = groups_[dom_parent_gindex];
+ Group* dom_root_group = dom_parent_group->FindRoot();
+ // If dom node group has a tuple as its root, we do not fuse tuple fields into it
+ if (dom_root_group->pattern == kTuple) continue;
+ if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
+ // Now we know the tuple has been fused into subsequent injective ops
+ auto fcond = [](OpPatternKind kind, bool is_sink) {
+ return kind <= kInjective;
+ };
+ // dom_root_group can also be tuple, as in inception layers
+ // CheckPath is needed to avoid fusing two intermediate tuples
+ if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
+ CommitFuse(graph_node, dom_node->parent->gnode);
+ }
+ }
+ continue;
+ }
+
+ // Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
}
+ // Do not fuse into tuple for now
+ if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
- } else if (group_node->pattern == kInjective) {
+ } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
- for (int phase = 0; phase < 2; ++phase) {
+ for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
}
return std::move(groups_);
Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
- Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
if (ret_group == gmap_.at(tuple)) {
- // This tuple is the root of its group. Check if all fields come from other groups.
- bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
- for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
- isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
- }
- if (isolated) {
- // Do not put a isolated tuple into a function
- return ExprMutator::VisitExpr_(tuple);
- }
- // This tuple has been fused with other ops before it
- for (size_t i = 0; i < new_fields.size(); i++) {
- // Copy function arguments to tuple field of the output because currently graph memory
- // planer doesn't support inplace operations
- if (new_fields[i].as<VarNode>()) {
- auto copy = Copy(new_fields[i]);
- new_fields.Set(i, copy);
- }
- }
- return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
+ return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
+ Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields);
}
f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
- p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
- p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
- out = relay.Tuple((upsampled, p1_copy))
- f1 = relay.Function([p0, p1], out)
+ f1 = relay.Function([p0], upsampled)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
- z = relay.Call(f1, [y, x])
- return relay.Function([x], z)
+ z = relay.Call(f1, [y])
+ tup = relay.Tuple((z, x))
+ return relay.Function([x], tup)
dshape = (1, 16, 64, 64)
z = before(dshape)
assert relay.ir_pass.alpha_equal(zz, after)
-def test_tuple_strided_slice():
- """
- Test fusion case where the number of fields of tuple and
- the number of parameters to the function containing the tuple are different
- """
-
- def before(dshape):
- x = relay.var("x", shape=dshape)
- slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
- slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
- out = relay.Tuple((slice1, slice2))
- return relay.Function([x], out)
-
- def expected(dshape):
- x = relay.var("x", shape=dshape)
- slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
- slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
- out = relay.Tuple((slice1, slice2))
- f0 = relay.Function([x], out)
-
- x = relay.var("x", shape=dshape)
- y = relay.Call(f0, [x])
- return relay.Function([x], y)
-
- dshape = (64, 64)
- z = before(dshape)
- z = relay.ir_pass.infer_type(z)
- zz = relay.ir_pass.fuse_ops(z, opt_level=0)
- assert not relay.ir_pass.free_vars(zz)
- zz = relay.ir_pass.fuse_ops(z, opt_level=2)
- zz = relay.ir_pass.infer_type(zz)
- assert not relay.ir_pass.free_vars(zz)
- after = relay.ir_pass.infer_type(expected(dshape))
- assert relay.ir_pass.alpha_equal(zz, after)
-
def test_stop_fusion():
def before(dshape):
assert relay.ir_pass.alpha_equal(zz, after)
+def test_tuple_intermediate():
+ def before(x):
+ inj = relay.squeeze(x)
+ y1 = relay.add(inj, relay.const(1, "float32"))
+ tmp = relay.squeeze(inj)
+ tmp = relay.add(tmp, relay.const(1, "float32"))
+ y2 = relay.add(tmp, relay.const(1, "float32"))
+ y3 = relay.add(inj, relay.const(1, "float32"))
+ concat = relay.concatenate((y1, y2, y3), axis=1)
+ out_inj = relay.squeeze(concat)
+ out = relay.add(out_inj, relay.const(1, "float32"))
+ return relay.Function(relay.ir_pass.free_vars(out), out)
+
+ def expected(p0):
+ f0 = before(p0)
+ x = relay.var("x", shape=dshape)
+ y = relay.Call(f0, [x])
+ return relay.Function([x], y)
+
+ dshape = (1, 16, 64, 64)
+ x = relay.var("x", shape=dshape)
+ z = before(x)
+ z = relay.ir_pass.infer_type(z)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+ assert not relay.ir_pass.free_vars(zz)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+ relay.build(zz, 'llvm')
+ zz = relay.ir_pass.infer_type(zz)
+ assert not relay.ir_pass.free_vars(zz)
+ after = relay.ir_pass.infer_type(expected(x))
+ assert relay.ir_pass.alpha_equal(zz, after)
+
+
+def test_tuple_consecutive():
+ def gen_intermediate_tuple(x):
+ y1 = relay.add(x, relay.const(1, "float32"))
+ y2 = relay.add(x, relay.const(1, "float32"))
+ y3 = relay.add(x, relay.const(1, "float32"))
+ concat = relay.concatenate((y1, y2, y3), axis=1)
+ out = relay.add(concat, relay.const(1, "float32"))
+ return out
+
+ def gen_consecutive_tuple(x):
+ y1 = gen_intermediate_tuple(x)
+ y2 = gen_intermediate_tuple(x)
+ y3 = gen_intermediate_tuple(x)
+ concat = relay.concatenate((y1, y2, y3), axis=1)
+ return concat
+
+ def before(x):
+ concat = gen_consecutive_tuple(x)
+ pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
+ out = relay.add(pooled, relay.const(1, "float32"))
+ out2 = relay.add(out, relay.const(1, "float32"))
+ out_tup = relay.Tuple((out, out2))
+ return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)
+
+ def expected(dshape):
+ p0 = relay.var("p0", shape=dshape)
+ concat = gen_consecutive_tuple(p0)
+ f0 = relay.Function([p0], concat)
+
+ p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
+ pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
+ out = relay.add(pooled, relay.const(1, "float32"))
+ f1 = relay.Function([p01], out)
+
+ p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
+ out = relay.add(p02, relay.const(1, "float32"))
+ f2 = relay.Function([p02], out)
+
+ x = relay.var("x", shape=dshape)
+ y = relay.Call(f0, [x])
+ z = relay.Call(f1, [y])
+ z2 = relay.Call(f2, [z])
+
+ return relay.Function([x], relay.Tuple((z, z2)))
+
+ dshape = (1, 16, 64, 64)
+ x = relay.var("x", shape=dshape)
+ z = before(x)
+ z = relay.ir_pass.infer_type(z)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+ assert not relay.ir_pass.free_vars(zz)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+ relay.build(zz, 'llvm')
+ zz = relay.ir_pass.infer_type(zz)
+ assert not relay.ir_pass.free_vars(zz)
+ after = relay.ir_pass.infer_type(expected(dshape))
+ assert relay.ir_pass.alpha_equal(zz, after)
+
+
+def test_inception_like():
+ def conv(data):
+ y = relay.nn.conv2d(data, relay.var("w"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=16)
+ return relay.nn.relu(data=y)
+
+ def inception_like(data):
+ c0 = conv(data)
+ c1 = conv(data)
+ return relay.concatenate((c0, c1), axis=1)
+
+ def before(dshape):
+ x = relay.var("x", shape=dshape)
+ in1 = inception_like(x)
+ in2 = inception_like(in1)
+ return relay.Function(relay.ir_pass.free_vars(in2), in2)
+
+ def expected(dshape):
+ p0 = relay.var("p0", shape=dshape)
+ c = conv(p0)
+ f0 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+ p01 = relay.var("p01", shape=dshape)
+ c = conv(p01)
+ f1 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+ p02 = relay.var("p02", shape=dshape)
+ p12 = relay.var("p12", shape=dshape)
+ concat1 = relay.concatenate((p02, p12), axis=1)
+ f_concat1 = relay.Function([p02, p12], concat1)
+
+ dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
+
+ p03 = relay.var("p03", shape=dshape2)
+ c = conv(p03)
+ f2 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+ p04 = relay.var("p04", shape=dshape2)
+ c = conv(p04)
+ f3 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+ p05 = relay.var("p05", shape=dshape)
+ p15 = relay.var("p15", shape=dshape)
+ concat2 = relay.concatenate((p05, p15), axis=1)
+ f_concat2 = relay.Function([p05, p15], concat2)
+
+ x = relay.var("x", shape=dshape)
+ c1 = relay.Call(f0, [x, relay.var("w1")])
+ c2 = relay.Call(f1, [x, relay.var("w2")])
+ concat = relay.Call(f_concat1, [c1, c2])
+ c3 = relay.Call(f2, [concat, relay.var("w3")])
+ c4 = relay.Call(f3, [concat, relay.var("w4")])
+ out = relay.Call(f_concat2, [c3, c4])
+
+ return relay.Function(relay.ir_pass.free_vars(out), out)
+
+ dshape = (1, 16, 64, 64)
+ z = before(dshape)
+ z = relay.ir_pass.infer_type(z)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+ assert not relay.ir_pass.free_vars(zz)
+ zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+ relay.build(zz, 'llvm')
+ zz = relay.ir_pass.infer_type(zz)
+ assert not relay.ir_pass.free_vars(zz)
+ after = relay.ir_pass.infer_type(expected(dshape))
+ assert relay.ir_pass.alpha_equal(zz, after)
+
+
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
- test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
test_tuple_get_root()
+ test_tuple_intermediate()
+ test_tuple_consecutive()
+ test_inception_like()