namespace loco
{
+class Graph;
class GraphInput;
class GraphOutput;
void link(GraphOutput *, Push *push);
+/// @brief Find a Push node with a given output index
+Push *push_node(Graph *g, const GraphOutputIndex &index);
+
/**
* @brief Create a value from user data
*/
void link(GraphInput *, Pull *pull);
+/// @brief Find a Pull node with a given input index
+Pull *pull_node(Graph *g, const GraphInputIndex &index);
+
/**
* @brief Create a new value identical to its input
*
push->index(output->index());
}
+Push *push_node(Graph *g, const GraphOutputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto push = dynamic_cast<Push *>(g->nodes()->at(n)))
+ {
+ if (push->indexed() && push->index() == index)
+ {
+ return push;
+ }
+ }
+ }
+ return nullptr;
+}
+
} // namespace loco
/**
pull->index(input->index());
}
+Pull *pull_node(Graph *g, const GraphInputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto pull = dynamic_cast<Pull *>(g->nodes()->at(n)))
+ {
+ if (pull->indexed() && pull->index() == index)
+ {
+ return pull;
+ }
+ }
+ }
+ return nullptr;
+}
+
} // namespace loco
/**
ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32);
// TODO Check Shape of pull_node
// TODO Check Shape of push_node
+
+ ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node);
+ ASSERT_EQ(loco::push_node(g.get(), 0), push_node);
}