[loco] Visit each node only once (#3378)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 29 Apr 2019 01:49:03 +0000 (10:49 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 29 Apr 2019 01:49:03 +0000 (10:49 +0900)
The current implementation of postorder_traversal may visits the same
node multiple times.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/loco/src/IR/Algorithm.cpp
contrib/loco/src/IR/Algorithm.test.cpp

index 87e92a2..1211552 100644 (file)
@@ -17,6 +17,7 @@
 #include "loco/IR/Algorithm.h"
 
 #include <cassert>
+#include <set>
 #include <stack>
 
 namespace
@@ -53,8 +54,13 @@ std::vector<loco::Node *> postorder_traversal(const std::vector<loco::Node *> &r
 {
   std::vector<loco::Node *> res;
 
+  std::set<loco::Node *> visited_nodes;
   std::stack<Frame> frames;
 
+  auto visited = [&visited_nodes](loco::Node *node) {
+    return visited_nodes.find(node) != visited_nodes.end();
+  };
+
   // NOTE There is not much difference between "auto" and "auto &" as node is of "loco::Node *"
   // type.
   for (auto node : roots)
@@ -66,6 +72,16 @@ std::vector<loco::Node *> postorder_traversal(const std::vector<loco::Node *> &r
   {
     auto &top_frame = frames.top();
 
+    if (top_frame.pos() == -1)
+    {
+      if (visited(top_frame.ptr()))
+      {
+        frames.pop();
+        continue;
+      }
+      visited_nodes.insert(top_frame.ptr());
+    }
+
     top_frame.advance();
 
     assert(top_frame.pos() >= 0);
index d616dcf..eb1835a 100644 (file)
 #include "loco/IR/Algorithm.h"
 #include "loco/IR/Graph.h"
 
+#include <algorithm>
+
 #include <gtest/gtest.h>
 
+namespace
+{
+
+bool contains(const std::vector<loco::Node *> &vec, loco::Node *val)
+{
+  return std::any_of(vec.begin(), vec.end(), [val](loco::Node *node) { return node == val; });
+}
+
+} // namespace
+
 TEST(AlgorithmTest, postorder_traversal)
 {
   auto g = loco::make_graph();
@@ -37,3 +49,28 @@ TEST(AlgorithmTest, postorder_traversal)
   ASSERT_EQ(seq.at(0), pull_1);
   ASSERT_EQ(seq.at(1), push);
 }
+
+TEST(AlgorithmTest, postorder_traversal_visit_once)
+{
+  auto g = loco::make_graph();
+
+  // Create a network of the following form:
+  //
+  //   Push1  Push2 <-- outputs
+  //    \     /
+  //     Pull  <-- input
+  //
+  auto pull = g->nodes()->create<loco::Pull>();
+  auto push_1 = g->nodes()->create<loco::Push>();
+  auto push_2 = g->nodes()->create<loco::Push>();
+
+  push_1->from(pull);
+  push_2->from(pull);
+
+  auto seq = loco::postorder_traversal({push_1, push_2});
+
+  ASSERT_EQ(seq.size(), 3);
+  ASSERT_TRUE(contains(seq, pull));
+  ASSERT_TRUE(contains(seq, push_1));
+  ASSERT_TRUE(contains(seq, push_2));
+}