From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Tue, 20 Aug 2019 02:29:19 +0000 (+0900) Subject: [loco] Introduce push_node/pull_node helper (#6684) X-Git-Tag: accepted/tizen/unified/20190903.052428~316 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ef1f91edad5b3f2efb8993d2f06e5b2bd3caffe9;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Introduce push_node/pull_node helper (#6684) This commit introduces push_node/pull_node helper which serves as an alternative to "node" method in GraphInput/GraphOutput class. Signed-off-by: Jonghyun Park --- diff --git a/compiler/loco/include/loco/IR/Nodes.h b/compiler/loco/include/loco/IR/Nodes.h index 572f5c1..d86ed56 100644 --- a/compiler/loco/include/loco/IR/Nodes.h +++ b/compiler/loco/include/loco/IR/Nodes.h @@ -37,6 +37,7 @@ namespace loco { +class Graph; class GraphInput; class GraphOutput; @@ -86,6 +87,9 @@ private: void link(GraphOutput *, Push *push); +/// @brief Find a Push node with a given output index +Push *push_node(Graph *g, const GraphOutputIndex &index); + /** * @brief Create a value from user data */ @@ -133,6 +137,9 @@ private: void link(GraphInput *, Pull *pull); +/// @brief Find a Pull node with a given input index +Pull *pull_node(Graph *g, const GraphInputIndex &index); + /** * @brief Create a new value identical to its input * diff --git a/compiler/loco/src/IR/Nodes.cpp b/compiler/loco/src/IR/Nodes.cpp index 5b0492c..a7006a7 100644 --- a/compiler/loco/src/IR/Nodes.cpp +++ b/compiler/loco/src/IR/Nodes.cpp @@ -77,6 +77,21 @@ void link(GraphOutput *output, Push *push) push->index(output->index()); } +Push *push_node(Graph *g, const GraphOutputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto push = dynamic_cast(g->nodes()->at(n))) + { + if (push->indexed() && push->index() == index) + { + return push; + } + } + } + return nullptr; +} + } // namespace loco /** @@ -161,6 +176,21 @@ void link(GraphInput *input, Pull *pull) pull->index(input->index()); } +Pull *pull_node(Graph *g, const GraphInputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto pull = dynamic_cast(g->nodes()->at(n))) + { + if (pull->indexed() && pull->index() == index) + { + return pull; + } + } + } + return nullptr; +} + } // namespace loco /** diff --git a/compiler/loco/src/loco.test.cpp b/compiler/loco/src/loco.test.cpp index 42a2921..4c4f51a 100644 --- a/compiler/loco/src/loco.test.cpp +++ b/compiler/loco/src/loco.test.cpp @@ -102,4 +102,7 @@ TEST(LOCO, identity_network_V2) ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32); // TODO Check Shape of pull_node // TODO Check Shape of push_node + + ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node); + ASSERT_EQ(loco::push_node(g.get(), 0), push_node); }