[neurun] Introduce Graph::Iterator (#2387)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Tue, 21 Aug 2018 08:19:20 +0000 (17:19 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Tue, 21 Aug 2018 08:19:20 +0000 (17:19 +0900)
This commit introduces `Graph::Iterator` class to support various
kinds of iterators. The template parameter `is_const` is used to
support both const and mutable iterator without writing same code
twice.

- Introduce `Graph::Iterator`
- Introduce `Graph::PostDfsIterator` class
- Remove `Graph::iteratePostDfs()` method

Resolve #2386

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/src/graph/Graph.cc
runtimes/neurun/src/graph/Graph.h
runtimes/neurun/src/graph/operation/Set.cc
runtimes/neurun/src/graph/operation/Set.h
runtimes/neurun/src/linear/Linear.cc

index 577b271..6be8a67 100644 (file)
@@ -69,12 +69,37 @@ void Graph::lower(void)
   }
 }
 
-void Graph::iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const
+std::unique_ptr<linear::Linear> Graph::linearize(void)
 {
-  std::vector<bool> visited(_operations.size(), false);
+  assert(_phase == Phase::LOWERED);
+
+  auto linear = nnfw::make_unique<linear::Linear>(*this);
+
+  // TODO Move the operations and operands to linear object
+
+  _phase = Phase::LINEARIZED;
+
+  return std::move(linear);
+}
+
+} // namespace graph
+} // namespace neurun
 
-  std::function<void(const operation::Index &index, const operation::Node &)> dfs_recursive =
-      [&](const operation::Index &index, const operation::Node &node) -> void {
+namespace neurun
+{
+namespace graph
+{
+
+template class Graph::PostDfsIterator<true>;
+template class Graph::PostDfsIterator<false>;
+
+template <bool is_const>
+void Graph::PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const
+{
+  std::vector<bool> visited(graph._operations.size(), false);
+
+  std::function<void(const operation::Index &, NodeRef)> dfs_recursive =
+      [&](const operation::Index &index, NodeRef node) -> void {
     if (visited[index.asInt()])
       return;
     visited[index.asInt()] = true;
@@ -85,40 +110,26 @@ void Graph::iteratePostDfs(const std::function<void(const operation::Node &)> &f
       // TODO Fix traversing algorithm
       //      Every time need to search for operations that has `outgoing` as incoming from all
       //      operations but we can hold that info cached
-      _operations.iterate(
-          [&](const operation::Index &cand_index, const operation::Node &cand_node) {
-            auto inputs = cand_node.inputs();
-            for (auto input : inputs.list())
-            {
-              if (output == input)
-              {
-                dfs_recursive(cand_index, cand_node);
-              }
-            }
-          });
+      graph._operations.iterate([&](const operation::Index &cand_index, NodeRef cand_node) -> void {
+        auto inputs = cand_node.inputs();
+        for (auto input : inputs.list())
+        {
+          if (output == input)
+          {
+            dfs_recursive(cand_index, cand_node);
+          }
+        }
+      });
     }
 
     fn(node);
   };
 
-  _operations.iterate(dfs_recursive);
+  graph._operations.iterate(dfs_recursive);
 
   // All of the operations(nodes) must have been visited.
   assert(std::all_of(visited.begin(), visited.end(), [](bool v) { return v; }));
 }
 
-std::unique_ptr<linear::Linear> Graph::linearize(void)
-{
-  assert(_phase == Phase::LOWERED);
-
-  auto linear = nnfw::make_unique<linear::Linear>(*this);
-
-  // TODO Move the operations and operands to linear object
-
-  _phase = Phase::LINEARIZED;
-
-  return std::move(linear);
-}
-
 } // namespace graph
 } // namespace neurun
index 42321f8..6592d64 100644 (file)
@@ -33,6 +33,32 @@ private:
   };
 
 public:
+  template <bool is_const> class Iterator
+  {
+  public:
+    using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type;
+    using NodeRef =
+        typename std::conditional<is_const, const operation::Node &, operation::Node &>::type;
+    using IterFn = std::function<void(NodeRef)>;
+
+  public:
+    virtual ~Iterator() = default;
+    virtual void iterate(GraphRef graph, const IterFn &fn) const = 0;
+  };
+
+  template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const>
+  {
+  public:
+    using GraphRef = typename Iterator<is_const>::GraphRef;
+    using NodeRef = typename Iterator<is_const>::NodeRef;
+    using IterFn = typename Iterator<is_const>::IterFn;
+
+  public:
+    void iterate(GraphRef graph, const IterFn &fn) const;
+  };
+  using PostDfsConstIterator = PostDfsIterator<true>;
+
+public:
   Graph(void) = default;
 
   // Graph Building
@@ -56,10 +82,6 @@ public:
   operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
   const operation::Set &operations() const { return _operations; }
 
-public:
-  // TODO Introduce Iterator class to support many kinds of interation
-  void iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const;
-
 private:
   Phase _phase{Phase::BUILDING};
   operation::Set _operations;
index 5e1a2b0..0b44cf8 100644 (file)
@@ -30,6 +30,14 @@ void Set::iterate(const std::function<void(const Index &, const Node &)> &fn) co
   }
 }
 
+void Set::iterate(const std::function<void(const Index &, Node &)> &fn)
+{
+  for (uint32_t index = 0; index < _nodes.size(); index++)
+  {
+    fn(Index{index}, *_nodes[index]);
+  }
+}
+
 } // namespace operation
 } // namespace graph
 } // namespace neurun
index d0bc536..dfcb871 100644 (file)
@@ -28,6 +28,7 @@ public:
   bool exist(const Index &) const;
   uint32_t size() const { return _nodes.size(); }
   void iterate(const std::function<void(const Index &, const Node &)> &fn) const;
+  void iterate(const std::function<void(const Index &, Node &)> &fn);
 
 private:
   std::vector<std::unique_ptr<Node>> _nodes;
index 73c1fb2..d8be62e 100644 (file)
@@ -20,10 +20,11 @@ Linear::Linear(const graph::Graph &graph)
   //   2. Append the node to vector when DFS for the node finishes(post order)
   //   3. Reverse the order of nodes
 
-  graph.iteratePostDfs([&](const neurun::graph::operation::Node &node) {
-    auto op = node.op();
-    _operations.emplace_back(op);
-  });
+  graph::Graph::PostDfsConstIterator().iterate(graph,
+                                               [&](const neurun::graph::operation::Node &node) {
+                                                 auto op = node.op();
+                                                 _operations.emplace_back(op);
+                                               });
 
   std::reverse(std::begin(_operations), std::end(_operations));
 }