[Relay] Fix operator fusion for multiple output (#3871)
author雾雨魔理沙 <lolisa@marisa.moe>
Thu, 5 Sep 2019 21:39:13 +0000 (14:39 -0700)
committermasahi <masahi129@gmail.com>
Thu, 5 Sep 2019 21:39:13 +0000 (06:39 +0900)
* save

* add test

* refactor

* fix indent

* save

* refactor

src/relay/ir/pretty_printer.cc
src/relay/pass/fuse_ops.cc
tests/python/relay/test_pass_fuse_ops.py

index 0ee76dc..5197414 100644 (file)
@@ -304,14 +304,16 @@ class PrettyPrinter :
    * \return The corresponding name.
    */
   Doc AllocTypeVar(const TypeVar& var) {
+    if (memo_type_.count(var)) {
+      Doc val = memo_type_[var];
+      val << "-malformed-ir";
+      return val;
+    }
     std::string name = var->var->name_hint;
     if (name.length() == 0 || !std::isalpha(name[0])) {
       name = "t" + name;
     }
     Doc val = GetUniqueName("%" + name);
-    if (memo_type_.count(var)) {
-      val << "-malformed-ir";
-    }
     memo_type_[var] = val;
     if (var->kind != kType) {
       val << ": " << Print(var->kind);
@@ -325,16 +327,18 @@ class PrettyPrinter :
    * \return The corresponding name.
    */
   Doc AllocVar(const Var& var) {
+    // still print if ir is malformed, but show the error.
+    if (memo_.count(var)) {
+      Doc val = memo_[var];
+      val << "-malformed-ir";
+      return val;
+    }
     std::string name = var->name_hint();
     // always make sure first name is alpha
     if (name.length() == 0 || !std::isalpha(name[0])) {
       name = "v" + name;
     }
     Doc val = GetUniqueName("%" + name);
-    // still print if ir is malformed, but show the error.
-    if (memo_.count(var)) {
-      val << "-malformed-ir";
-    }
     memo_[var] = val;
     if (var->type_annotation.defined()) {
       val << ": " << Print(var->type_annotation);
index 9dc180f..b5faf4c 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 /*!
- * Copyright (c) 2018 by Contributors
+ * Copyright (c) 2019 by Contributors
  *
  * \file src/tvm/relay/pass/fuse_ops.cc
  *
@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     node->pattern = op_pattern;
     this->Update(call->op, nullptr, kOpaque);
     const auto* rtype = call->checked_type().as<TensorTypeNode>();
-    // pass the message back to all the children it references.
+    // pass the analysis back to all the children it references.
     for (size_t i = 0; i < call->args.size(); ++i) {
       const auto* arg_type =
           call->args[i]->checked_type().as<TensorTypeNode>();
-      // specifically check if result type
+      // specifically check if result type is the same as arguments type
       OpPatternKind edge_pattern = op_pattern;
       if (edge_pattern == kBroadcast &&
           arg_type != nullptr &&
@@ -403,12 +403,12 @@ class DominatorTree {
     return rhs;
   }
   /*!
-   * \brief Find the least common acenstor of the two nodes.
+   * \brief Find the least common ancestor of the two nodes.
    * \param lhs The left node.
    * \param rhs The right node.
    * \param edge_pattern
    *        The combined edge pattern across all the parents.
-   * \return The least common ancestor of thw two.
+   * \return The least common ancestor of the two.
    */
   static Node* LeastCommonAncestor(
       Node* lhs,
@@ -436,17 +436,43 @@ class DominatorTree {
     }
     return lhs;
   }
-};
-
-DominatorTree DominatorTree::PostDom(common::Arena* arena,
-                                     const IndexedForwardGraph& graph) {
-  DominatorTree tree;
-  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
-  // reverse topo order
-  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
-    size_t index = i - 1;
+  /*!
+   * \brief Find the least common ancestor of a list of nodes.
+   * \param nodes the nodes.
+   * \param edge_pattern
+   *        The combined edge pattern across all the parents.
+   * \return The least common ancestor of all nodes.
+   */
+  Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
+                            OpPatternKind* edge_pattern) {
+    auto link = input_nodes.head;
+    if (link == nullptr) {
+      return nullptr;
+    }
+    auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
+      size_t oindex = edge.node->index;
+      CHECK_LT(oindex, nodes.size());
+      Node* onode = nodes[oindex];
+      CHECK(onode != nullptr);
+      return onode;
+    };
+    Node* parent = get_node(link->value);
+    *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
+    link = link->next;
+    for (; link != nullptr; link = link->next) {
+      parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
+      *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
+    }
+    return parent;
+  }
+  /*!
+   * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
+   * \param arena The Arena.
+   * \param gnode An IndexedForwardGraph Node.
+   * \return The DominatorTree Node.
+   */
+  Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) {
     Node* tnode = arena->make<Node>();
-    auto* gnode = graph.post_dfs_order[index];
     tnode->gnode = gnode;
     if (gnode->extern_ref) {
       tnode->depth = 1;
@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
     } else {
       // find the LCAs of all outputs.
       OpPatternKind pattern = kElemWise;
-      Node* parent = nullptr;
-      for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
-        size_t oindex = link->value.node->index;
-        CHECK_LT(oindex, tree.nodes.size());
-        Node* onode = tree.nodes[oindex];
-        CHECK(onode != nullptr);
-        if (parent != nullptr) {
-          parent = LeastCommonAncestor(parent, onode, &pattern);
-        } else {
-          parent = onode;
-        }
-        pattern = CombinePattern(pattern, link->value.pattern);
-      }
+      Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
       tnode->depth = parent ? parent->depth + 1 : 1;
       tnode->parent = parent;
       tnode->pattern = pattern;
     }
-    tree.nodes[index] = tnode;
+    return tnode;
+  }
+};
+
+
+DominatorTree DominatorTree::PostDom(common::Arena* arena,
+                                     const IndexedForwardGraph& graph) {
+  DominatorTree tree;
+  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
+  // reverse topo order
+  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
+    size_t index = i - 1;
+    tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
   }
   return tree;
 }
