From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Fri, 10 May 2019 06:41:54 +0000 (+0900) Subject: [loco] Introduce input_nodes/output_nodes helper (#3428) X-Git-Tag: nncc_backup~626 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e80da89079e826efe055bd16debec0b84b6aa4db;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Introduce input_nodes/output_nodes helper (#3428) * [loco] Introduce input_nodes/output_nodes helper This commit introduces input_nodes/output_nodes helper functions which allow users to easily enumerate graph-level input/output nodes. Signed-off-by: Jonghyun Park * Update test name --- diff --git a/contrib/loco/include/loco/IR/Graph.h b/contrib/loco/include/loco/IR/Graph.h index 440d8e9..a7a18d4 100644 --- a/contrib/loco/include/loco/IR/Graph.h +++ b/contrib/loco/include/loco/IR/Graph.h @@ -23,6 +23,7 @@ #include #include +#include namespace loco { @@ -157,6 +158,10 @@ private: OutputContext _output_ctx; }; +// TODO Use "const Graph *" +std::vector input_nodes(Graph *); +std::vector output_nodes(Graph *); + std::unique_ptr make_graph(void); } // namespace loco diff --git a/contrib/loco/src/IR/Graph.cpp b/contrib/loco/src/IR/Graph.cpp index 3f8979e..43156fa 100644 --- a/contrib/loco/src/IR/Graph.cpp +++ b/contrib/loco/src/IR/Graph.cpp @@ -28,6 +28,32 @@ Graph::NodeContext::~NodeContext() } } +std::vector input_nodes(loco::Graph *g) +{ + std::vector res; + + for (uint32_t n = 0; n < g->inputs()->size(); ++n) + { + auto node = g->inputs()->at(n)->node(); + res.emplace_back(node); + } + + return res; +} + +std::vector output_nodes(loco::Graph *g) +{ + std::vector res; + + for (uint32_t n = 0; n < g->outputs()->size(); ++n) + { + auto node = g->outputs()->at(n)->node(); + res.emplace_back(node); + } + + return res; +} + std::unique_ptr make_graph(void) { return std::unique_ptr{new Graph}; } } // namespace loco diff --git a/contrib/loco/src/IR/Graph.test.cpp b/contrib/loco/src/IR/Graph.test.cpp index 61d0419..f583813 100644 --- a/contrib/loco/src/IR/Graph.test.cpp +++ b/contrib/loco/src/IR/Graph.test.cpp @@ -42,3 +42,36 @@ TEST(NamedTest, setter_and_getter) elem.name("name"); ASSERT_EQ(elem.name(), "name"); } + +TEST(GraphTest, graph_inout_enumeration) +{ + auto g = loco::make_graph(); + + std::vector pull_nodes; + + auto pull_1 = g->nodes()->create(); + auto pull_2 = g->nodes()->create(); + auto pull_3 = g->nodes()->create(); + + auto push_1 = g->nodes()->create(); + auto push_2 = g->nodes()->create(); + auto push_3 = g->nodes()->create(); + + g->inputs()->create()->node(pull_2); + g->inputs()->create()->node(pull_1); + + g->outputs()->create()->node(push_1); + g->outputs()->create()->node(push_3); + + auto input_nodes = loco::input_nodes(g.get()); + + ASSERT_EQ(input_nodes.size(), 2); + ASSERT_EQ(input_nodes.at(0), g->inputs()->at(0)->node()); + ASSERT_EQ(input_nodes.at(1), g->inputs()->at(1)->node()); + + auto output_nodes = loco::output_nodes(g.get()); + + ASSERT_EQ(output_nodes.size(), 2); + ASSERT_EQ(output_nodes.at(0), g->outputs()->at(0)->node()); + ASSERT_EQ(output_nodes.at(1), g->outputs()->at(1)->node()); +}