[NEURALNET] Refine forwarding operation
authorJiho Chu <jiho.chu@samsung.com>
Fri, 3 Mar 2023 06:29:41 +0000 (15:29 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 12 Mar 2023 23:00:06 +0000 (08:00 +0900)
This patch refines forwading operation in neural network class.

The code depth is not matched for forward and backward operations. For
backwarding operation, there is backwarding_op and it pasded to the
graph, which can handle the operation.
To keep same depth for forwarding, this patch adds forwarding_op
function, and passed to the graph class.

Releated Issue:
\#2108

Signed-off-by: Jiho Chu <jiho.chu@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/models/neuralnet.cpp

index 47b4953..0384cb9 100644 (file)
@@ -344,19 +344,15 @@ void NetworkGraph::applyGradients(
   }
 }
 
-sharedConstTensors
-NetworkGraph::forwarding(bool training,
-                         std::function<bool(void *userdata)> stop_cb) {
+sharedConstTensors NetworkGraph::forwarding(
+  bool training,
+  std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op,
+  std::function<bool(void *userdata)> stop_cb) {
   for (auto iter = cbegin(); iter != cend() && !stop_cb(nullptr); iter++) {
-    auto const &ln = *iter;
-    PROFILE_TIME_START(profile_keys.at(ln->getType()));
-    PROFILE_MEM_ANNOTATE("Forwarding for layer: " + ln->getName());
-
-    auto f = std::get<0>(ln->getExecutionOrder());
-    flushCacheExcept(f);
-
-    ln->forwarding(training);
+    auto &ln = *iter;
 
+    PROFILE_TIME_START(profile_keys.at(ln->getType()));
+    forwarding_op(*iter, training);
     PROFILE_TIME_END(profile_keys.at(ln->getType()));
   }
 
index 0ba38d5..362f86d 100644 (file)
@@ -170,9 +170,12 @@ public:
    * @param[in] training true if forwarding is on training
    * @retval output tensors
    */
-  sharedConstTensors forwarding(bool training = false,
-                                std::function<bool(void *userdata)> stop_cb =
-                                  [](void *user_data) { return false; });
+  sharedConstTensors forwarding(
+    bool training = false,
+    std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
+      [](std::shared_ptr<LayerNode>, bool) {},
+    std::function<bool(void *userdata)> stop_cb =
+      [](void *user_data) { return false; });
 
   /**
    * @brief     backwarding the network graph
index 03a217d..feb1284 100644 (file)
@@ -267,7 +267,18 @@ NeuralNetwork::~NeuralNetwork() { deallocate(); }
 sharedConstTensors
 NeuralNetwork::forwarding(bool training,
                           std::function<bool(void *userdata)> stop_cb) {
-  return model_graph.forwarding(training, stop_cb);
+  std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
+    [this, stop_cb](std::shared_ptr<LayerNode> node, bool training) -> void {
+    (void)this;
+    PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName());
+
+    auto f = std::get<0>(node->getExecutionOrder());
+    model_graph.flushCacheExcept(f);
+
+    node->forwarding(training);
+  };
+
+  return model_graph.forwarding(training, forwarding_op, stop_cb);
 }
 
 /**