From 99cf819be0f1fb2399dd052236fb92d5e88ece64 Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Fri, 25 Jun 2021 15:39:53 +0900 Subject: [PATCH] [graph_core] Implement input_list, output_list for multi input, output - Make input_list, output_list and its getter to support multi input, output Signed-off-by: hyeonseok lee --- nntrainer/graph/graph_core.h | 32 ++++++++++++++++++++++++++++++++ nntrainer/graph/network_graph.cpp | 13 ++----------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/nntrainer/graph/graph_core.h b/nntrainer/graph/graph_core.h index 261aa74..1c7f928 100644 --- a/nntrainer/graph/graph_core.h +++ b/nntrainer/graph/graph_core.h @@ -206,6 +206,36 @@ public: void addLossToSorted(); /** + * @brief getter of graph input nodes with index number + * @param idx + * @return graph node of input node + */ + const std::shared_ptr &getInputNode(unsigned int idx) const { + return input_list[idx]; + } + + /** + * @brief getter of number of input nodes + * @return number of input nodes + */ + unsigned int getNumInputNodes() const { return input_list.size(); } + + /** + * @brief getter of graph output nodes with index number + * @param idx + * @return graph node of output node + */ + const std::shared_ptr &getOutputNode(unsigned int idx) const { + return output_list[idx]; + } + + /** + * @brief getter of number of output nodes + * @return number of output nodes + */ + unsigned int getNumOutputNodes() const { return output_list.size(); } + + /** * @brief Verify if the node exists */ inline bool verifyNode(const std::string &name) { @@ -215,6 +245,8 @@ public: } private: + std::vector> input_list; + std::vector> output_list; std::vector> node_list; /**< Unordered Node List */ std::vector> Sorted; /**< Ordered Node List */ diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 2303b67..2405f22 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -185,7 +185,7 @@ int NetworkGraph::realizeActivationType( in_node->setProperty({"activation=none"}); lnode->setInputLayers({in_node->getName()}); - /** output layers for layer aobj will be set in setOutputLayers() */ + /** output layers for layer obj will be set in setOutputLayers() */ updateConnectionName(in_node->getName(), lnode->getName()); graph.addNode(lnode, false); @@ -308,7 +308,6 @@ int NetworkGraph::addLossLayer(const std::string &loss_type) { void NetworkGraph::setOutputLayers() { - size_t last_layer_count = 0; for (auto iter_idx = cbegin(); iter_idx != cend(); iter_idx++) { auto &layer_idx = *iter_idx; for (auto iter_i = cbegin(); iter_i != cend(); iter_i++) { @@ -332,15 +331,6 @@ void NetworkGraph::setOutputLayers() { } } } - - if (layer_idx->getNumOutputConnections() == 0) { - last_layer_count += 1; - } - } - - if (last_layer_count != 1) { - throw std::invalid_argument( - "Error: Multiple last layers in the model not supported"); } } @@ -439,6 +429,7 @@ int NetworkGraph::realizeGraph() { } } /// @todo add check that input_layers <-> output_layers does match. + /// @todo check whether graph has a cycle or graph is seperated to subgraph return status; } -- 2.7.4