From 0147cbd92256938876eb4465606ccd59e32e4db5 Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Tue, 23 Nov 2021 19:19:05 +0900 Subject: [PATCH] [props] Extract connection This patch extract connection from common_properties to have more room to handle connections freely **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- Applications/Custom/LayerClient/jni/Android.mk | 1 + Applications/LogisticRegression/jni/Android.mk | 1 + Applications/ProductRatings/jni/Android.mk | 1 + .../ReinforcementLearning/DeepQ/jni/Android.mk | 1 + Applications/Resnet/jni/Android.mk | 1 + .../CIFAR_Classification/jni/Android.mk | 1 + .../Draw_Classification/jni/Android.mk | 1 + Applications/VGG/jni/Android.mk | 1 + debian/nntrainer-dev.install | 1 + jni/Android.mk | 2 + nntrainer/compiler/multiout_realizer.cpp | 12 +- nntrainer/graph/connection.cpp | 56 ++++++++++ nntrainer/graph/connection.h | 121 +++++++++++++++++++++ nntrainer/graph/meson.build | 5 +- nntrainer/layers/common_properties.cpp | 39 +------ nntrainer/layers/common_properties.h | 87 +-------------- nntrainer/layers/layer_node.cpp | 56 ++++++---- nntrainer/layers/layer_node.h | 22 +++- packaging/nntrainer.spec | 1 + test/unittest/unittest_common_properties.cpp | 5 +- 20 files changed, 258 insertions(+), 157 deletions(-) create mode 100644 nntrainer/graph/connection.cpp create mode 100644 nntrainer/graph/connection.h diff --git a/Applications/Custom/LayerClient/jni/Android.mk b/Applications/Custom/LayerClient/jni/Android.mk index 2a6330a..56dd428 100644 --- a/Applications/Custom/LayerClient/jni/Android.mk +++ b/Applications/Custom/LayerClient/jni/Android.mk @@ -17,6 +17,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/utils \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ diff --git a/Applications/LogisticRegression/jni/Android.mk b/Applications/LogisticRegression/jni/Android.mk index cdda9ea..56e58b6 100644 --- a/Applications/LogisticRegression/jni/Android.mk +++ b/Applications/LogisticRegression/jni/Android.mk @@ -16,6 +16,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ diff --git a/Applications/ProductRatings/jni/Android.mk b/Applications/ProductRatings/jni/Android.mk index e349988..7a475d6 100644 --- a/Applications/ProductRatings/jni/Android.mk +++ b/Applications/ProductRatings/jni/Android.mk @@ -20,6 +20,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer/include \ $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/utils \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ diff --git a/Applications/ReinforcementLearning/DeepQ/jni/Android.mk b/Applications/ReinforcementLearning/DeepQ/jni/Android.mk index 7d74e0e..67173ad 100644 --- a/Applications/ReinforcementLearning/DeepQ/jni/Android.mk +++ b/Applications/ReinforcementLearning/DeepQ/jni/Android.mk @@ -16,6 +16,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ diff --git a/Applications/Resnet/jni/Android.mk b/Applications/Resnet/jni/Android.mk index 1d11aa4..c122a3a 100644 --- a/Applications/Resnet/jni/Android.mk +++ b/Applications/Resnet/jni/Android.mk @@ -16,6 +16,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ diff --git a/Applications/TransferLearning/CIFAR_Classification/jni/Android.mk b/Applications/TransferLearning/CIFAR_Classification/jni/Android.mk index 98f7da4..22be25c 100644 --- a/Applications/TransferLearning/CIFAR_Classification/jni/Android.mk +++ b/Applications/TransferLearning/CIFAR_Classification/jni/Android.mk @@ -17,6 +17,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ $(NNTRAINER_ROOT)/nntrainer/utils \ diff --git a/Applications/TransferLearning/Draw_Classification/jni/Android.mk b/Applications/TransferLearning/Draw_Classification/jni/Android.mk index 3c82957..9e933db 100644 --- a/Applications/TransferLearning/Draw_Classification/jni/Android.mk +++ b/Applications/TransferLearning/Draw_Classification/jni/Android.mk @@ -14,6 +14,7 @@ endif NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/tensor \ diff --git a/Applications/VGG/jni/Android.mk b/Applications/VGG/jni/Android.mk index 01354a4..f44e65c 100644 --- a/Applications/VGG/jni/Android.mk +++ b/Applications/VGG/jni/Android.mk @@ -17,6 +17,7 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/graph \ $(NNTRAINER_ROOT)/nntrainer/layers \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/optimizers \ $(NNTRAINER_ROOT)/nntrainer/tensor \ $(NNTRAINER_ROOT)/nntrainer/utils \ diff --git a/debian/nntrainer-dev.install b/debian/nntrainer-dev.install index 4c8031e..cca6b4a 100644 --- a/debian/nntrainer-dev.install +++ b/debian/nntrainer-dev.install @@ -1,5 +1,6 @@ # node exporter and its dependencies /usr/include/nntrainer/nntrainer_error.h +/usr/include/nntrainer/connection.h /usr/include/nntrainer/common_properties.h /usr/include/nntrainer/base_properties.h /usr/include/nntrainer/node_exporter.h diff --git a/jni/Android.mk b/jni/Android.mk index 16a19cd..9db746c 100644 --- a/jni/Android.mk +++ b/jni/Android.mk @@ -187,6 +187,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/reduce_mean_layer.cpp \ $(NNTRAINER_ROOT)/nntrainer/graph/network_graph.cpp \ $(NNTRAINER_ROOT)/nntrainer/graph/graph_core.cpp \ + $(NNTRAINER_ROOT)/nntrainer/graph/connection.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_context.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_devel.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_impl.cpp \ @@ -275,6 +276,7 @@ CCAPI_NNTRAINER_SRCS := $(NNTRAINER_ROOT)/api/ccapi/src/factory.cpp CCAPI_NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/nntrainer/dataset \ + $(NNTRAINER_ROOT)/nntrainer/compiler \ $(NNTRAINER_ROOT)/nntrainer/layers \ $(NNTRAINER_ROOT)/nntrainer/models \ $(NNTRAINER_ROOT)/nntrainer/tensor \ diff --git a/nntrainer/compiler/multiout_realizer.cpp b/nntrainer/compiler/multiout_realizer.cpp index fd9d98e..b079949 100644 --- a/nntrainer/compiler/multiout_realizer.cpp +++ b/nntrainer/compiler/multiout_realizer.cpp @@ -39,8 +39,8 @@ MultioutRealizer::realize(const GraphRepresentation &reference) { for (unsigned int i = 0, num_nodes = node->getNumInputConnections(); i < num_nodes; ++i) { - props::InputConnection c(props::Connection( - node->getInputConnectionName(i), node->getInputConnectionIndex(i))); + Connection c(node->getInputConnectionName(i), + node->getInputConnectionIndex(i)); auto uniq_name = to_string(c); [[maybe_unused]] auto [iter, result] = freq_map.try_emplace(uniq_name, 0); iter->second++; @@ -54,20 +54,18 @@ MultioutRealizer::realize(const GraphRepresentation &reference) { multiout_nodes; for (auto &[con_name, freq] : freq_map) { - props::InputConnection con; - from_string(con_name, con); - /// @note freq < 1 should never happen as the map entry is not created. /// but if it happens multiout realizer is not interested in checking if it /// is a dangled or actually an output. So there is no assurance done at /// this point. Some other class must check if the given graph is formed in /// a correct way. + Connection con(con_name); if (freq <= 1) { continue; } - std::string id = con.get().getName(); - auto idx = con.get().getIndex(); + std::string id = con.getName(); + auto idx = con.getIndex(); std::stringstream ss; /// {connection_name}/generated_out_{index} diff --git a/nntrainer/graph/connection.cpp b/nntrainer/graph/connection.cpp new file mode 100644 index 0000000..853c56a --- /dev/null +++ b/nntrainer/graph/connection.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Jihoon Lee + * + * @file connection.cpp + * @date 23 Nov 2021 + * @see https://github.com/nnstreamer/nntrainer + * @author Jihoon Lee + * @bug No known bugs except for NYI items + * @brief Connection class and related utility functions + */ +#include + +#include + +namespace {} + +namespace nntrainer { +Connection::Connection(const std::string &layer_name, unsigned int idx) : + index(idx), + name(props::Name(layer_name).get()) {} + +Connection::Connection(const std::string &string_representation) { + auto &sr = string_representation; + auto pos = sr.find_first_of('('); + auto idx = 0u; + auto name_part = sr.substr(0, pos); + + if (pos != std::string::npos) { + NNTR_THROW_IF(sr.back() != ')', std::invalid_argument) + << "failed to parse connection invalid format: " << sr; + + auto idx_part = sr.substr(pos + 1, sr.length() - 1); + idx = str_converter::from_string(idx_part); + } + + index = idx; + name = props::Name(name_part); +} + +Connection::Connection(const Connection &rhs) = default; +Connection &Connection::operator=(const Connection &rhs) = default; +Connection::Connection(Connection &&rhs) noexcept = default; +Connection &Connection::operator=(Connection &&rhs) noexcept = default; + +bool Connection::operator==(const Connection &rhs) const noexcept { + return index == rhs.index and name == rhs.name; +} + +std::string Connection::toString() const { + std::stringstream ss; + ss << getName() << '(' << getIndex() << ')'; + return ss.str(); +} + +}; // namespace nntrainer diff --git a/nntrainer/graph/connection.h b/nntrainer/graph/connection.h new file mode 100644 index 0000000..160ddd0 --- /dev/null +++ b/nntrainer/graph/connection.h @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Jihoon Lee + * + * @file connection.h + * @date 23 Nov 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Jihoon Lee + * @bug No known bugs except for NYI items + * @brief Connection class and related utility functions + */ +#ifndef __CONNECTION_H__ +#define __CONNECTION_H__ + +#include + +namespace nntrainer { +/** + * @brief RAII class to define a connection + * + */ +class Connection { +public: + /** + * @brief Construct a new Connection object + * + * @param layer_name layer identifier + */ + Connection(const std::string &layer_name, unsigned int idx); + + /** + * @brief Construct a new Connection object from string representation + * string representation is format of {layer_name, idx}; + * + * @param string_representation string format of {layer_name}({idx}) + */ + explicit Connection(const std::string &string_representation); + + /** + * @brief Construct a new Connection object + * + * @param rhs rhs to copy + */ + Connection(const Connection &rhs); + + /** + * @brief Copy assignment operator + * + * @param rhs rhs to copy + * @return Connection& + */ + Connection &operator=(const Connection &rhs); + + /** + * @brief Move Construct Connection object + * + * @param rhs rhs to move + */ + Connection(Connection &&rhs) noexcept; + + /** + * @brief Move assign a connection operator + * + * @param rhs rhs to move + * @return Connection& + */ + Connection &operator=(Connection &&rhs) noexcept; + + /** + * @brief string representation of connection + * + * @return std::string string format of {name}({idx}) + */ + std::string toString() const; + + /** + * @brief Get the index + * + * @return unsigned index + */ + const unsigned getIndex() const { return index; } + + /** + * @brief Get the index + * + * @return unsigned index + */ + unsigned &getIndex() { return index; } + + /** + * @brief Get the Layer name object + * + * @return const Name& name of layer + */ + const std::string &getName() const { return name; } + + /** + * @brief Get the Layer name object + * + * @return Name& name of layer + */ + std::string &getName() { return name; } + + /** + * + * @brief operator== + * + * @param rhs right side to compare + * @return true if equal + * @return false if not equal + */ + bool operator==(const Connection &rhs) const noexcept; + +private: + unsigned index; + std::string name; +}; + +} // namespace nntrainer + +#endif // __CONNECTION_H__ \ No newline at end of file diff --git a/nntrainer/graph/meson.build b/nntrainer/graph/meson.build index e0e6ad2..ce635d9 100644 --- a/nntrainer/graph/meson.build +++ b/nntrainer/graph/meson.build @@ -1,9 +1,10 @@ graph_sources = [ 'network_graph.cpp', - 'graph_core.cpp' + 'graph_core.cpp', + 'connection.cpp' ] -graph_headers = [] +graph_headers = ['connection.h'] foreach s : graph_sources nntrainer_sources += meson.current_source_dir() / s diff --git a/nntrainer/layers/common_properties.cpp b/nntrainer/layers/common_properties.cpp index d307c9d..27b5431 100644 --- a/nntrainer/layers/common_properties.cpp +++ b/nntrainer/layers/common_properties.cpp @@ -81,19 +81,6 @@ InputConnection::InputConnection() : nntrainer::Property() {} InputConnection::InputConnection(const Connection &value) : nntrainer::Property(value) {} /**< default value if any */ -Connection::Connection(const std::string &layer_name, unsigned int idx) : - index(idx), - name(layer_name) {} - -Connection::Connection(const Connection &rhs) = default; -Connection &Connection::operator=(const Connection &rhs) = default; -Connection::Connection(Connection &&rhs) noexcept = default; -Connection &Connection::operator=(Connection &&rhs) noexcept = default; - -bool Connection::operator==(const Connection &rhs) const noexcept { - return index == rhs.index and name == rhs.name; -}; - Epsilon::Epsilon(float value) { set(value); } bool Epsilon::isValid(const float &value) const { return value > 0.0f; } @@ -310,31 +297,15 @@ void GenericShape::set(const TensorDim &value) { } // namespace props template <> -std::string -str_converter::to_string( - const props::Connection &value) { - std::stringstream ss; - ss << value.getName().get() << '(' << value.getIndex() << ')'; - return ss.str(); +std::string str_converter::to_string( + const Connection &value) { + return value.toString(); } template <> -props::Connection -str_converter::from_string( +Connection str_converter::from_string( const std::string &value) { - auto pos = value.find_first_of('('); - auto idx = 0u; - auto name_part = value.substr(0, pos); - - if (pos != std::string::npos) { - NNTR_THROW_IF(value.back() != ')', std::invalid_argument) - << "failed to parse connection invalid format: " << value; - - auto idx_part = value.substr(pos + 1, value.length() - 1); - idx = str_converter::from_string(idx_part); - } - - return props::Connection(name_part, idx); + return Connection(value); } } // namespace nntrainer diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index ea136be..1f84257 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -138,92 +139,6 @@ public: }; /** - * @brief RAII class to define the connection - * - */ -class Connection { -public: - /** - * @brief Construct a new Connection object - * - * @param layer_name layer identifier - */ - Connection(const std::string &layer_name, unsigned int idx); - - /** - * @brief Construct a new Connection object - * - * @param rhs rhs to copy - */ - Connection(const Connection &rhs); - - /** - * @brief Copy assignment operator - * - * @param rhs rhs to copy - * @return Connection& - */ - Connection &operator=(const Connection &rhs); - - /** - * @brief Move Construct Connection object - * - * @param rhs rhs to move - */ - Connection(Connection &&rhs) noexcept; - - /** - * @brief Move assign a connection operator - * - * @param rhs rhs to move - * @return Connection& - */ - Connection &operator=(Connection &&rhs) noexcept; - - /** - * @brief Get the index - * - * @return unsigned index - */ - const unsigned getIndex() const { return index; } - - /** - * @brief Get the index - * - * @return unsigned index - */ - unsigned &getIndex() { return index; } - - /** - * @brief Get the Layer name object - * - * @return const Name& name of layer - */ - const Name &getName() const { return name; } - - /** - * @brief Get the Layer name object - * - * @return Name& name of layer - */ - Name &getName() { return name; } - - /** - * - * @brief operator== - * - * @param rhs right side to compare - * @return true if equal - * @return false if not equal - */ - bool operator==(const Connection &rhs) const noexcept; - -private: - unsigned index; - Name name; -}; - -/** * @brief Connection prop tag type * */ diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index d9e109d..f809338 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -29,9 +29,12 @@ #include namespace nntrainer { + +#ifdef PROFILE static constexpr const char *FORWARD_SUFFIX = ":forward"; static constexpr const char *CALC_DERIV_SUFFIX = ":calcDeriv"; static constexpr const char *CALC_GRAD_SUFFIX = ":calcGrad"; +#endif namespace props { @@ -166,7 +169,7 @@ LayerNode::LayerNode(std::unique_ptr &&l) : inplace(InPlace::NONE), needs_calc_derivative(false), needs_calc_gradient(false), - output_layers(new std::vector()), + output_layers(), run_context(nullptr), layer_node_props( new PropsType(props::Name(), props::Distribute(), props::Trainable(), {}, @@ -239,8 +242,7 @@ std::ostream &operator<<(std::ostream &out, const LayerNode &l) { }; print_vector(input_layers, " input_layers"); - /// @todo enable this - // print_vector(l.output_layers, "output_layers"); + // print_vector(l.output_layers, "output_layers"); return out; } @@ -260,7 +262,7 @@ unsigned int LayerNode::getNumInputConnections() const { } unsigned int LayerNode::getNumOutputConnections() const { - return output_layers->size(); + return output_layers.size(); } const std::vector LayerNode::getInputLayers() const { @@ -270,16 +272,21 @@ const std::vector LayerNode::getInputLayers() const { names.reserve(input_layers.size()); std::transform(input_layers.begin(), input_layers.end(), std::back_inserter(names), - [](const props::Connection &con) { return con.getName(); }); + [](const Connection &con) { return con.getName(); }); return names; } const std::vector LayerNode::getOutputLayers() const { std::vector names; - names.reserve(output_layers->size()); - std::transform(output_layers->begin(), output_layers->end(), - std::back_inserter(names), - [](const props::Connection &con) { return con.getName(); }); + names.reserve(output_layers.size()); + + for (auto &output_layer : output_layers) { + if (output_layer == nullptr) { + ml_logw("intermediate output is empty for layer: %s", getName().c_str()); + continue; + } + names.push_back(output_layer->getName()); + } return names; } @@ -342,11 +349,11 @@ nntrainer::Layer *LayerNode::getLayer() { void LayerNode::addInputLayers(const std::string &in_layer) { auto &input_layers = std::get>(*layer_node_props); - input_layers.emplace_back(props::Connection(in_layer, 0)); + input_layers.emplace_back(Connection(in_layer, 0)); } void LayerNode::addOutputLayers(const std::string &out_layer) { - output_layers->emplace_back(out_layer, 0); + output_layers.emplace_back(new Connection(out_layer, 0)); } void LayerNode::setInputLayers(const std::vector &layers) { @@ -356,17 +363,16 @@ void LayerNode::setInputLayers(const std::vector &layers) { input_layers.reserve(layers.size()); std::transform(layers.begin(), layers.end(), std::back_inserter(input_layers), [](const std::string &id) { - return props::Connection{id, 0}; + return Connection{id, 0}; }); } void LayerNode::setOutputLayers(const std::vector &layers) { - output_layers->clear(); - output_layers->reserve(layers.size()); - std::transform(layers.begin(), layers.end(), - std::back_inserter(*output_layers), [](const std::string &id) { - return props::Connection{id, 0}; - }); + output_layers.clear(); + output_layers.reserve(layers.size()); + std::transform( + layers.begin(), layers.end(), std::back_inserter(output_layers), + [](const std::string &id) { return std::make_unique(id); }); } bool LayerNode::hasInputShapeProperty() const { @@ -499,8 +505,8 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims) { layer_node_props_realization = std::make_unique( props::Flatten(), props::Activation()); - auto num_outputs = output_layers->size(); - if (output_layers->empty()) { + auto num_outputs = output_layers.size(); + if (output_layers.empty()) { num_outputs = 1; } @@ -699,9 +705,13 @@ void LayerNode::remapConnections( remap_fn(name, idx); } - for (auto &output_layer : *output_layers) { - auto &name = output_layer.getName(); - auto &idx = output_layer.getIndex(); + for (auto &output_layer : output_layers) { + if (output_layer == nullptr) { + continue; + } + + auto &name = output_layer->getName(); + auto &idx = output_layer->getIndex(); remap_fn(name, idx); } } diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 53997f5..d801b36 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -37,7 +37,7 @@ namespace nntrainer { class Layer; - +class Connection; class Exporter; enum class ExportMethods; @@ -49,7 +49,6 @@ class Loss; class InputShape; class Activation; class SharedFrom; -class Connection; class InputConnection; class ClipGradByGlobalNorm; } // namespace props @@ -168,6 +167,23 @@ public: void setInputConnectionName(unsigned nth, const std::string &name); /** + * @brief Set the Output Connection + * @note Each output must be identified only ONCE. + * @note when nth comes, getNumOutput() expends to nth + 1 as resize occurs. + * Please also notice none identified intermediate output (or mismatch between + * actual number of out tensors and output) is allowed but will produce + * warning, this implies that the output is not used else where. + * @throw std::invalid_argument when trying to identify output + * more then once + * + * @param nth nth output + * @param name name of the output bound connection + * @param index index of the output bound connection + */ + void setOutputConnection(unsigned nth, const std::string &name, + unsigned index); + + /** * @brief Get the input connections for this node * * @return list of name of the nodes which form input connections @@ -742,7 +758,7 @@ private: calcDerivative */ bool needs_calc_gradient; /**< cache if this layer needs to do calcGradient */ - std::unique_ptr> + std::vector> output_layers; /**< output layer names */ std::unique_ptr diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 303b82c..6dc820e 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -445,6 +445,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %files devel # node exporter and its dependencies %{_includedir}/nntrainer/nntrainer_error.h +%{_includedir}/nntrainer/connection.h %{_includedir}/nntrainer/common_properties.h %{_includedir}/nntrainer/base_properties.h %{_includedir}/nntrainer/node_exporter.h diff --git a/test/unittest/unittest_common_properties.cpp b/test/unittest/unittest_common_properties.cpp index 6e93fc0..242a880 100644 --- a/test/unittest/unittest_common_properties.cpp +++ b/test/unittest/unittest_common_properties.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -104,7 +105,7 @@ TEST(NameProperty, mustStartWithAlphaNumeric_01_n) { TEST(InputConnection, setPropertyValid_p) { using namespace nntrainer::props; { - InputConnection expected(Connection("a", 0)); + InputConnection expected(nntrainer::Connection("a", 0)); InputConnection actual; nntrainer::from_string("A", actual); @@ -113,7 +114,7 @@ TEST(InputConnection, setPropertyValid_p) { } { - InputConnection expected(Connection("a", 2)); + InputConnection expected(nntrainer::Connection("a", 2)); InputConnection actual; nntrainer::from_string("a(2)", actual); -- 2.7.4