[graph] update getter of input/output dims
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 6 Oct 2021 16:19:09 +0000 (01:19 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 8 Oct 2021 09:30:56 +0000 (18:30 +0900)
This patch update input/output dims to properly reflect model input,
output dimensions, not just a single object.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h

index 8619f67..5cc35f7 100644 (file)
@@ -496,17 +496,19 @@ sharedConstTensors NetworkGraph::forwarding(bool training) const {
 }
 
 std::vector<TensorDim> NetworkGraph::getInputDimension() const {
-  NNTR_THROW_IF(this->empty(), std::invalid_argument)
-    << "[NetworkGraph] the graph has no node!";
-  return getSortedLayerNode(0)->getInputDimensions();
+  NNTR_THROW_IF(input_dims.empty(), std::invalid_argument)
+    << "[NetworkGraph] the graph has no node identified as input!";
+  return input_dims;
 }
 
 unsigned int NetworkGraph::getBatchSize() const { return batch_size; }
 
 std::vector<TensorDim> NetworkGraph::getOutputDimension() const {
-  NNTR_THROW_IF(this->empty(), std::invalid_argument)
-    << "[NetworkGraph] the graph has no node!";
-  return getSortedLayerNode(graph.size() - 1)->getOutputDimensions();
+  NNTR_THROW_IF(label_dims.empty(), std::invalid_argument)
+    << "[NetworkGraph] the graph has no node identified as output!";
+  /// for now, outputting label_dims works, later label dim will be different
+  /// from output dimension
+  return label_dims;
 }
 
 std::vector<std::shared_ptr<LayerNode>>
@@ -857,6 +859,7 @@ int NetworkGraph::initialize(
       << num_input;
 
     input_list.push_back(node->getInput(0).getName());
+    input_dims.push_back(node->getInputDimensions()[0]);
   };
 
   auto is_label_node = [](LayerNode *node) { return node->requireLabel(); };
@@ -872,6 +875,7 @@ int NetworkGraph::initialize(
       << num_label;
 
     label_list.push_back(node->getOutputGrad(0).getName());
+    label_dims.push_back(node->getOutputDimensions()[0]);
   };
 
   auto identify_external_tensors = [this](const std::vector<std::string> &names,
index 217bf66..8f564ed 100644 (file)
@@ -95,6 +95,7 @@ public:
    * @brief     Swap function for the class
    */
   friend void swap(NetworkGraph &lhs, NetworkGraph &rhs) {
+    /// @fixme this swap function need maintenance
     using std::swap;
 
     swap(lhs.graph, rhs.graph);
@@ -382,8 +383,11 @@ private:
   unsigned int batch_size;     /**< current batch_size */
   // std::vector<Var_Grad *> label_list; /**< var_grads for the labels */
   // std::vector<Var_Grad *> input_list; /**< var_grads for the inputs */
-  std::vector<std::string> label_list; /**< var_grads for the labels */
-  std::vector<std::string> input_list; /**< var_grads for the inputs */
+  std::vector<std::string> label_list; /**< identifier for the model labels */
+  std::vector<std::string> input_list; /**< identifier for the model inputs */
+  std::vector<TensorDim> label_dims;   /**< graph label dimensions */
+  std::vector<TensorDim> input_dims;   /**< graph input dimensions */
+
   ExecutionMode exec_mode; /**< execution mode with which the graph has been
                               currently set or previously set */