Fix windows test hang (#17778)
authorxuzhu <xzhu1900@gmail.com>
Tue, 12 Mar 2019 08:43:45 +0000 (01:43 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 08:50:49 +0000 (01:50 -0700)
Summary:
This PR resolves two concurrent issues discovered when running the test in windows. Details about the windows test can be found here: https://github.com/pytorch/pytorch/issues/17609

The change covers two fixes:
1. update running_preloaders_ upfront before creating worker thread to prevent underflow.
2. add a lock when updating stop_ to prevent dead lock in condition variable cv_write_.

The fix has been tested on both Windows and Linux. With --gtest_repeat=1000, the tests runs smoothly without issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17778

Differential Revision: D14404910

Pulled By: soumith

fbshipit-source-id: 2fbb8007e4b0bce4613e9a9fd31b8aace1bbfa8d

torch/csrc/api/include/torch/data/datasets/chunk.h

index 8326816..b5d66f8 100644 (file)
@@ -48,8 +48,7 @@ class BatchDataBuffer {
       size_t queue_capacity)
       : batch_size_(batch_size),
         example_sampler_(example_sampler),
-        queue_capacity_(queue_capacity),
-        stop_(false) {}
+        queue_capacity_(queue_capacity) {}
 
   /// Return batch data from the queue. Called from the ChunkDataset main
   /// thread.
@@ -60,10 +59,10 @@ class BatchDataBuffer {
       // loaded (i.e. the dataset is exhausted for this epoch)
       return (
           this->total_example_count_in_queue_ >= batch_size_ ||
-          this->stop_.load());
+          this->stop_);
     });
     if (batch_queue_.empty()) {
-      AT_ASSERT(this->stop_.load());
+      AT_ASSERT(stop_);
       // All batches have been retrieved. Return an empty batch.
       return nullopt;
     }
@@ -88,10 +87,9 @@ class BatchDataBuffer {
     cv_write_.wait(lock, [this] {
       // stop loading if we have preloaded enough data.
       return this->total_example_count_in_queue_ < this->queue_capacity_ ||
-          stop_.load();
+          this->stop_;
     });
-
-    if (stop_.load()){
+    if (stop_) {
       // When stop_ is true, it means no further chunk loading is necessary.
       // Return without any further processing.
       return;
@@ -149,10 +147,11 @@ class BatchDataBuffer {
     std::unique_lock<std::mutex> lock(queue_mutex_);
     cv_write_.wait(lock, [this] {
       // stop loading if we have preloaded enough data.
-      return this->total_example_count_in_queue_ < this->queue_capacity_ || stop_.load();
+      return (
+        this->total_example_count_in_queue_ < this->queue_capacity_ ||
+        this->stop_);
     });
-
-    if (stop_.load()){
+    if (stop_){
       // When stop_ is true, it means this current thread needs to be tore down,
       // the batch buffer will be discarded, so no need to enqueue any new
       // exceptions.
@@ -165,14 +164,27 @@ class BatchDataBuffer {
   }
 
   void stop(){
-    stop_ = true;
+    {
+      // Hold the lock before changing stop_ to prevent a race condition which can
+      // cause a deadlock.
+      // To be more specific, conditional variable cv_write_ waits on predicate
+      // stop_ in add_chunk_data(). The wait happens in two steps: 1) while still
+      // holding the lock, check if predicate is true; 2) if it is true, proceeds,
+      // otherwise, release the lock and wait until notified. Without holding a
+      // lock, cv_write_'s notification can happen in between step 1) and 2). In
+      // that case, as cv_write_ is not in waiting status yet, so the notification
+      // is lost and cv_write_ will sleep forever.
+      // By taking a lock before changing predicate stop_, it is ensured updating
+      // and evaluating stop_ always happen in a synchronized way
+      std::lock_guard<std::mutex> lock(queue_mutex_);
+      stop_ = true;
+    }
 
     // notify all writers, wake them from wait to exit current method.
     cv_write_.notify_all();
     // notify all readers too.
     cv_read_.notify_all();
   }
-  
   /// The batch size is needed to create batches from the chunk data. Similar to
   /// regular dataloader where the batches are created with prefetches,
   /// BatchDataBuffer perform the batch creation using the provided batch size.
@@ -217,7 +229,7 @@ class BatchDataBuffer {
   // preloader to finish previous work before tearing down the thread, the
   // preloader could be still waiting for the conditional variable, thus cause
   // the program to hang. This boolean is used to break this waiting condition.
-  std::atomic<bool> stop_;
+  bool stop_ = false;
 };
 } // namespace detail
 
@@ -321,7 +333,7 @@ class ChunkDataset final
   /// This will clear any internal state and starts the internal prefetching
   /// mechanism for the chunk dataset.
   void reset() override {
-    // We need this to support partial data reads via dataloader iterator.     
+    // We need this to support partial data reads via dataloader iterator.
     if (batch_buffer_) {
       batch_buffer_->stop();
     }
@@ -345,9 +357,9 @@ class ChunkDataset final
     quit_worker_ = false;
 
     AT_ASSERT(running_preloaders_ == 0);
+    running_preloaders_ = options_.preloader_count_;
     for (size_t i = 0; i < options_.preloader_count_; ++i) {
       preload_threads_.emplace_back([this, i]() { this->preloader(i); });
-      ++running_preloaders_;
     }
   }
 
@@ -384,6 +396,7 @@ class ChunkDataset final
         batch_buffer_->add_chunk_data(std::current_exception());
       }
     }
+    AT_ASSERT(running_preloaders_.load() > 0);
     --running_preloaders_;
     if (running_preloaders_.load() == 0) {
       // all preloaders are completed, so we can notify the batch_buffer.
@@ -394,7 +407,7 @@ class ChunkDataset final
   /// Block the current thread until the workers finish execution and exit.
   void free_workers() {
     if (!quit_worker_.load()) {
-      quit_worker_ = true;      
+      quit_worker_ = true;
       for (auto& worker_thread : preload_threads_) {
         worker_thread.join();
       }