Ghost nodes in NNVM graph (#3290)
authorPrzemyslaw Tredak <ptredak@nvidia.com>
Wed, 5 Jun 2019 23:27:16 +0000 (16:27 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 5 Jun 2019 23:27:16 +0000 (16:27 -0700)
nnvm/include/nnvm/op_attr_types.h
nnvm/src/core/graph.cc

index bcc8247..470b8e8 100644 (file)
@@ -137,6 +137,17 @@ using FInferType = FInferNodeEntryAttr<int>;
 using TIsBackward = bool;
 
 /*!
+ * \brief Whether this op is a ghost node.
+ * If TIsGhost is true:
+ *   - The node with this op will not be visible in the indexed graph.
+ *
+ * \note Register under "TIsGhost"
+ * This enables shape/type inference for backward nodes when
+ * fusion is present.
+ */
+using TIsGhost = bool;
+
+/*!
  * \brief Get possible inplace options.
  *  This function enables optimization to reuse memory of inputs in output.
  * \param attrs The attributes of the node
index 92ff986..29149f4 100644 (file)
@@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
 
   DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
              (const NodePtr& n) {
+      const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
+      if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
       CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
       uint32_t nid = static_cast<uint32_t>(nodes_.size());
       CHECK(n);
@@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
       inputs_rptr.push_back(input_entries_.size());
       // control deps
       for (const auto& nptr : n->control_deps) {
+        if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
         auto it = node2index_.find(nptr.get());
         CHECK(it != node2index_.end() && it->first == nptr.get());
         control_deps_.push_back(it->second);