NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
$(NNTRAINER_ROOT)/nntrainer/models/model_loader.cpp \
$(NNTRAINER_ROOT)/nntrainer/models/dynamic_training_optimization.cpp \
+ $(NNTRAINER_ROOT)/nntrainer/dataset/batch_queue.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_factory.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_func.cpp \
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file batch_queue.cpp
+ * @date 13 July 2021
+ * @brief This file contains thread safe queue
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+#include <batch_queue.h>
+#include <chrono>
+
+#include <mutex>
+#include <nntrainer_error.h>
+#include <shared_mutex>
+
+using namespace std::literals::chrono_literals;
+
+namespace nntrainer {
+
+BatchQueue::BatchQueue(unsigned int queue_capacity_) :
+ queue_capacity(queue_capacity_) {
+ NNTR_THROW_IF(queue_capacity == 0, std::invalid_argument)
+ << "queue capacity of zero not supported!";
+}
+
+BatchQueue::BatchQueue(const BatchQueue &rhs) :
+ queue_capacity(rhs.queue_capacity) {}
+
+BatchQueue &BatchQueue::operator=(const BatchQueue &rhs) {
+ if (this == &rhs) {
+ return *this;
+ }
+ this->queue_capacity = rhs.queue_capacity;
+ return *this;
+}
+
+void BatchQueue::wait_and_push(T &&data) noexcept {
+ std::unique_lock<std::shared_mutex> lk(q_mutex);
+ q_writer_cv.wait(lk, [this] { return q.size() != queue_capacity; });
+ q.push(std::make_unique<T>(data));
+ q_reader_cv.notify_one();
+}
+
+std::unique_ptr<BatchQueue::T> BatchQueue::wait_and_pop() noexcept {
+ std::unique_lock<std::shared_mutex> lk(q_mutex);
+ q_reader_cv.wait(lk, [this] { return !q.empty(); });
+
+ /// @note this invalidates q.front(), but it is okay because it is locked and
+ /// popped right away
+ auto ptr = std::move(q.front());
+ q.pop();
+ q_writer_cv.notify_one();
+
+ return ptr;
+}
+
+bool BatchQueue::isFull() const {
+ std::shared_lock<std::shared_mutex> lk(q_mutex);
+ return queue_capacity == q.size();
+}
+
+bool BatchQueue::isEmpty() const {
+ std::shared_lock<std::shared_mutex> lk(q_mutex);
+ return q.empty();
+}
+
+} // namespace nntrainer
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file batch_queue.h
+ * @date 13 July 2021
+ * @brief This file contains thread safe queue for batch
+ * @note This file is made to easily extend to type T, although it didn't to
+ * save compile time
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+#ifndef __BATCH_QUEUE_H__
+#define __BATCH_QUEUE_H__
+
+#include <queue>
+
+#include <condition_variable>
+#include <data_producers.h>
+#include <memory>
+#include <shared_mutex>
+
+namespace nntrainer {
+
+/**
+ * @brief Thread Safe batch queue Queue
+ *
+ */
+class BatchQueue {
+public:
+ using T = DataProducer::Iteration; /**< Iteration as type T to leave room to
+ extend the class to type T */
+
+ /**
+ * @brief Construct a new batch queue Queue object
+ * @note this is not the size of buffersize, but it is @a
+ * buffersize/batch_size the size never changes after the BatchQueue has been
+ * created
+ * @param queue_capacity_ max queue size
+ */
+ BatchQueue(unsigned int queue_capacity_);
+
+ /**
+ * @brief Construct a new batch queue Queue object
+ * @note this does not copy the original queue, but only queue size
+ * @param rhs batch queue queue to copy
+ */
+ BatchQueue(const BatchQueue &rhs);
+
+ /**
+ * @brief Copy-assign batch queue queue
+ *
+ * @param rhs batch queue queue to copy
+ * @return BatchQueue& new queue
+ */
+ BatchQueue &operator=(const BatchQueue &rhs);
+
+ /**
+ * @brief push data to queue, if the batch queue is full, wait if full
+ *
+ * @param data data to put inside the batch queue
+ */
+ void wait_and_push(T &&data) noexcept;
+
+ /**
+ * @brief pop data from the queue, wait if empty
+ *
+ * @return std::unique_ptr<T> data
+ */
+ std::unique_ptr<T> wait_and_pop() noexcept;
+
+ /**
+ * @brief check if current queue is full
+ *
+ * @return bool true if full
+ */
+ bool isFull() const;
+
+ /**
+ * @brief check if current queue is empty
+ *
+ * @return bool true if empty
+ */
+ bool isEmpty() const;
+
+private:
+ unsigned int queue_capacity;
+ mutable std::shared_mutex q_mutex;
+ std::condition_variable_any q_reader_cv;
+ std::condition_variable_any q_writer_cv;
+
+ /**
+ * @todo consider using circular buffer if this is too slow
+ *
+ */
+ std::queue<std::unique_ptr<T>> q;
+};
+
+} // namespace nntrainer
+
+#endif // __BATCH_QUEUE_H__
dataset_sources = [
+ 'batch_queue.cpp',
'databuffer.cpp',
'databuffer_factory.cpp',
'databuffer_file.cpp',
'data_producer_common_tests.cpp',
'unittest_random_data_producers.cpp',
'unittest_func_data_producer.cpp',
- 'unittest_raw_file_data_producer.cpp'
+ 'unittest_raw_file_data_producer.cpp',
+ 'unittest_batch_queue.cpp'
]
test_target += producer_targets
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file unittest_batch_queue.cpp
+ * @date 12 July 2021
+ * @brief Batch Queue Test
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <gtest/gtest.h>
+
+#include <batch_queue.h>
+#include <tensor.h>
+
+#include <future>
+#include <thread>
+#include <tuple>
+#include <vector>
+
+nntrainer::DataProducer::Iteration data(size_t key) {
+ return {true, std::vector<nntrainer::Tensor>(key), {}};
+};
+
+void test_data(const nntrainer::DataProducer::Iteration &dat,
+ size_t expected_key) {
+ EXPECT_EQ(std::get<1>(dat).size(), expected_key);
+}
+
+TEST(BatchQueue, pushPop_p) {
+ nntrainer::BatchQueue bq(1);
+
+ EXPECT_NO_THROW(bq.wait_and_push(data(1)));
+ auto result = bq.wait_and_pop();
+ test_data(*result, 1);
+}
+
+TEST(BatchQueue, threadedPushPops_p) {
+ /** preparing primitives */
+ using namespace std::chrono_literals;
+ auto push_after = [](nntrainer::BatchQueue &bq, const auto &duration,
+ size_t key) {
+ std::this_thread::sleep_for(duration);
+ EXPECT_NO_THROW(bq.wait_and_push(data(key)));
+ };
+ auto pop_after = [](nntrainer::BatchQueue &bq, const auto &duration,
+ size_t key) {
+ std::this_thread::sleep_for(duration);
+ auto result = bq.wait_and_pop();
+ test_data(*result, key);
+ };
+
+ std::vector<std::future<void>> futures;
+ {
+ futures.clear();
+ /// 0 -> push(1)
+ /// 250ms -> pop(1)
+ nntrainer::BatchQueue bq(1);
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 0ms, 1));
+ futures.push_back(
+ std::async(std::launch::async, pop_after, std::ref(bq), 250ms, 1));
+ for (auto &future : futures) {
+ future.get();
+ }
+ }
+
+ {
+ futures.clear();
+ /// 0 -> pop(1)
+ /// 250ms -> push(1)
+ nntrainer::BatchQueue bq(1);
+ futures.push_back(
+ std::async(std::launch::async, pop_after, std::ref(bq), 0ms, 1));
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 250ms, 1));
+ for (auto &future : futures) {
+ future.get();
+ }
+ }
+
+ {
+ futures.clear();
+ /// 0 -> push(1)
+ /// 300ms -> push(2)
+ /// 300ms -> pop(1)
+ /// 500ms -> push(3)
+ /// 600ms -> push(4) (waits)
+ /// 750ms -> pop(2)
+ /// 1000ms-> pop(3)
+ nntrainer::BatchQueue bq(2);
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 0ms, 1));
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 300ms, 2));
+ futures.push_back(
+ std::async(std::launch::async, pop_after, std::ref(bq), 300ms, 1));
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 500ms, 3));
+ futures.push_back(
+ std::async(std::launch::async, push_after, std::ref(bq), 600ms, 4));
+ futures.push_back(
+ std::async(std::launch::async, pop_after, std::ref(bq), 750ms, 2));
+ futures.push_back(
+ std::async(std::launch::async, pop_after, std::ref(bq), 1000ms, 3));
+ for (auto &future : futures) {
+ future.get();
+ }
+ }
+}