[Dataset] Add mt support to iteration Queue
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 18 Aug 2021 05:20:40 +0000 (14:20 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 27 Aug 2021 11:44:58 +0000 (20:44 +0900)
This path adds async support to iteration queue and ending mechanism

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/dataset/batch_queue.cpp
nntrainer/dataset/batch_queue.h

index 7a64e13ef11dc6eb5515ddf1a0eae6ee1e008b42..17733552cee892295c1f45e1a3a62f6a13840025 100644 (file)
@@ -42,6 +42,7 @@ void BatchQueue::wait_and_push(T &&data) noexcept {
   std::unique_lock<std::shared_mutex> lk(q_mutex);
   q_writer_cv.wait(lk, [this] { return q.size() != queue_capacity; });
   q.push(std::make_unique<T>(data));
+  lk.unlock();
   q_reader_cv.notify_one();
 }
 
@@ -53,6 +54,7 @@ std::unique_ptr<BatchQueue::T> BatchQueue::wait_and_pop() noexcept {
   /// popped right away
   auto ptr = std::move(q.front());
   q.pop();
+  lk.unlock();
   q_writer_cv.notify_one();
 
   return ptr;
@@ -71,7 +73,9 @@ bool BatchQueue::isEmpty() const {
 IterationQueue::IterationQueue(
   unsigned int num_slots, const std::vector<ml::train::TensorDim> &input_dims,
   const std::vector<ml::train::TensorDim> &label_dims) :
-  being_filled(nullptr) {
+  being_filled(nullptr),
+  num_being_filled(0),
+  flow_state(IterationQueue::FlowState::FLOW_STATE_OPEN) {
   NNTR_THROW_IF(num_slots == 0, std::invalid_argument)
     << "number of slots must be more then zero";
 
@@ -82,14 +86,35 @@ IterationQueue::IterationQueue(
   }
 }
 
+IterationQueue::~IterationQueue() {
+  std::scoped_lock lg(empty_mutex, filled_mutex);
+
+  /// if an iteration is not included in either empty_q or filled_q, that
+  /// means it's either being filled or being served. Which means it will be
+  /// dangerous to destroy @a this, we might want to wait on the destructor if
+  /// we can assure this can stay no except
+  if (empty_q.size() + filled_q.size() < iterations.size()) {
+    ml_logw(
+      "Destroying the iteration queue, while some buffers are being used");
+  }
+}
+
 ScopedView<Sample> IterationQueue::requestEmpty() {
-  if (being_filled == nullptr) {
-    if (empty_q.empty()) {
-      throw std::invalid_argument(
-        "empty_q empty"); /// this is temporary measure
-    }
-    being_filled = empty_q.front();
-    empty_q.pop();
+  std::scoped_lock lg(empty_mutex);
+  NNTR_THROW_IF(flow_state != FlowState::FLOW_STATE_OPEN, std::invalid_argument)
+    << "Calling requestEmpty() after notifyEndOfRequestEmpty() breaks "
+       "invariant";
+
+  /// below is useful information when debugging iteration queue, but there will
+  /// be to much log if we turn the log on. so leaving it as a comment for now.
+  // std::cout << "[requestEmpty] empty_q.size(): " << empty_q.size()
+  // << " being_filled: " << num_being_filled
+  // << " filled_q.size():  " << filled_q.size() << '\n';
+
+  if (being_filled == nullptr ||
+      current_iterator + 1 == being_filled->get().end()) {
+    being_filled = empty_q.waitAndPop();
+    num_being_filled++;
     current_iterator = being_filled->get().begin();
   } else {
     current_iterator++;
@@ -99,42 +124,76 @@ ScopedView<Sample> IterationQueue::requestEmpty() {
                                  [current_being_filed = this->being_filled] {
                                    current_being_filed->markSampleFilled();
                                  });
-
-  if (current_iterator + 1 == being_filled->get().end()) {
-    being_filled = nullptr;
-  }
-
   return view;
 }
 
 ScopedView<Iteration> IterationQueue::requestFilled() {
-  if (filled_q.empty()) {
-    throw std::invalid_argument("filled_q empty"); /// this is temporary measure
+  std::scoped_lock lock(filled_mutex);
+
+  /// below is useful information when debugging iteration queue, but there will
+  /// be to much log if we turn the log on. so leaving it as a comment for now.
+  // std::cout << "[requestFilled] empty_q.size(): " << empty_q.size()
+  // << " num being filled: " << num_being_filled
+  // << " filled_q.size(): " << filled_q.size() << '\n';
+  if (flow_state == FlowState::FLOW_STATE_STOPPED) {
+    return ScopedView<Iteration>(nullptr);
+  }
+
+  auto iteration = filled_q.waitAndPop();
+  if (iteration == nullptr) {
+    NNTR_THROW_IF(flow_state != FlowState::FLOW_STATE_STOP_REQUESTED,
+                  std::runtime_error)
+      << "the queue has either already stopped or running, but trying stopping "
+         "without requesting stop, queue size: "
+      << iterations.size() << " num currently empty: " << empty_q.size()
+      << " num being filled: " << num_being_filled
+      << " filled_q.size(): " << filled_q.size();
+
+    flow_state = FlowState::FLOW_STATE_STOPPED;
+    return ScopedView<Iteration>(nullptr);
   }
 
-  auto iteration = filled_q.front();
-  filled_q.pop();
   return ScopedView<Iteration>(&iteration->get(),
                                [this, iteration] { markEmpty(iteration); });
 }
 
+void IterationQueue::notifyEndOfRequestEmpty() {
+  std::unique_lock lg(empty_mutex);
+  NNTR_THROW_IF(flow_state != FlowState::FLOW_STATE_OPEN, std::invalid_argument)
+    << "notifyEndOfRequestEmpty() must be called once";
+
+  /// below is useful information when debugging iteration queue, but there will
+  /// be to much log if we turn the log on. so leaving it as a comment for now.
+  // std::cout << "[notifyEnd] empty_q.size(): " << empty_q.size()
+  //           << " num being filled: " << num_being_filled
+  //           << " filled_q.size(): " << filled_q.size() << '\n';
+
+  flow_state = FlowState::FLOW_STATE_STOP_REQUESTED;
+  if (being_filled) {
+    being_filled->setEndSample(current_iterator + 1);
+  }
+  notify_emptied_cv.wait(lg, [this] { return num_being_filled == 0; });
+  filled_q.push(nullptr);
+}
+
 IterationQueue::MarkableIteration::MarkableIteration(
   const std::vector<ml::train::TensorDim> &input_dims,
   const std::vector<ml::train::TensorDim> &label_dims, IterationQueue *iq) :
+  num_observed(0),
   iteration(input_dims, label_dims),
-  iq(iq),
-  num_observed(0) {}
+  iq(iq) {}
 
 IterationQueue::MarkableIteration::MarkableIteration(MarkableIteration &&rhs) :
+  num_observed(rhs.num_observed),
   iteration(std::move(rhs.iteration)),
-  iq(rhs.iq),
-  num_observed(rhs.num_observed) {}
+  iq(rhs.iq) {}
 
 IterationQueue::MarkableIteration &IterationQueue::MarkableIteration::
 operator=(MarkableIteration &&rhs) {
   if (this == &rhs) {
     return *this;
   }
+  std::scoped_lock lock(this->notify_mutex, rhs.notify_mutex);
   std::swap(iteration, rhs.iteration);
   std::swap(iq, rhs.iq);
   std::swap(num_observed, rhs.num_observed);
@@ -142,7 +201,11 @@ operator=(MarkableIteration &&rhs) {
 }
 
 void IterationQueue::markFilled(MarkableIteration *iteration) /** noexcept */ {
+  std::unique_lock lg(empty_mutex);
+  num_being_filled--;
   filled_q.push(iteration);
+  lg.unlock();
+  notify_emptied_cv.notify_all();
 }
 
 void IterationQueue::markEmpty(MarkableIteration *iteration) /** noexcept */ {
@@ -150,7 +213,7 @@ void IterationQueue::markEmpty(MarkableIteration *iteration) /** noexcept */ {
 }
 
 void IterationQueue::MarkableIteration::markSampleFilled() {
-  std::lock_guard notify_lock_guard(notify_mutex);
+  std::scoped_lock notify_lock_guard(notify_mutex);
   num_observed++;
   if (num_observed == iteration.batch()) {
     iq->markFilled(this);
@@ -158,4 +221,29 @@ void IterationQueue::MarkableIteration::markSampleFilled() {
   }
 }
 
+void IterationQueue::MarkableIteration::setEndSample(
+  std::vector<Sample>::iterator sample_iterator) {
+  std::scoped_lock notify_lock_guard(notify_mutex);
+  auto old_batch = iteration.batch();
+  if (sample_iterator != iteration.end()) {
+    iteration.setEndSample(sample_iterator);
+  }
+  auto new_batch = iteration.batch();
+  /// if batch has changed, check if this batch is partially filled and should
+  /// be moved
+  if (old_batch != new_batch && num_observed == new_batch) {
+#if DEBUG
+    NNTR_THROW_IF_CLEANUP(iq->empty_mutex.try_lock(), std::runtime_error,
+                          [iq] { iq->empty_mutex.unlock(); })
+      << "iteration queue must be locked already but empty_mutex is not "
+         "locked.";
+#endif
+    /// warning: iq has to be locked with iq->empty_mutex
+    iq->num_being_filled--;
+    iq->filled_q.push(this);
+    iq->notify_emptied_cv.notify_all();
+    num_observed = 0;
+  }
+}
+
 } // namespace nntrainer
index 0bd476762cd5e8392f37e597b95bd79e89325eff..50e437343a66a59dce26f564c82f4df222eb69d6 100644 (file)
@@ -20,6 +20,7 @@
 #include <memory>
 #include <queue>
 #include <shared_mutex>
+#include <stdexcept>
 #include <tuple>
 
 #include <data_iteration.h>
@@ -99,6 +100,74 @@ private:
   std::queue<std::unique_ptr<T>> q;
 };
 
+/**
+ * @brief Thread Safe Queue implementation dedicated for the non-owing pointer
+ *
+ * @tparam original type of the view (T * will be pushed and pop)
+ */
+template <typename T> class ViewQueue {
+public:
+  /**
+   * @brief Construct a new queue
+   */
+  ViewQueue() : q() {}
+
+  /**
+   * @brief push data to queue
+   *
+   * @param data data to put
+   */
+  void push(T *data) {
+    {
+      std::unique_lock<std::shared_mutex> lk(q_mutex);
+      q.push(data);
+    }
+
+    q_cv.notify_one();
+  }
+
+  /**
+   * @brief pop data from the queue, wait if empty
+   * @note when fail to get, this will return nullptr (eg) when interrupt
+   * happens)
+   * @return T* view of the data
+   */
+  T *waitAndPop() {
+    std::unique_lock<std::shared_mutex> lk(q_mutex);
+    q_cv.wait(lk, [this] { return !q.empty(); });
+    auto ptr = q.front();
+    q.pop();
+
+    return ptr;
+  }
+
+  /**
+   * @brief check if current queue is empty
+   *
+   * @return bool true if empty
+   */
+  bool isEmpty() const {
+    std::shared_lock<std::shared_mutex> lk(q_mutex);
+    return q.empty();
+  }
+
+  /**
+   * @brief check if current queue is empty
+   *
+   * @return bool true if empty
+   */
+  typename std::queue<T *>::size_type size() const {
+    std::shared_lock<std::shared_mutex> lk(q_mutex);
+    return q.size();
+  }
+
+private:
+  mutable std::shared_mutex q_mutex;
+  std::condition_variable_any q_cv;
+
+  std::queue<T *> q;
+};
+
 /**
  * @brief A view container that calls a callback on destruct
  * @note the callback must be noexcept, and the given underlying data must
@@ -112,10 +181,9 @@ public:
    * @brief Construct a new Scoped View object
    *
    * @param data_ reference of the underlying data
-   * @param on_notify_ callback to be called on exit, this is not copied but
-   * reused
+   * @param on_notify_ callback to be called on exit
    */
-  ScopedView(T *data_, std::function<void(void)> &&on_notify_) :
+  ScopedView(T *data_, std::function<void(void)> &&on_notify_ = nullptr) :
     data(data_),
     on_notify(std::forward<std::function<void(void)>>(on_notify_)) {}
 
@@ -125,15 +193,28 @@ public:
   ScopedView(ScopedView &&rhs) = default;
   ScopedView &operator=(ScopedView &&rhs) = default;
 
+  /**
+   * @brief check if scoped view contains any underlying data
+   *
+   * @return bool if data is empty
+   */
+  bool isEmpty() { return data == nullptr; }
+
   /**
    * @brief Destroy the Scoped View object, callback is called at this time
    *
    */
   ~ScopedView() {
     try {
-      on_notify();
+      if (std::uncaught_exceptions()) {
+        /// NYI, add on_error handler here
+      } else {
+        if (on_notify) {
+          on_notify();
+        }
+      }
     } catch (...) {
-      ml_loge("while notifiying, error happened");
+      ml_loge("while handling on_notify or on_error, error happened");
     }
   }
 
@@ -159,7 +240,7 @@ private:
 
 /**
  * @brief Iteration queue that owns the buffer for input / labels
- * @detail
+ * @details
  *
  * - requestEmpty() will give a ScopedView<sample>
  *     Destructing the returned object will notify the iteration that is done
@@ -169,8 +250,15 @@ private:
  *     Destructing this will notify the queue that is done used (internally
  * calls IterationQueue::markEmpty())
  *
+ * @details For an iteration there can be four state.
+ * 1. The buffer is empty, waiting to be filled (will be in empty_q)
+ * 2. The buffer is being filled sample by sample, waiting to be marked as
+ * filled.
+ * 3. The buffer is filled, waiting to be served (will be in filled_q)
+ * 4. The buffer is being served, waiting to be marked as emptied.
  * @todo apply this to the databuffer
- * @todo prepare thread safe queue and apply
+ * @todo handle error case: 1. when ScopedView<Sample> has met throw
+ *                          2. when ScopedView<Iteration> has met throw
  */
 class IterationQueue {
 public:
@@ -187,20 +275,30 @@ public:
                  const std::vector<ml::train::TensorDim> &input_dims,
                  const std::vector<ml::train::TensorDim> &label_dims);
 
+  /**
+   * @brief Destroy the Iteration Queue object
+   *
+   */
+  ~IterationQueue();
+
   /**
    * @brief request empty sample from the queue.
-   * @note There is race condition between requesting empty, race condition with
-   * mark_ready should be handled by using MT_safe queue.
-   * @return ScopedView<Sample> sample view. Destroying the returned object will
+   * @note User must check if ScopedView actually has a value by calling
+   * copedView::isEmpty()
+   * @return ScopedView<Sample> sample view. ScopedView::isEmpty() == true
+   * if there is no more data coming. Destroying the returned object will
    * signal the queue that the sample is filled.
    */
   ScopedView<Sample> requestEmpty();
 
   /**
    * @brief request filled iteration from the queue.
-   * @note race condition here can be handled by using MT_safe queue
-   * @return ScopedView<Iteration> Ieration view. Destroying the returned object
-   * will signal the queue that the sample is done using.
+   * @note User must check if ScopedView actually has a value by calling
+   * copedView::isEmpty()
+   * @return ScopedView<Iteration> Ieration view. ScopedView::isEmpty() == true
+   * if there is no more data coming. Destroying the returned object will
+   * signal the queue that the sample is done using.
+   *
    */
   ScopedView<Iteration> requestFilled();
 
@@ -218,6 +316,17 @@ public:
    */
   unsigned int batch() { return iterations.front().get().batch(); }
 
+  /**
+   * @brief notifyEndOfRequest, when the producing by requestEmpty has finished.
+   * @note It is important that the owner of this class must ensure that there
+   * will be no more requestEmpty call after this. This means that, in case of
+   * multiple workers, the manager of the worker(producer) must know every
+   * producer has finished. and call this function other than each worker call
+   * this function.
+   *
+   */
+  void notifyEndOfRequestEmpty();
+
 private:
   /**
    * @brief A wrapper object around @c Iteration which marks filled when filling
@@ -260,6 +369,15 @@ private:
      */
     void markSampleFilled() /** noexcept */;
 
+    /**
+     * @brief update end sample to the given iterator and mark last
+     * @note after updating end iterator, this can be markFilled() if every
+     * sample is already filled
+     *
+     * @param iterator non-inclusive iterator to mark the last
+     */
+    void setEndSample(std::vector<Sample>::iterator sample_iterator);
+
     /**
      * @brief get underlying iteration
      *
@@ -268,10 +386,22 @@ private:
     Iteration &get() { return iteration; }
 
   private:
-    mutable std::mutex notify_mutex;
-    Iteration iteration;
-    IterationQueue *iq;
-    unsigned int num_observed;
+    unsigned int num_observed; /**< number of observed samples which were passed
+                                  to the callee and notified done filling */
+    mutable std::mutex
+      notify_mutex;      /**< mutex which should be locked when try to notify */
+    Iteration iteration; /**< underlying iteration that this class owns */
+    IterationQueue *iq;  /**< view of iteration queue */
+  };
+
+  /**
+   * @brief Queue running state enum
+   *
+   */
+  enum class FlowState {
+    FLOW_STATE_OPEN = 0,           /**< nothing */
+    FLOW_STATE_STOP_REQUESTED = 1, /**< request stop */
+    FLOW_STATE_STOPPED = 2,        /**< queue is fully stopped */
   };
 
   /**
@@ -288,14 +418,25 @@ private:
    */
   void markEmpty(MarkableIteration *iteration) /** noexcept */;
 
-  std::vector<MarkableIteration> iterations; /** allocated iterations */
-  MarkableIteration *being_filled; /**< iteration that is being filled now */
-
-  std::vector<Sample>::iterator current_iterator;
-
-  /// @todo use mt safe queue
-  std::queue<MarkableIteration *> empty_q;  /** iterations to be filled */
-  std::queue<MarkableIteration *> filled_q; /** iterations to be served */
+  std::vector<MarkableIteration> iterations; /**< allocated iterations */
+  MarkableIteration *being_filled; /**< last iteration that is being filled */
+  std::vector<Sample>::iterator
+    current_iterator; /**< current sample iteration of being_filled */
+
+  mutable std::mutex empty_mutex; /**< mutex to be used when it is mutually
+                                     exclusive to the requesting empty slots */
+  unsigned int
+    num_being_filled; /**< number of iteration that is in being_filled state */
+  mutable std::mutex
+    filled_mutex; /**< mutex to be used when it is mutually exclusive to the
+                     requesting filled slots */
+  std::condition_variable_any
+    notify_emptied_cv;  /**< conditional variable to wait based on the
+                           num_being_filled */
+  FlowState flow_state; /**< flow state of the queue */
+
+  ViewQueue<MarkableIteration> empty_q;  /**< iterations to be filled */
+  ViewQueue<MarkableIteration> filled_q; /**< iterations to be served */
 };
 
 } // namespace nntrainer