From ef1f91edad5b3f2efb8993d2f06e5b2bd3caffe9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Aug 2019 11:29:19 +0900 Subject: [PATCH] [loco] Introduce push_node/pull_node helper (#6684) This commit introduces push_node/pull_node helper which serves as an alternative to "node" method in GraphInput/GraphOutput class. Signed-off-by: Jonghyun Park --- compiler/loco/include/loco/IR/Nodes.h | 7 +++++++ compiler/loco/src/IR/Nodes.cpp | 30 ++++++++++++++++++++++++++++++ compiler/loco/src/loco.test.cpp | 3 +++ 3 files changed, 40 insertions(+) diff --git a/compiler/loco/include/loco/IR/Nodes.h b/compiler/loco/include/loco/IR/Nodes.h index 572f5c1..d86ed56 100644 --- a/compiler/loco/include/loco/IR/Nodes.h +++ b/compiler/loco/include/loco/IR/Nodes.h @@ -37,6 +37,7 @@ namespace loco { +class Graph; class GraphInput; class GraphOutput; @@ -86,6 +87,9 @@ private: void link(GraphOutput *, Push *push); +/// @brief Find a Push node with a given output index +Push *push_node(Graph *g, const GraphOutputIndex &index); + /** * @brief Create a value from user data */ @@ -133,6 +137,9 @@ private: void link(GraphInput *, Pull *pull); +/// @brief Find a Pull node with a given input index +Pull *pull_node(Graph *g, const GraphInputIndex &index); + /** * @brief Create a new value identical to its input * diff --git a/compiler/loco/src/IR/Nodes.cpp b/compiler/loco/src/IR/Nodes.cpp index 5b0492c..a7006a7 100644 --- a/compiler/loco/src/IR/Nodes.cpp +++ b/compiler/loco/src/IR/Nodes.cpp @@ -77,6 +77,21 @@ void link(GraphOutput *output, Push *push) push->index(output->index()); } +Push *push_node(Graph *g, const GraphOutputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto push = dynamic_cast(g->nodes()->at(n))) + { + if (push->indexed() && push->index() == index) + { + return push; + } + } + } + return nullptr; +} + } // namespace loco /** @@ -161,6 +176,21 @@ void link(GraphInput *input, Pull *pull) pull->index(input->index()); } +Pull *pull_node(Graph *g, const GraphInputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto pull = dynamic_cast(g->nodes()->at(n))) + { + if (pull->indexed() && pull->index() == index) + { + return pull; + } + } + } + return nullptr; +} + } // namespace loco /** diff --git a/compiler/loco/src/loco.test.cpp b/compiler/loco/src/loco.test.cpp index 42a2921..4c4f51a 100644 --- a/compiler/loco/src/loco.test.cpp +++ b/compiler/loco/src/loco.test.cpp @@ -102,4 +102,7 @@ TEST(LOCO, identity_network_V2) ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32); // TODO Check Shape of pull_node // TODO Check Shape of push_node + + ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node); + ASSERT_EQ(loco::push_node(g.get(), 0), push_node); } -- 2.7.4