[loco] Introduce push_node/pull_node helper (#6684)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 20 Aug 2019 02:29:19 +0000 (11:29 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 20 Aug 2019 02:29:19 +0000 (11:29 +0900)
This commit introduces push_node/pull_node helper which serves as an
alternative to "node"  method in GraphInput/GraphOutput class.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/include/loco/IR/Nodes.h
compiler/loco/src/IR/Nodes.cpp
compiler/loco/src/loco.test.cpp

index 572f5c1..d86ed56 100644 (file)
@@ -37,6 +37,7 @@
 namespace loco
 {
 
+class Graph;
 class GraphInput;
 class GraphOutput;
 
@@ -86,6 +87,9 @@ private:
 
 void link(GraphOutput *, Push *push);
 
+/// @brief Find a Push node with a given output index
+Push *push_node(Graph *g, const GraphOutputIndex &index);
+
 /**
  * @brief Create a value from user data
  */
@@ -133,6 +137,9 @@ private:
 
 void link(GraphInput *, Pull *pull);
 
+/// @brief Find a Pull node with a given input index
+Pull *pull_node(Graph *g, const GraphInputIndex &index);
+
 /**
  * @brief Create a new value identical to its input
  *
index 5b0492c..a7006a7 100644 (file)
@@ -77,6 +77,21 @@ void link(GraphOutput *output, Push *push)
   push->index(output->index());
 }
 
+Push *push_node(Graph *g, const GraphOutputIndex &index)
+{
+  for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+  {
+    if (auto push = dynamic_cast<Push *>(g->nodes()->at(n)))
+    {
+      if (push->indexed() && push->index() == index)
+      {
+        return push;
+      }
+    }
+  }
+  return nullptr;
+}
+
 } // namespace loco
 
 /**
@@ -161,6 +176,21 @@ void link(GraphInput *input, Pull *pull)
   pull->index(input->index());
 }
 
+Pull *pull_node(Graph *g, const GraphInputIndex &index)
+{
+  for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+  {
+    if (auto pull = dynamic_cast<Pull *>(g->nodes()->at(n)))
+    {
+      if (pull->indexed() && pull->index() == index)
+      {
+        return pull;
+      }
+    }
+  }
+  return nullptr;
+}
+
 } // namespace loco
 
 /**
index 42a2921..4c4f51a 100644 (file)
@@ -102,4 +102,7 @@ TEST(LOCO, identity_network_V2)
   ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32);
   // TODO Check Shape of pull_node
   // TODO Check Shape of push_node
+
+  ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node);
+  ASSERT_EQ(loco::push_node(g.get(), 0), push_node);
 }