@@ -614,7 +640,7 @@ class GraphPartitioner {
     // merge the current group to the parent if possible.
     MergeFromTo(gnode, target);
     for (auto link = src->outputs.head; link != nullptr; link = link->next) {
-      CommitFuse_(link->value.node, sink, target);;
+      CommitFuse_(link->value.node, sink, target);
     }
   }
   /*!
@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {
 
   Expr VisitExpr_(const TupleNode* tuple) {
     auto* ret_group = gmap_.at(tuple)->FindRoot();
-    if (ret_group == gmap_.at(tuple)) {
+    if (ret_group->root_ref == tuple) {
       return ExprMutator::VisitExpr_(tuple);
     }
     // This tuple is an intermediate node in the group
@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
     auto* ret_group = gmap_.at(tuple_get)->FindRoot();
     auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
     auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
-    if (ret_group == gmap_.at(tuple_get)) {
+    if (ret_group->root_ref == tuple_get) {
       if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
         // Isolated. This case occurs when tuple is created by an Opaque op
         // e.g. multibox_transform_loc
@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
   }
 };
 
-// Temporary solution, should be handled by implementing a "FunctionPass"
-// which applies fusion to each function.
-struct GlobalVarLiveness : ExprVisitor {
-  Module module;
-  std::set<GlobalVar> visited;
-
-  explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}
-
-  void VisitExpr_(const GlobalVarNode* gvar_node) {
-    auto gvar = GetRef<GlobalVar>(gvar_node);
-    if (visited.find(gvar) == visited.end()) {
-      visited.insert(gvar);
-      this->VisitExpr(this->module->Lookup(gvar));
-    }
-  }
-};
-
-std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
-  auto gvl = GlobalVarLiveness(mod);
-  gvl.VisitExpr(expr);
-  return gvl.visited;
-}
-
 Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
-  // First we convert all chains of fusable ops into
-  // abstracted functions which we mark as primtive
-  // then we convert these primtive functions into
-  // new operators.
-  if (!module.defined()) {
-    return FuseMutator().Transform(expr, fuse_opt_level);
-  } else {
-    auto lgvs = LiveGlobals(module, expr);
-    for (auto lv : lgvs) {
-      auto body = module->Lookup(lv);
-      auto e = FuseMutator().Transform(body, fuse_opt_level);
-      module->Add(lv, Downcast<Function>(e), true);
-    }
-    return FuseMutator().Transform(expr, fuse_opt_level);
-  }
+  return FuseMutator().Transform(expr, fuse_opt_level);
 }
 
 namespace transform {
index 4c03840..f148502 100644 (file)
@@ -541,6 +541,18 @@ def test_immutable():
     assert relay.analysis.alpha_equal(new_mod, expected())
 
 
+def test_split():
+    """Test that the result is well formed."""
+    x = relay.var("x", shape=(6, 9))
+    y = relay.split(x, 3).astuple()
+    a = relay.TupleGetItem(y, 0)
+    b = relay.TupleGetItem(y, 1)
+    c = relay.TupleGetItem(y, 2)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
+    mod = transform.FuseOps()(mod)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
@@ -555,3 +567,4 @@ if __name__ == "__main__":
     test_inception_like()
     test_fuse_parallel_injective()
     test_immutable()
+    test_split()