[layernode] Update getNumInputs/Outputs
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 1 Jul 2021 08:37:15 +0000 (17:37 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 22 Jul 2021 11:47:24 +0000 (20:47 +0900)
Separate getNumInputs/Outputs semantics for inputs/outputs
and for connections as a node. Number of inputs/outputs will always be
more than 1, but number of input/output connections can be 0 for
input/output nodes of the graph.
This patch separates the two concepts, and its usage.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/compiler/tflite_opnode.cpp
nntrainer/graph/graph_core.cpp
nntrainer/graph/network_graph.cpp
nntrainer/layers/layer_node.h
nntrainer/models/neuralnet.cpp
nntrainer/tensor/var_grad.h
nntrainer/tensor/weight.h
test/unittest/unittest_nntrainer_models.cpp

index 784f6ef..f9ae09e 100644 (file)
 namespace nntrainer {
 
 void TfOpNode::setInOut(const LayerNode &layer) {
-  auto &in = layer.getInputLayers();
-  is_input = std::find(in.begin(), in.end(), "__data__") != in.end();
-
-  auto &out = layer.getOutputLayers();
-  is_output = std::find(out.begin(), out.end(), "__exit__") != out.end();
+  is_input = layer.getInputLayers().size() == 0;
+  is_output = layer.getOutputLayers().size() == 0;
 }
 
 void TfOpNode::setInputs(
index c183cf3..e79def2 100644 (file)
@@ -55,8 +55,6 @@ void GraphCore::makeAdjacencyList(
   /** make the connections */
   for (auto &node : node_list) {
     for (auto const &in_conn : node->getInputConnections()) {
-      if (istrequal(in_conn, "__data__"))
-        continue;
       unsigned int to_node_id = getNode(in_conn)->getIndex();
       adj[to_node_id].push_back(node);
     }
index c243815..0b6fe9f 100644 (file)
@@ -77,7 +77,7 @@ void NetworkGraph::addDefaultInputLayers() {
   for (auto iter = cbegin() + 1; iter != cend(); iter++) {
     auto layer = *iter;
     auto prev_layer = *(iter - 1);
-    if (layer->getNumInputs() == 0) {
+    if (layer->getNumInputConnections() == 0) {
       layer->addInputLayers(prev_layer->getName());
     }
   }
@@ -111,7 +111,7 @@ int NetworkGraph::realizeMultiInputType(
    * Multi-input works with time distribution layer by itself
    *
    */
-  if (in_node->getNumInputs() <= 1)
+  if (in_node->getNumInputConnections() <= 1)
     return ML_ERROR_NONE;
 
   // TODO: this can be addition or concat layer - add support
@@ -197,7 +197,7 @@ int NetworkGraph::realizeMultiOutputType(
    *
    */
 
-  if (in_node->getNumOutputs() <= 1)
+  if (in_node->getNumOutputConnections() <= 1)
     return ML_ERROR_NONE;
 
   std::shared_ptr<LayerNode> lnode = createLayerNode(OutputLayer::type);
@@ -208,7 +208,7 @@ int NetworkGraph::realizeMultiOutputType(
 
   in_node->setOutputLayers({lnode->getName()});
 
-  for (unsigned int i = 0; i < in_node->getNumOutputs(); ++i) {
+  for (unsigned int i = 0; i < in_node->getNumOutputConnections(); ++i) {
     updateConnectionName(in_node->getName(), lnode->getName());
   }
 
@@ -307,10 +307,11 @@ void NetworkGraph::setOutputLayers() {
       auto &layer_i = *iter_i;
       if (istrequal(layer_i->getName(), layer_idx->getName()))
         continue;
-      for (unsigned int j = 0; j < layer_i->getNumInputs(); ++j) {
+      for (unsigned int j = 0; j < layer_i->getNumInputConnections(); ++j) {
         if (istrequal(layer_i->getInputLayers()[j], layer_idx->getName())) {
           bool already_exist = false;
-          for (unsigned int k = 0; k < layer_idx->getNumOutputs(); ++k) {
+          for (unsigned int k = 0; k < layer_idx->getNumOutputConnections();
+               ++k) {
             if (istrequal(layer_idx->getOutputLayers()[k],
                           layer_i->getName())) {
               already_exist = true;
@@ -324,7 +325,7 @@ void NetworkGraph::setOutputLayers() {
       }
     }
 
-    if (layer_idx->getNumOutputs() == 0) {
+    if (layer_idx->getNumOutputConnections() == 0) {
       last_layer_count += 1;
     }
   }
@@ -381,7 +382,7 @@ int NetworkGraph::realizeGraph() {
 
     /** If a layer does not has input nodes, then it must have input dimension
      */
-    if (lnode->getNumInputs() == 0) {
+    if (lnode->getNumInputConnections() == 0) {
       for (unsigned int i = 0; i < lnode->getInputDimensions().size(); ++i) {
         if (lnode->getInputDimensions()[i].getDataLen() == 0) {
           ml_loge("Input Dimension must be set");
@@ -389,8 +390,6 @@ int NetworkGraph::realizeGraph() {
           NN_RETURN_STATUS();
         }
       }
-
-      lnode->setInputLayers({"__data__"});
     }
 
     if (lnode->getType() != AdditionLayer::type &&
@@ -594,7 +593,7 @@ void NetworkGraph::inPlaceOptimize(Manager &manager) {
     if (l->supportInPlace()) {
       /** @note assumes layer to be optimized is only for single in/out tensor
        */
-      if (layer_node->getNumInputs() != 1)
+      if (layer_node->getNumInputConnections() != 1)
         throw std::runtime_error("Internal error in the formed graph");
 
       auto prev_node = getLayerNode(layer_node->getInputLayers()[0]);
@@ -733,7 +732,8 @@ int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
         auto in_layer_node = getLayerNode(input_layers[i]);
 
         unsigned int location = 0;
-        for (unsigned int j = 0; j < in_layer_node->getNumOutputs(); ++j) {
+        for (unsigned int j = 0; j < in_layer_node->getNumOutputConnections();
+             ++j) {
           if (in_layer_node->getOutputLayers()[j] == lnode->getName()) {
             location = j;
             break;
@@ -789,14 +789,14 @@ int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
         input_map.insert({output_layers[i], {}});
 
       unsigned int j = 0;
-      for (; j < out_layer_node->getNumInputs(); ++j) {
+      for (; j < out_layer_node->getNumInputConnections(); ++j) {
         if (out_layer_node->getInputLayers()[j] == lnode->getName()) {
           break;
         }
       }
 
       auto &in_map = input_map.at(output_layers[i]);
-      in_map.resize(out_layer_node->getNumInputs());
+      in_map.resize(out_layer_node->getNumInputConnections());
       in_map[j] = outputs[i];
     }
 #else
@@ -810,7 +810,8 @@ int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
         auto in_layer_node = getLayerNode(input_layers[i]);
 
         unsigned int location = 0;
-        for (unsigned int j = 0; j < in_layer_node->getNumOutputs(); ++j) {
+        for (unsigned int j = 0; j < in_layer_node->getNumOutputConnections();
+             ++j) {
           if (in_layer_node->getOutputLayers()[j] == lnode->getName()) {
             location = j;
             break;
index 4a5ef2a..f89ef7f 100644 (file)
@@ -313,16 +313,33 @@ public:
   ActivationType getActivationType() const;
 
   /**
+   * @brief     Get number of input connections
+   * @retval    number of inputs
+   */
+  unsigned int getNumInputConnections() const { return input_layers.size(); }
+
+  /**
+   * @brief     Get number of output connections
+   * @retval    number of outputs
+   */
+  unsigned int getNumOutputConnections() const { return output_layers.size(); }
+
+  /**
    * @brief     Get number of inputs
    * @retval    number of inputs
    */
-  unsigned int getNumInputs() const { return input_layers.size(); }
+  unsigned int getNumInputs() const { return input_dim.size(); }
 
   /**
    * @brief     Get number of outputs
    * @retval    number of outputs
    */
-  unsigned int getNumOutputs() const { return output_layers.size(); }
+  unsigned int getNumOutputs() const {
+    if (finalized)
+      return init_context.getOutputDimensions().size();
+    else
+      return getNumOutputConnections();
+  }
 
   /**
    * @brief Get the number of weights
@@ -449,8 +466,8 @@ public:
    */
   Weight getWeightWrapper(unsigned int idx) {
     if (layerv1 == nullptr) {
-      return Weight(run_context.getWeight(idx),
-            run_context.getWeightGrad(idx), run_context.getWeightName(idx));
+      return Weight(run_context.getWeight(idx), run_context.getWeightGrad(idx),
+                    run_context.getWeightName(idx));
     } else {
       return getLayer()->getWeightsRef()[idx];
     }
index 243db27..92736f4 100644 (file)
@@ -230,19 +230,19 @@ sharedConstTensors NeuralNetwork::forwarding(sharedConstTensors input,
     << " label_batch: " << label[0]->batch() << " target_batch: " << batch_size;
 
   auto fill_label = [&label](auto const &layer_node) {
-    NNTR_THROW_IF(label.size() != layer_node->getOutputDimensions().size(),
+    NNTR_THROW_IF(label.size() != layer_node->getNumOutputs(),
                   std::invalid_argument)
       << "label size does not match with the layer requirements"
       << " layer: " << layer_node->getName() << " label size: " << label.size()
       << " requirements size: " << layer_node->getNumOutputs();
 
-    for (unsigned int i = 0; i < layer_node->getOutputDimensions().size(); i++) {
+    for (unsigned int i = 0; i < layer_node->getNumOutputs(); i++) {
       layer_node->getOutputGrad(i) = *label[i];
     }
   };
 
   auto clear_label = [](auto const &layer_node) {
-    for (unsigned int i = 0; i < layer_node->getOutputDimensions().size(); i++) {
+    for (unsigned int i = 0; i < layer_node->getNumOutputs(); i++) {
       layer_node->getOutputGrad(i) = Tensor();
     }
   };
index 9185af6..0359c74 100644 (file)
@@ -75,19 +75,18 @@ public:
    * @param g Already created gradient object
    * @param n Name for this Var_Grad
    *
-   * @note This API is not recommended for usage and must be used for internal uses only,
-   * as Var_Grad does not own the tensors v and g,
-   * and can go invalid if the owner of these tensors free the tensors.
+   * @note This API is not recommended for usage and must be used for internal
+   * uses only, as Var_Grad does not own the tensors v and g, and can go invalid
+   * if the owner of these tensors free the tensors.
    */
-  explicit Var_Grad(const Tensor &v,
-      const Tensor &g, const std::string &n = ""):
+  explicit Var_Grad(const Tensor &v, const Tensor &g,
+                    const std::string &n = "") :
     dim(v.getDim()),
     var(std::make_shared<Tensor>(v.getSharedDataTensor(dim, 0, false))),
     grad(std::make_shared<Tensor>(g.getSharedDataTensor(dim, 0, false))),
     trainable(!g.uninitialized()),
     alloc_now(v.isAllocated()),
-    name(n) {
-    }
+    name(n) {}
 
   /**
    * @brief Copy constructor for Var_Grad
index a36fae3..e67e846 100644 (file)
@@ -103,7 +103,6 @@ public:
            std::get<5>(spec) // Name
     ) {}
 
-
   /**
    * @brief Construct a new Weight object
    *
@@ -114,15 +113,13 @@ public:
    * @note This is primarily used to created wrapper of variable extracted from
    * context. If needed, add support for regularizer, and opt_vars.
    *
-   * @note This API is not recommended for usage and must be used for internal uses only,
-   * as Weight does not own the tensors v and g,
-   * and can go invalid if the owner of these tensors free the tensors.
+   * @note This API is not recommended for usage and must be used for internal
+   * uses only, as Weight does not own the tensors v and g, and can go invalid
+   * if the owner of these tensors free the tensors.
    */
-  explicit Weight(const Tensor &v,
-      const Tensor &g, const std::string &n = ""):
+  explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "") :
     Var_Grad(v, g, n) {}
 
-
   /**
    * @copydoc var_grad::initializeVariable(const Tensor &)
    */
index f982ee8..94a631b 100644 (file)
@@ -266,7 +266,7 @@ void NodeWatcher::forward(int iteration, NodeWatcher &next_node) {
   std::string err_msg = ss.str();
 
   std::vector<nntrainer::Tensor> out;
-  for (unsigned int idx = 0; idx < node->getNumOutputs(); idx ++) {
+  for (unsigned int idx = 0; idx < node->getNumOutputs(); idx++) {
     out.push_back(node->getOutput(idx));
   }
 
@@ -287,7 +287,7 @@ void NodeWatcher::backward(int iteration, bool verify_deriv, bool verify_grad) {
   std::string err_msg = ss.str();
 
   std::vector<nntrainer::Tensor> out;
-  for (unsigned int idx = 0; idx < node->getNumInputs(); idx ++) {
+  for (unsigned int idx = 0; idx < node->getNumInputs(); idx++) {
     out.push_back(node->getInputGrad(idx));
   }
 
@@ -1345,7 +1345,8 @@ INSTANTIATE_TEST_CASE_P(
 // TEST(nntrainerModels, read_save_01_n) {
 //   nntrainer::NeuralNetwork NN;
 //   std::shared_ptr<nntrainer::LayerNode> layer_node =
-//     nntrainer::createLayerNode(nntrainer::InputLayer::type, {"input_shape=1:1:62720", "normalization=true"});
+//     nntrainer::createLayerNode(nntrainer::InputLayer::type,
+//     {"input_shape=1:1:62720", "normalization=true"});
 //
 //   EXPECT_NO_THROW(NN.addLayer(layer_node));
 //   EXPECT_NO_THROW(NN.setProperty({"loss=mse"}));