[neuralnet] check stop_cb by every layer
authorhyeonseok lee <hs89.lee@samsung.com>
Fri, 14 Oct 2022 09:28:20 +0000 (18:28 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 14 Oct 2022 12:36:30 +0000 (21:36 +0900)
 - Check stop_cb before doing forwarding/calcGrading/calcDerivative by every layer
   to stop training more quickly.

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h

index cb0ae10..8d98a24 100644 (file)
@@ -329,8 +329,10 @@ void NetworkGraph::applyGradients(
   }
 }
 
-sharedConstTensors NetworkGraph::forwarding(bool training) {
-  for (auto iter = cbegin(); iter != cend(); iter++) {
+sharedConstTensors
+NetworkGraph::forwarding(bool training,
+                         std::function<bool(void *userdata)> stop_cb) {
+  for (auto iter = cbegin(); iter != cend() && !stop_cb(nullptr); iter++) {
     auto const &ln = *iter;
     PROFILE_TIME_START(profile_keys.at(ln->getType()));
 
@@ -356,7 +358,8 @@ sharedConstTensors NetworkGraph::forwarding(bool training) {
 void NetworkGraph::backwarding(
   int iteration,
   std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
-  std::function<void(Weight &, int)> &apply_grad_clip_op) const {
+  std::function<void(Weight &, int)> &apply_grad_clip_op,
+  std::function<bool(void *userdata)> stop_cb) const {
   /**
    * last layer backwarding is run out of this loop
    */
@@ -374,7 +377,7 @@ void NetworkGraph::backwarding(
     throw std::runtime_error(
       "Error: last layer does not accept label, we can't train");
 
-  for (auto iter = iter_begin; iter != iter_end; iter++) {
+  for (auto iter = iter_begin; iter != iter_end && !stop_cb(nullptr); iter++) {
     auto &ln = *iter;
     PROFILE_TIME_START(profile_keys.at(ln->getType()));
     backwarding_op(ln, iteration);
@@ -1074,7 +1077,9 @@ std::vector<Tensor> NetworkGraph::getOutputTensors() const {
 
 void NetworkGraph::flushCache() { tensor_manager->flushCache(); }
 
-void NetworkGraph::flushCacheExcept(unsigned int order) { tensor_manager->flushCacheExcept(order); }
+void NetworkGraph::flushCacheExcept(unsigned int order) {
+  tensor_manager->flushCacheExcept(order);
+}
 
 void NetworkGraph::requestOptimizerVariable(
   std::function<std::vector<TensorDim>(const TensorDim &)> cb,
index 06a0e80..1272c1f 100644 (file)
@@ -166,7 +166,9 @@ public:
    * @param[in] training true if forwarding is on training
    * @retval output tensors
    */
-  sharedConstTensors forwarding(bool training = false);
+  sharedConstTensors forwarding(bool training = false,
+                                std::function<bool(void *userdata)> stop_cb =
+                                  [](void *user_data) { return false; });
 
   /**
    * @brief     backwarding the network graph
@@ -177,7 +179,10 @@ public:
   void backwarding(
     int iteration,
     std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
-    std::function<void(Weight &, int)> &apply_grad_clip_op) const;
+    std::function<void(Weight &, int)> &apply_grad_clip_op,
+    std::function<bool(void *userdata)> stop_cb = [](void *user_data) {
+      return false;
+    }) const;
 
   /**
    * @brief     get begin iterator for the graph
index ea4e1d3..0228ff7 100644 (file)
@@ -252,8 +252,10 @@ NeuralNetwork::~NeuralNetwork() = default;
 /**
  * @brief     forward propagation using layers object which has layer
  */
-sharedConstTensors NeuralNetwork::forwarding(bool training) {
-  return model_graph.forwarding(training);
+sharedConstTensors
+NeuralNetwork::forwarding(bool training,
+                          std::function<bool(void *userdata)> stop_cb) {
+  return model_graph.forwarding(training, stop_cb);
 }
 
 /**
@@ -281,14 +283,15 @@ sharedConstTensors NeuralNetwork::forwarding(sharedConstTensors input,
  *            Call backwarding function of layer in reverse order
  *            No need to call at first Input Layer (No data to be updated)
  */
-void NeuralNetwork::backwarding(int iteration) {
+void NeuralNetwork::backwarding(int iteration,
+                                std::function<bool(void *userdata)> stop_cb) {
 
 #ifdef DEBUG
   NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!";
 #endif
 
   std::function<void(std::shared_ptr<LayerNode>, int)> backwarding_op =
-    [this](std::shared_ptr<LayerNode> node, int iteration) -> void {
+    [this, stop_cb](std::shared_ptr<LayerNode> node, int iteration) -> void {
     /**
      * Do not change this order:
      * 1. calcGradient
@@ -323,6 +326,10 @@ void NeuralNetwork::backwarding(int iteration) {
 
     model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder()));
 
+    if (stop_cb(nullptr)) {
+      return;
+    }
+
     if (node->needsCalcDerivative())
       node->calcDerivative();
 
@@ -348,7 +355,8 @@ void NeuralNetwork::backwarding(int iteration) {
     opt_->applyGradient(opt_context);
   };
 
-  model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op);
+  model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op,
+                          stop_cb);
 }
 
 void NeuralNetwork::save(const std::string &file_path,
@@ -764,8 +772,8 @@ int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
                                              DataBuffer &buffer) {
     model_graph.flushCache();
 
-    forwarding(true);
-    backwarding(iter++);
+    forwarding(true, stop_cb);
+    backwarding(iter++, stop_cb);
 
     if (!stop_cb(nullptr)) {
       std::cout << "#" << epoch_idx << "/" << getEpochs();
index 70334e6..6bc57ce 100644 (file)
@@ -222,7 +222,9 @@ public:
   /**
    * @brief     Forward Propagation of the neural network
    */
-  sharedConstTensors forwarding(bool training = true);
+  sharedConstTensors forwarding(bool training = true,
+                                std::function<bool(void *userdata)> stop_cb =
+                                  [](void *user_data) { return false; });
 
   /**
    * @brief     Forward Propagation of the neural network
@@ -238,7 +240,8 @@ public:
    * @brief     Backward Propagation of the neural network
    * @param[in] iteration Iteration Number for the optimizer
    */
-  void backwarding(int iteration);
+  void backwarding(int iteration, std::function<bool(void *userdata)> stop_cb =
+                                    [](void *user_data) { return false; });
 
   /**
    * @copydoc Model::save(const std::string &file_path, ml::train::ModelFormat