[Dataset] Add thread safe queue for the buffer
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 13 Jul 2021 11:18:06 +0000 (20:18 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 3 Aug 2021 06:21:20 +0000 (15:21 +0900)
Currently, databuffer is not queuing any iteration but requesting buffer
on-demand basis.
This patch adds thread safe queue for the buffer to queue multiple
buffers

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
jni/Android.mk
nntrainer/dataset/batch_queue.cpp [new file with mode: 0644]
nntrainer/dataset/batch_queue.h [new file with mode: 0644]
nntrainer/dataset/meson.build
test/unittest/datasets/meson.build
test/unittest/datasets/unittest_batch_queue.cpp [new file with mode: 0644]

index 970989c..f8b234f 100644 (file)
@@ -125,6 +125,7 @@ include $(CLEAR_VARS)
 NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/models/model_loader.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/models/dynamic_training_optimization.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/dataset/batch_queue.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/dataset/databuffer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_factory.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_func.cpp \
diff --git a/nntrainer/dataset/batch_queue.cpp b/nntrainer/dataset/batch_queue.cpp
new file mode 100644 (file)
index 0000000..915268f
--- /dev/null
@@ -0,0 +1,71 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file   batch_queue.cpp
+ * @date   13 July 2021
+ * @brief  This file contains thread safe queue
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+#include <batch_queue.h>
+#include <chrono>
+
+#include <mutex>
+#include <nntrainer_error.h>
+#include <shared_mutex>
+
+using namespace std::literals::chrono_literals;
+
+namespace nntrainer {
+
+BatchQueue::BatchQueue(unsigned int queue_capacity_) :
+  queue_capacity(queue_capacity_) {
+  NNTR_THROW_IF(queue_capacity == 0, std::invalid_argument)
+    << "queue capacity of zero not supported!";
+}
+
+BatchQueue::BatchQueue(const BatchQueue &rhs) :
+  queue_capacity(rhs.queue_capacity) {}
+
+BatchQueue &BatchQueue::operator=(const BatchQueue &rhs) {
+  if (this == &rhs) {
+    return *this;
+  }
+  this->queue_capacity = rhs.queue_capacity;
+  return *this;
+}
+
+void BatchQueue::wait_and_push(T &&data) noexcept {
+  std::unique_lock<std::shared_mutex> lk(q_mutex);
+  q_writer_cv.wait(lk, [this] { return q.size() != queue_capacity; });
+  q.push(std::make_unique<T>(data));
+  q_reader_cv.notify_one();
+}
+
+std::unique_ptr<BatchQueue::T> BatchQueue::wait_and_pop() noexcept {
+  std::unique_lock<std::shared_mutex> lk(q_mutex);
+  q_reader_cv.wait(lk, [this] { return !q.empty(); });
+
+  /// @note this invalidates q.front(), but it is okay because it is locked and
+  /// popped right away
+  auto ptr = std::move(q.front());
+  q.pop();
+  q_writer_cv.notify_one();
+
+  return ptr;
+}
+
+bool BatchQueue::isFull() const {
+  std::shared_lock<std::shared_mutex> lk(q_mutex);
+  return queue_capacity == q.size();
+}
+
+bool BatchQueue::isEmpty() const {
+  std::shared_lock<std::shared_mutex> lk(q_mutex);
+  return q.empty();
+}
+
+} // namespace nntrainer
diff --git a/nntrainer/dataset/batch_queue.h b/nntrainer/dataset/batch_queue.h
new file mode 100644 (file)
index 0000000..b440bfa
--- /dev/null
@@ -0,0 +1,103 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file   batch_queue.h
+ * @date   13 July 2021
+ * @brief  This file contains thread safe queue for batch
+ * @note   This file is made to easily extend to type T, although it didn't to
+ * save compile time
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+#ifndef __BATCH_QUEUE_H__
+#define __BATCH_QUEUE_H__
+
+#include <queue>
+
+#include <condition_variable>
+#include <data_producers.h>
+#include <memory>
+#include <shared_mutex>
+
+namespace nntrainer {
+
+/**
+ * @brief Thread Safe batch queue Queue
+ *
+ */
+class BatchQueue {
+public:
+  using T = DataProducer::Iteration; /**< Iteration as type T to leave room to
+                                        extend the class to type T */
+
+  /**
+   * @brief Construct a new batch queue Queue object
+   * @note this is not the size of buffersize, but it is @a
+   * buffersize/batch_size the size never changes after the BatchQueue has been
+   * created
+   * @param queue_capacity_ max queue size
+   */
+  BatchQueue(unsigned int queue_capacity_);
+
+  /**
+   * @brief Construct a new batch queue Queue object
+   * @note this does not copy the original queue, but only queue size
+   * @param rhs batch queue queue to copy
+   */
+  BatchQueue(const BatchQueue &rhs);
+
+  /**
+   * @brief Copy-assign batch queue queue
+   *
+   * @param rhs batch queue queue to copy
+   * @return BatchQueue& new queue
+   */
+  BatchQueue &operator=(const BatchQueue &rhs);
+
+  /**
+   * @brief push data to queue, if the batch queue is full, wait if full
+   *
+   * @param data data to put inside the batch queue
+   */
+  void wait_and_push(T &&data) noexcept;
+
+  /**
+   * @brief pop data from the queue, wait if empty
+   *
+   * @return std::unique_ptr<T> data
+   */
+  std::unique_ptr<T> wait_and_pop() noexcept;
+
+  /**
+   * @brief check if current queue is full
+   *
+   * @return bool true if full
+   */
+  bool isFull() const;
+
+  /**
+   * @brief check if current queue is empty
+   *
+   * @return bool true if empty
+   */
+  bool isEmpty() const;
+
+private:
+  unsigned int queue_capacity;
+  mutable std::shared_mutex q_mutex;
+  std::condition_variable_any q_reader_cv;
+  std::condition_variable_any q_writer_cv;
+
+  /**
+   * @todo consider using circular buffer if this is too slow
+   *
+   */
+  std::queue<std::unique_ptr<T>> q;
+};
+
+} // namespace nntrainer
+
+#endif // __BATCH_QUEUE_H__
index e3cebe7..459bdc2 100644 (file)
@@ -1,4 +1,5 @@
 dataset_sources = [
+  'batch_queue.cpp',
   'databuffer.cpp',
   'databuffer_factory.cpp',
   'databuffer_file.cpp',
index 64de934..bff21ff 100644 (file)
@@ -6,7 +6,8 @@ producer_targets = [
   'data_producer_common_tests.cpp',
   'unittest_random_data_producers.cpp',
   'unittest_func_data_producer.cpp',
-  'unittest_raw_file_data_producer.cpp'
+  'unittest_raw_file_data_producer.cpp',
+  'unittest_batch_queue.cpp'
 ]
 
 test_target += producer_targets
diff --git a/test/unittest/datasets/unittest_batch_queue.cpp b/test/unittest/datasets/unittest_batch_queue.cpp
new file mode 100644 (file)
index 0000000..f0b2882
--- /dev/null
@@ -0,0 +1,112 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file unittest_batch_queue.cpp
+ * @date 12 July 2021
+ * @brief Batch Queue Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <gtest/gtest.h>
+
+#include <batch_queue.h>
+#include <tensor.h>
+
+#include <future>
+#include <thread>
+#include <tuple>
+#include <vector>
+
+nntrainer::DataProducer::Iteration data(size_t key) {
+  return {true, std::vector<nntrainer::Tensor>(key), {}};
+};
+
+void test_data(const nntrainer::DataProducer::Iteration &dat,
+               size_t expected_key) {
+  EXPECT_EQ(std::get<1>(dat).size(), expected_key);
+}
+
+TEST(BatchQueue, pushPop_p) {
+  nntrainer::BatchQueue bq(1);
+
+  EXPECT_NO_THROW(bq.wait_and_push(data(1)));
+  auto result = bq.wait_and_pop();
+  test_data(*result, 1);
+}
+
+TEST(BatchQueue, threadedPushPops_p) {
+  /** preparing primitives */
+  using namespace std::chrono_literals;
+  auto push_after = [](nntrainer::BatchQueue &bq, const auto &duration,
+                       size_t key) {
+    std::this_thread::sleep_for(duration);
+    EXPECT_NO_THROW(bq.wait_and_push(data(key)));
+  };
+  auto pop_after = [](nntrainer::BatchQueue &bq, const auto &duration,
+                      size_t key) {
+    std::this_thread::sleep_for(duration);
+    auto result = bq.wait_and_pop();
+    test_data(*result, key);
+  };
+
+  std::vector<std::future<void>> futures;
+  {
+    futures.clear();
+    /// 0     -> push(1)
+    /// 250ms -> pop(1)
+    nntrainer::BatchQueue bq(1);
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 0ms, 1));
+    futures.push_back(
+      std::async(std::launch::async, pop_after, std::ref(bq), 250ms, 1));
+    for (auto &future : futures) {
+      future.get();
+    }
+  }
+
+  {
+    futures.clear();
+    /// 0     -> pop(1)
+    /// 250ms -> push(1)
+    nntrainer::BatchQueue bq(1);
+    futures.push_back(
+      std::async(std::launch::async, pop_after, std::ref(bq), 0ms, 1));
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 250ms, 1));
+    for (auto &future : futures) {
+      future.get();
+    }
+  }
+
+  {
+    futures.clear();
+    /// 0     -> push(1)
+    /// 300ms -> push(2)
+    /// 300ms -> pop(1)
+    /// 500ms -> push(3)
+    /// 600ms -> push(4) (waits)
+    /// 750ms -> pop(2)
+    /// 1000ms-> pop(3)
+    nntrainer::BatchQueue bq(2);
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 0ms, 1));
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 300ms, 2));
+    futures.push_back(
+      std::async(std::launch::async, pop_after, std::ref(bq), 300ms, 1));
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 500ms, 3));
+    futures.push_back(
+      std::async(std::launch::async, push_after, std::ref(bq), 600ms, 4));
+    futures.push_back(
+      std::async(std::launch::async, pop_after, std::ref(bq), 750ms, 2));
+    futures.push_back(
+      std::async(std::launch::async, pop_after, std::ref(bq), 1000ms, 3));
+    for (auto &future : futures) {
+      future.get();
+    }
+  }
+}