[Relay, OpFusion] Better tuple fusion implementation (#3092)
authormasahi <masahi129@gmail.com>
Mon, 29 Apr 2019 02:18:41 +0000 (11:18 +0900)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 29 Apr 2019 02:18:41 +0000 (19:18 -0700)
include/tvm/relay/op_attr_types.h
python/tvm/relay/op/op.py
src/relay/pass/fuse_ops.cc
tests/python/relay/test_backend_compile_engine.py
tests/python/relay/test_pass_fuse_ops.py

index 464bc1c..ca7f6e5 100644 (file)
@@ -49,6 +49,9 @@ enum OpPatternKind {
   // Complex operation, can still fuse elemwise operations into its output.
   // but cannot chain another complex op
   kOutEWiseFusable = 4,
+  // The pattern for tuple nodes. Can fuse into subsequent injective ops,
+  // but treated specially
+  kTuple = 7,
   // Opaque operation, cannot fuse anything.
   kOpaque = 8
 };
index 6312f02..6ba2079 100644 (file)
@@ -112,6 +112,8 @@ class OpPattern(object):
     COMM_REDUCE = 3
     # Complex op, can still fuse ewise into it
     OUT_ELEMWISE_FUSABLE = 4
+    # Represents tuple node
+    TUPLE = 7
     # Not fusable opaque op
     OPAQUE = 8
 
index 12e3174..55d6098 100644 (file)
@@ -267,7 +267,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
   void VisitExpr_(const TupleNode* op) final {
     CHECK(graph_.node_map.count(op));
     Node* tuple_node = graph_.node_map.at(op);
-    tuple_node->pattern = kInjective;
+    tuple_node->pattern = kTuple;
     for (const Expr& field : op->fields) {
       if (field->checked_type().as<TensorTypeNode>()) {
         this->Update(field, tuple_node, kInjective);
@@ -661,12 +661,36 @@ class GraphPartitioner {
       // no actions needed if the current node have no dominator
       if (dom_node->parent == nullptr) continue;
       CHECK(!graph_node->extern_ref);
-      // Skip if current node is already fused to the parent.
       size_t dom_parent_gindex = dom_node->parent->gnode->index;
+
+      if (phase == 2) {
+        // Fuse injective ops into intermediate tuples, if any
+        if (group_node->pattern > kInjective) continue;
+        Group* dom_parent_group = groups_[dom_parent_gindex];
+        Group* dom_root_group = dom_parent_group->FindRoot();
+        // If dom node group has a tuple as its root, we do not fuse tuple fields into it
+        if (dom_root_group->pattern == kTuple) continue;
+        if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
+          // Now we know the tuple has been fused into subsequent injective ops
+          auto fcond = [](OpPatternKind kind, bool is_sink) {
+            return kind <= kInjective;
+          };
+          // dom_root_group can also be tuple, as in inception layers
+          // CheckPath is needed to avoid fusing two intermediate tuples
+          if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
+            CommitFuse(graph_node, dom_node->parent->gnode);
+          }
+        }
+        continue;
+      }
+
+      // Skip if current node is already fused to the parent.
       if (groups_[dom_parent_gindex] != nullptr &&
           group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
         continue;
       }
+      // Do not fuse into tuple for now
+      if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
       // Try to fuse current node to its post-dominator.
       if (group_node->pattern == kOutEWiseFusable) {
         if (phase != 0) continue;
@@ -702,7 +726,7 @@ class GraphPartitioner {
             CommitFuse(graph_node, dom_node->parent->gnode);
           }
         }
-      } else if (group_node->pattern == kInjective) {
+      } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
         // defer injective fusion to second phase.
         // so conv2d always finishes fusing.
         if (phase != 1) continue;
@@ -728,7 +752,7 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
   // get post dominator tree
   auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
   // run fusion algorithm.
-  for (int phase = 0; phase < 2; ++phase) {
+  for (int phase = 0; phase < 3; ++phase) {
     this->RunFuse(graph, post_dom_tree, phase);
   }
   return std::move(groups_);
@@ -821,29 +845,11 @@ class FuseMutator : private ExprMutator {
 
   Expr VisitExpr_(const TupleNode* tuple) {
     auto* ret_group = gmap_.at(tuple)->FindRoot();
-    Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
     if (ret_group == gmap_.at(tuple)) {
-      // This tuple is the root of its group. Check if all fields come from other groups.
-      bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
-      for (size_t i = 0; i < new_fields.size() && isolated; ++i) {
-        isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
-      }
-      if (isolated) {
-        // Do not put a isolated tuple into a function
-        return ExprMutator::VisitExpr_(tuple);
-      }
-      // This tuple has been fused with other ops before it
-      for (size_t i = 0; i < new_fields.size(); i++) {
-        // Copy function arguments to tuple field of the output because currently graph memory
-        // planer doesn't support inplace operations
-        if (new_fields[i].as<VarNode>()) {
-          auto copy = Copy(new_fields[i]);
-          new_fields.Set(i, copy);
-        }
-      }
-      return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
+      return ExprMutator::VisitExpr_(tuple);
     }
     // This tuple is an intermediate node in the group
+    Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
     return TupleNode::make(new_fields);
   }
 
index 3b479b8..ca4619c 100644 (file)
@@ -69,8 +69,16 @@ def test_compile_injective_with_tuple():
     relay.build(func, 'llvm')
 
 
+def test_compile_tuple_dup():
+    x = relay.var("data", shape=(16, 16))
+    log = relay.log(x)
+    output = relay.Tuple([log, log])
+    f = relay.Function([x], output)
+    relay.build(f, 'llvm')
+
+
 if __name__ == "__main__":
     test_compile_engine()
     test_compile_placeholder_bypass()
     test_compile_injective_with_tuple()
-
+    test_compile_tuple_dup()
index baafbee..bdffdf7 100644 (file)
@@ -176,16 +176,14 @@ def test_tuple_root():
         f0 = relay.Function([x], pooled)
 
         p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
-        p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
-        p1_copy = relay.copy(p1)
         upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
-        out = relay.Tuple((upsampled, p1_copy))
-        f1 = relay.Function([p0, p1], out)
+        f1 = relay.Function([p0], upsampled)
 
         x = relay.var("x", shape=dshape)
         y = relay.Call(f0, [x])
-        z = relay.Call(f1, [y, x])
-        return relay.Function([x], z)
+        z = relay.Call(f1, [y])
+        tup = relay.Tuple((z, x))
+        return relay.Function([x], tup)
 
     dshape = (1, 16, 64, 64)
     z = before(dshape)
@@ -199,41 +197,6 @@ def test_tuple_root():
     assert relay.ir_pass.alpha_equal(zz, after)
 
 
-def test_tuple_strided_slice():
-    """
-    Test fusion case where the number of fields of tuple and
-    the number of parameters to the function containing the tuple are different
-    """
-
-    def before(dshape):
-        x = relay.var("x", shape=dshape)
-        slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
-        slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
-        out = relay.Tuple((slice1, slice2))
-        return relay.Function([x], out)
-
-    def expected(dshape):
-        x = relay.var("x", shape=dshape)
-        slice1 = relay.strided_slice(x, begin=[0, 0], end=[dshape[1]//2, dshape[1]], strides=[1,1])
-        slice2 = relay.strided_slice(x, begin=[dshape[1]//2, 0], end=[dshape[0], dshape[1]], strides=[1,1])
-        out = relay.Tuple((slice1, slice2))
-        f0 = relay.Function([x], out)
-
-        x = relay.var("x", shape=dshape)
-        y = relay.Call(f0, [x])
-        return relay.Function([x], y)
-
-    dshape = (64, 64)
-    z = before(dshape)
-    z = relay.ir_pass.infer_type(z)
-    zz = relay.ir_pass.fuse_ops(z, opt_level=0)
-    assert not relay.ir_pass.free_vars(zz)
-    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
-    zz = relay.ir_pass.infer_type(zz)
-    assert not relay.ir_pass.free_vars(zz)
-    after = relay.ir_pass.infer_type(expected(dshape))
-    assert relay.ir_pass.alpha_equal(zz, after)
-
 
 def test_stop_fusion():
     def before(dshape):
@@ -377,13 +340,178 @@ def test_tuple_get_root():
     assert relay.ir_pass.alpha_equal(zz, after)
 
 
+def test_tuple_intermediate():
+    def before(x):
+        inj = relay.squeeze(x)
+        y1 = relay.add(inj, relay.const(1, "float32"))
+        tmp = relay.squeeze(inj)
+        tmp = relay.add(tmp, relay.const(1, "float32"))
+        y2 = relay.add(tmp, relay.const(1, "float32"))
+        y3 = relay.add(inj, relay.const(1, "float32"))
+        concat = relay.concatenate((y1, y2, y3), axis=1)
+        out_inj = relay.squeeze(concat)
+        out = relay.add(out_inj, relay.const(1, "float32"))
+        return relay.Function(relay.ir_pass.free_vars(out), out)
+
+    def expected(p0):
+        f0 = before(p0)
+        x = relay.var("x", shape=dshape)
+        y = relay.Call(f0, [x])
+        return relay.Function([x], y)
+
+    dshape = (1, 16, 64, 64)
+    x = relay.var("x", shape=dshape)
+    z = before(x)
+    z = relay.ir_pass.infer_type(z)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+    assert not relay.ir_pass.free_vars(zz)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+    relay.build(zz, 'llvm')
+    zz = relay.ir_pass.infer_type(zz)
+    assert not relay.ir_pass.free_vars(zz)
+    after = relay.ir_pass.infer_type(expected(x))
+    assert relay.ir_pass.alpha_equal(zz, after)
+
+
+def test_tuple_consecutive():
+    def gen_intermediate_tuple(x):
+        y1 = relay.add(x, relay.const(1, "float32"))
+        y2 = relay.add(x, relay.const(1, "float32"))
+        y3 = relay.add(x, relay.const(1, "float32"))
+        concat = relay.concatenate((y1, y2, y3), axis=1)
+        out = relay.add(concat, relay.const(1, "float32"))
+        return out
+
+    def gen_consecutive_tuple(x):
+        y1 = gen_intermediate_tuple(x)
+        y2 = gen_intermediate_tuple(x)
+        y3 = gen_intermediate_tuple(x)
+        concat = relay.concatenate((y1, y2, y3), axis=1)
+        return concat
+
+    def before(x):
+        concat = gen_consecutive_tuple(x)
+        pooled = relay.nn.max_pool2d(concat, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
+        out = relay.add(pooled, relay.const(1, "float32"))
+        out2 = relay.add(out, relay.const(1, "float32"))
+        out_tup = relay.Tuple((out, out2))
+        return relay.Function(relay.ir_pass.free_vars(out_tup), out_tup)
+
+    def expected(dshape):
+        p0 = relay.var("p0", shape=dshape)
+        concat = gen_consecutive_tuple(p0)
+        f0 = relay.Function([p0], concat)
+
+        p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
+        pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
+        out = relay.add(pooled, relay.const(1, "float32"))
+        f1 = relay.Function([p01], out)
+
+        p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
+        out = relay.add(p02, relay.const(1, "float32"))
+        f2 = relay.Function([p02], out)
+
+        x = relay.var("x", shape=dshape)
+        y = relay.Call(f0, [x])
+        z = relay.Call(f1, [y])
+        z2 = relay.Call(f2, [z])
+
+        return relay.Function([x], relay.Tuple((z, z2)))
+
+    dshape = (1, 16, 64, 64)
+    x = relay.var("x", shape=dshape)
+    z = before(x)
+    z = relay.ir_pass.infer_type(z)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+    assert not relay.ir_pass.free_vars(zz)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+    relay.build(zz, 'llvm')
+    zz = relay.ir_pass.infer_type(zz)
+    assert not relay.ir_pass.free_vars(zz)
+    after = relay.ir_pass.infer_type(expected(dshape))
+    assert relay.ir_pass.alpha_equal(zz, after)
+
+
+def test_inception_like():
+    def conv(data):
+        y = relay.nn.conv2d(data, relay.var("w"),
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            channels=16)
+        return relay.nn.relu(data=y)
+
+    def inception_like(data):
+        c0 = conv(data)
+        c1 = conv(data)
+        return relay.concatenate((c0, c1), axis=1)
+
+    def before(dshape):
+        x = relay.var("x", shape=dshape)
+        in1 = inception_like(x)
+        in2 = inception_like(in1)
+        return relay.Function(relay.ir_pass.free_vars(in2), in2)
+
+    def expected(dshape):
+        p0 = relay.var("p0", shape=dshape)
+        c = conv(p0)
+        f0 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+        p01 = relay.var("p01", shape=dshape)
+        c = conv(p01)
+        f1 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+        p02 = relay.var("p02", shape=dshape)
+        p12 = relay.var("p12", shape=dshape)
+        concat1 = relay.concatenate((p02, p12), axis=1)
+        f_concat1 = relay.Function([p02, p12], concat1)
+
+        dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
+
+        p03 = relay.var("p03", shape=dshape2)
+        c = conv(p03)
+        f2 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+        p04 = relay.var("p04", shape=dshape2)
+        c = conv(p04)
+        f3 = relay.Function(relay.ir_pass.free_vars(c), c)
+
+        p05 = relay.var("p05", shape=dshape)
+        p15 = relay.var("p15", shape=dshape)
+        concat2 = relay.concatenate((p05, p15), axis=1)
+        f_concat2 = relay.Function([p05, p15], concat2)
+
+        x = relay.var("x", shape=dshape)
+        c1 = relay.Call(f0, [x, relay.var("w1")])
+        c2 = relay.Call(f1, [x, relay.var("w2")])
+        concat = relay.Call(f_concat1, [c1, c2])
+        c3 = relay.Call(f2, [concat, relay.var("w3")])
+        c4 = relay.Call(f3, [concat, relay.var("w4")])
+        out = relay.Call(f_concat2, [c3, c4])
+
+        return relay.Function(relay.ir_pass.free_vars(out), out)
+
+    dshape = (1, 16, 64, 64)
+    z = before(dshape)
+    z = relay.ir_pass.infer_type(z)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=0)
+    assert not relay.ir_pass.free_vars(zz)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+    relay.build(zz, 'llvm')
+    zz = relay.ir_pass.infer_type(zz)
+    assert not relay.ir_pass.free_vars(zz)
+    after = relay.ir_pass.infer_type(expected(dshape))
+    assert relay.ir_pass.alpha_equal(zz, after)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
     test_concatenate()
     test_tuple_root()
-    test_tuple_strided_slice()
     test_stop_fusion()
     test_fuse_myia_regression()
     test_fuse_tuple_get_elemwise()
     test_tuple_get_root()
+    test_tuple_intermediate()
+    test_tuple_consecutive()
+    test_inception_like()