From 5190a104aa7496450c1a311d0d8a533e00a4a237 Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Tue, 23 Nov 2021 22:10:05 +0900 Subject: [PATCH] [graph] implement setOutputConnections THis patch implement setOutputConnections. Now, every connection has defined place to be, we can pin point where the connection has to go. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/graph/network_graph.cpp | 35 +++++++++++------------------------ nntrainer/graph/network_graph.h | 2 +- nntrainer/layers/layer_node.cpp | 9 +++++++++ nntrainer/layers/layer_node.h | 14 +++++++------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index b4decc2..b508879 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -46,7 +46,7 @@ int NetworkGraph::compile(const std::string &loss_type) { NN_RETURN_STATUS(); try { - setOutputLayers(); + setOutputConnections(); } catch (std::exception &e) { ml_loge("setting output layer failed, reason: %s", e.what()); return ML_ERROR_INVALID_PARAMETER; @@ -164,29 +164,16 @@ int NetworkGraph::addLossLayer(const std::string &loss_type_) { return ML_ERROR_NONE; } -void NetworkGraph::setOutputLayers() { - for (auto iter_idx = cbegin(); iter_idx != cend(); iter_idx++) { - auto &layer_idx = *iter_idx; - for (auto iter_i = cbegin(); iter_i != cend(); iter_i++) { - auto &layer_i = *iter_i; - if (istrequal(layer_i->getName(), layer_idx->getName())) - continue; - for (unsigned int j = 0; j < layer_i->getNumInputConnections(); ++j) { - if (istrequal(layer_i->getInputLayers()[j], layer_idx->getName())) { - bool already_exist = false; - for (unsigned int k = 0; k < layer_idx->getNumOutputConnections(); - ++k) { - if (istrequal(layer_idx->getOutputLayers()[k], - layer_i->getName())) { - already_exist = true; - break; - } - } - - if (!already_exist) - layer_idx->addOutputLayers(layer_i->getName()); - } - } +void NetworkGraph::setOutputConnections() { + for (auto layer_iter = cbegin(); layer_iter != cend(); layer_iter++) { + const auto &node = *layer_iter; + for (auto i = 0u, num_inode = node->getNumInputConnections(); i < num_inode; + ++i) { + const auto &name = node->getInputConnectionName(i); + const auto &idx = node->getInputConnectionIndex(i); + + auto node_setting_output = getLayerNode(name); + node_setting_output->setOutputConnection(idx, node->getName(), i); } } } diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 8ba7b88..23a13de 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -413,7 +413,7 @@ private: /** * @brief set output connections for all the layers */ - void setOutputLayers(); + void setOutputConnections(); /** * @brief Ensure that layer has a name. diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index f809338..84be39d 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -222,6 +222,15 @@ void LayerNode::setInputConnectionName(unsigned nth, const std::string &name) { input_layers.at(nth).get().getName() = name; } +void LayerNode::setOutputConnection(unsigned nth, const std::string &name, + unsigned index) { + if (nth >= output_layers.size()) { + output_layers.resize(nth + 1); + } + + output_layers[nth] = std::make_unique(name, index); +} + const std::string LayerNode::getName() const noexcept { auto &name = std::get(*layer_node_props); return name.empty() ? "" : name.get(); diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index d801b36..2bda7fa 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -411,13 +411,6 @@ public: } /** - * @brief Get the Input Layers object - * - * @return const std::vector - */ - const std::vector getInputLayers() const; - - /** * @brief Get the Output Layers object * * @return const std::vector @@ -749,6 +742,13 @@ public: bool needsCalcGradient() { return needs_calc_gradient; } private: + /** + * @brief Get the Input Layers object + * + * @return const std::vector + */ + const std::vector getInputLayers() const; + std::unique_ptr layer; /**< The actual object in the graph node */ -- 2.7.4