[neurun] Enhance DFS implementation (#2561)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Tue, 4 Sep 2018 04:07:38 +0000 (13:07 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 4 Sep 2018 04:07:38 +0000 (13:07 +0900)
Now that we have operand def/use info, we can do DFS efficiently.
This is applied to PostDfsIterator. This might also be applied to
DAGVerifier, but operand def/use info is built when Graph building is
finished.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/src/graph/Graph.cc
runtimes/neurun/src/graph/Graph.h

index 04e08ff..ef97c84 100644 (file)
@@ -188,6 +188,8 @@ void Graph::DefaultIterator<is_const>::iterate(GraphRef graph, const IterFn &fn)
 template <bool is_const>
 void Graph::PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const
 {
+  assert(!graph.isBuildingPhase()); // Restrict iteration condition
+
   std::vector<bool> visited(graph._operations.size(), false);
 
   std::function<void(const operation::Index &, NodeRef)> dfs_recursive =
@@ -196,22 +198,13 @@ void Graph::PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn)
       return;
     visited[index.asInt()] = true;
 
-    auto outputs = node.getOutputs();
-    for (auto output : outputs)
+    for (auto output : node.getOutputs())
     {
-      // TODO Fix traversing algorithm
-      //      Every time need to search for operations that has `outgoing` as incoming from all
-      //      operations but we can hold that info cached
-      graph._operations.iterate([&](const operation::Index &cand_index, NodeRef cand_node) -> void {
-        auto inputs = cand_node.getInputs();
-        for (auto input : inputs)
-        {
-          if (output == input)
-          {
-            dfs_recursive(cand_index, cand_node);
-          }
-        }
-      });
+      const auto &operand = graph._operands.at(output);
+      for (const auto &use : operand.getUses().list())
+      {
+        dfs_recursive(use, graph._operations.at(use));
+      }
     }
 
     fn(node);
index 1e217cf..c9d1e38 100644 (file)
@@ -86,7 +86,7 @@ public:
   void finishBuilding(void);
   void lower(void);
   std::unique_ptr<linear::Linear> linearize(void);
-  bool isBuildingPhase(void) { return _phase == Phase::BUILDING; }
+  bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; }
 
 private:
   void initializeUseDef();