[loco] Introduce input_nodes/output_nodes helper (#3428)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 10 May 2019 06:41:54 +0000 (15:41 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 10 May 2019 06:41:54 +0000 (15:41 +0900)
* [loco] Introduce input_nodes/output_nodes helper

This commit introduces input_nodes/output_nodes helper functions which
allow users to easily enumerate graph-level input/output nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Update test name

contrib/loco/include/loco/IR/Graph.h
contrib/loco/src/IR/Graph.cpp
contrib/loco/src/IR/Graph.test.cpp

index 440d8e9..a7a18d4 100644 (file)
@@ -23,6 +23,7 @@
 
 #include <string>
 #include <memory>
+#include <vector>
 
 namespace loco
 {
@@ -157,6 +158,10 @@ private:
   OutputContext _output_ctx;
 };
 
+// TODO Use "const Graph *"
+std::vector<Node *> input_nodes(Graph *);
+std::vector<Node *> output_nodes(Graph *);
+
 std::unique_ptr<Graph> make_graph(void);
 
 } // namespace loco
index 3f8979e..43156fa 100644 (file)
@@ -28,6 +28,32 @@ Graph::NodeContext::~NodeContext()
   }
 }
 
+std::vector<loco::Node *> input_nodes(loco::Graph *g)
+{
+  std::vector<loco::Node *> res;
+
+  for (uint32_t n = 0; n < g->inputs()->size(); ++n)
+  {
+    auto node = g->inputs()->at(n)->node();
+    res.emplace_back(node);
+  }
+
+  return res;
+}
+
+std::vector<loco::Node *> output_nodes(loco::Graph *g)
+{
+  std::vector<loco::Node *> res;
+
+  for (uint32_t n = 0; n < g->outputs()->size(); ++n)
+  {
+    auto node = g->outputs()->at(n)->node();
+    res.emplace_back(node);
+  }
+
+  return res;
+}
+
 std::unique_ptr<Graph> make_graph(void) { return std::unique_ptr<Graph>{new Graph}; }
 
 } // namespace loco
index 61d0419..f583813 100644 (file)
@@ -42,3 +42,36 @@ TEST(NamedTest, setter_and_getter)
   elem.name("name");
   ASSERT_EQ(elem.name(), "name");
 }
+
+TEST(GraphTest, graph_inout_enumeration)
+{
+  auto g = loco::make_graph();
+
+  std::vector<loco::Pull *> pull_nodes;
+
+  auto pull_1 = g->nodes()->create<loco::Pull>();
+  auto pull_2 = g->nodes()->create<loco::Pull>();
+  auto pull_3 = g->nodes()->create<loco::Pull>();
+
+  auto push_1 = g->nodes()->create<loco::Push>();
+  auto push_2 = g->nodes()->create<loco::Push>();
+  auto push_3 = g->nodes()->create<loco::Push>();
+
+  g->inputs()->create()->node(pull_2);
+  g->inputs()->create()->node(pull_1);
+
+  g->outputs()->create()->node(push_1);
+  g->outputs()->create()->node(push_3);
+
+  auto input_nodes = loco::input_nodes(g.get());
+
+  ASSERT_EQ(input_nodes.size(), 2);
+  ASSERT_EQ(input_nodes.at(0), g->inputs()->at(0)->node());
+  ASSERT_EQ(input_nodes.at(1), g->inputs()->at(1)->node());
+
+  auto output_nodes = loco::output_nodes(g.get());
+
+  ASSERT_EQ(output_nodes.size(), 2);
+  ASSERT_EQ(output_nodes.at(0), g->outputs()->at(0)->node());
+  ASSERT_EQ(output_nodes.at(1), g->outputs()->at(1)->node());
+}