[sub-plugin] Modify synchronization mechanism between push_data and getSample
authorhyunil park <hyunil46.park@samsung.com>
Tue, 29 Aug 2023 09:23:25 +0000 (18:23 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 01:21:23 +0000 (10:21 +0900)
Modify synchronization mechanism between push_data and getSample
- Remove some member variable
- Add member function to check queue

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyunil park <hyunil46.park@samsung.com>
nnstreamer/tensor_trainer/tensor_trainer_nntrainer.cc
nnstreamer/tensor_trainer/tensor_trainer_nntrainer.hh

index 9b06345fc086f1ccfb4215065389207ecf29b8aa..d942ee6c77eeb522b6ce7e6a269174b425a16b35 100644 (file)
@@ -113,6 +113,29 @@ void nntrainer_thread_func(NNTrainer::NNTrainerTrain *nntrainer) {
   nntrainer->trainModel();
 }
 
+/**
+ * @brief Check if queue is empty
+ */
+bool NNTrainer::InputTensorsInfo::isQueueEmpty() {
+  ml_logd("front:%d, rear:%d", queue_front, queue_rear);
+  if (queue_front == queue_rear) {
+    ml_logd("queue is empty ");
+    return TRUE;
+  }
+  return FALSE;
+}
+
+/**
+ * @brief Check if queue is full
+ */
+bool NNTrainer::InputTensorsInfo::isQueueFull() {
+  if (((queue_rear + 1) % queue_size) == queue_front) {
+    ml_logd("queue is full");
+    return TRUE;
+  }
+  return FALSE;
+}
+
 /**
  * @brief push_data function
  * tensor_trainer call this function to push tensor data.
@@ -164,6 +187,10 @@ static int nntrainer_model_push_data(const GstTensorTrainerFramework *fw,
   ml_logd("number of inputs(%d) and labels(%d)", nntrainer->num_inputs,
           nntrainer->num_labels);
 
+  std::unique_lock<std::mutex> lock(data->queue_lock);
+  data->data_full.wait(lock, [&] { return !data->isQueueFull(); });
+  ml_logd("nntrainer_model_push_data condition is met");
+
   unsigned int idx = 0, i = 0;
   i = data->queue_rear;
   ml_logd("Insert, queue_rear : %d", i);
@@ -182,30 +209,18 @@ static int nntrainer_model_push_data(const GstTensorTrainerFramework *fw,
   }
 
   data->push_count++;
-  data->queue_count++;
   data->queue_rear = (data->queue_rear + 1) % data->queue_size;
-  ml_logd("front:%d, rear:%d, filled:%d", data->queue_front, data->queue_rear,
-          data->queue_count);
-
-  if (data->is_data_wait_locked && data->queue_count > 0) {
-    data->data_wait.notify_one();
-    ml_logd("send signal");
-  }
-
-  if (data->queue_count == data->queue_size) {
-    data->is_data_full_locked = TRUE;
-    ml_logd("locked, data is full");
-    std::unique_lock<std::mutex> lock(data->data_full_lock);
-    data->data_full.wait(lock);
-    ml_logd("unlocked, queue is empty");
-    data->is_data_full_locked = FALSE;
-  }
+  int queue_count = (data->queue_rear - data->queue_front + data->queue_size) %
+                    data->queue_size;
 
+  ml_logd("front:%d, rear:%d, filled:%d", data->queue_front, data->queue_rear,
+          queue_count);
   ml_logd("(pop/push: %d/%d)", data->pop_count, data->push_count);
   ml_logd("T-pushed: %d/%d, V-pushed:%d/%d\n",
           nntrainer->train_data->push_count, nntrainer->num_training_samples,
           nntrainer->valid_data->push_count, nntrainer->num_validation_samples);
-
+  lock.unlock();
+  data->data_empty.notify_one();
   ml_logd("<leaved>");
   return 0;
 }
@@ -214,22 +229,15 @@ void NNTrainer::InputTensorsInfo::getSample(float **input, float **label,
                                             bool *last) {
   ml_logd("<called>");
   ml_logd("(pop/push: %d/%d)", pop_count, push_count);
+
   pid_t pid = getpid();
   pid_t tid = syscall(SYS_gettid);
+  ml_logd("pid[%d], tid[%d]", pid, tid);
 
-  /* After the epoch ends, the sub-plugin has no data yet to send. */
-  if (push_count == 0) {
-    ml_logd("locked, need to wait for more data, "
-            "After the epoch ends, the sub-plugin has no data yet to send.");
-    std::unique_lock<std::mutex> lock(data_wait_lock);
-    is_data_wait_locked = TRUE;
-    data_wait.wait(lock);
-    ml_logd("unlocked, get data");
-  }
+  std::unique_lock<std::mutex> lock(queue_lock);
+  data_empty.wait(lock, [this] { return !isQueueEmpty(); });
+  ml_logd("getSample condition is met");
 
-  ml_logd("<called>");
-  ml_logd("pid[%d], tid[%d]", pid, tid);
-  ml_logd("front:%d, rear:%d", queue_front, queue_rear);
   ml_logd("num_inputs: %d, num_labels: %d", num_inputs, num_labels);
 
   unsigned int i = 0;
@@ -248,7 +256,6 @@ void NNTrainer::InputTensorsInfo::getSample(float **input, float **label,
   }
 
   pop_count++;
-  queue_count--;
   queue_front = (queue_front + 1) % queue_size;
 
   ml_logd("(pop/push: %d/%d)", pop_count, push_count);
@@ -260,24 +267,10 @@ void NNTrainer::InputTensorsInfo::getSample(float **input, float **label,
     pop_count = 0;
   }
 
-  if (is_data_full_locked && queue_count > 0) {
-    data_full.notify_one();
-    ml_logd("send signal");
-  }
+  int queue_count = (queue_rear - queue_front + queue_size) % queue_size;
   ml_logd("front:%d, rear:%d, filled:%d", queue_front, queue_rear, queue_count);
-
-  /* epoch is complete */
-  if (pop_count == 0)
-    return;
-
-  /* to avoid dead lock, check is_data_full_locked */
-  if (!is_data_full_locked && queue_count == 0) {
-    ml_logd("locked, need to wait for more data");
-    std::unique_lock<std::mutex> lock(data_wait_lock);
-    is_data_wait_locked = TRUE;
-    data_wait.wait(lock);
-    ml_logd("unlocked, get data");
-  }
+  lock.unlock();
+  data_full.notify_one();
 
   ml_logd("<leave>");
   return;
@@ -350,11 +343,8 @@ NNTrainer::InputTensorsInfo::InputTensorsInfo(unsigned int _total_num_samples,
                                               unsigned int _num_inputs,
                                               unsigned int _num_labels,
                                               unsigned int _tensors_size[]) :
-  is_data_wait_locked(0),
-  is_data_full_locked(0),
   queue_front(0),
   queue_rear(0),
-  queue_count(0),
   push_count(0),
   pop_count(0),
   total_num_samples(_total_num_samples),
index d8521b1e4a5eba478c67bf4847430ee4e114fdb7..93e6aba88c287a92e957cf910e71d8bce25006a5 100644 (file)
@@ -15,6 +15,7 @@
 
 #include <condition_variable>
 #include <model.h>
+#include <mutex>
 #include <nnstreamer_plugin_api.h>
 #include <nnstreamer_plugin_api_trainer.h>
 #include <vector>
@@ -141,17 +142,13 @@ public:
    * @brief Destroy the InputTensorsInfo object
    */
   ~InputTensorsInfo();
-
-  bool is_data_wait_locked;
-  bool is_data_full_locked;
   unsigned int queue_size;
   unsigned int queue_front;
   unsigned int queue_rear;
-  unsigned int queue_count; /**< The number of data in queue */
-  unsigned int push_count;  /**< The number of samples pushed to queue by
-                               NNStreamer(tensor_trainer) */
-  unsigned int pop_count;   /**< The number of pop from the queue for pushing
-                               samples to nntrainer */
+  unsigned int push_count; /**< The number of samples pushed to queue by
+                              NNStreamer(tensor_trainer) */
+  unsigned int pop_count;  /**< The number of pop from the queue for pushing
+                              samples to nntrainer */
   unsigned int
     input_size[NNS_TENSOR_SIZE_LIMIT]; /**< feature size * data type */
   unsigned int label_size[NNS_TENSOR_SIZE_LIMIT];
@@ -165,10 +162,9 @@ public:
   std::vector<TensorData>
     tensor_data; /**< Manage multiple inputs and labels data */
 
-  std::mutex data_wait_lock;
-  std::mutex data_full_lock;
-  std::condition_variable data_wait;
+  std::mutex queue_lock;
   std::condition_variable data_full;
+  std::condition_variable data_empty;
 
   /**
    * @brief get sample data
@@ -178,5 +174,15 @@ public:
    * @param last set TRUE if data is last
    */
   void getSample(float **input, float **label, bool *last);
+
+  /**
+   * @brief Check if queue is empty
+   */
+  bool isQueueEmpty();
+
+  /**
+   * @brief Check if queue is full
+   */
+  bool isQueueFull();
 };
 } // namespace NNTrainer