Remove chunk count check on the ChunkBuffer (#16868)
authorJaliya Ekanayake <jaliyaek@microsoft.com>
Wed, 13 Feb 2019 18:26:15 +0000 (10:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Feb 2019 19:09:42 +0000 (11:09 -0800)
Summary:
Previously, the ChunkBuffer depends on the remaining chunk count to signal end of dataloading. This does not work with distributed samplers where each sampler only loads a subset of  chunks. This refactor remove the dependency on the remaining chunk count at the ChunkBuffer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16868

Differential Revision: D14066517

Pulled By: goldsborough

fbshipit-source-id: 293dfe282ceff326dff0876c2f75c2ee4f4463e2

test/cpp/api/dataloader.cpp
torch/csrc/api/include/torch/data/datasets/chunk.h
torch/csrc/api/include/torch/data/samplers/base.h

index 844eaed..0e17233 100644 (file)
@@ -1620,3 +1620,44 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
     }
   }
 }
+
+TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
+  const size_t prefetch_count = 2;
+  const size_t batch_size = 5;
+
+  DummyChunkDataReader data_reader;
+  samplers::SequentialSampler sampler(0);
+  datasets::SharedBatchDataset<datasets::ChunkDataset<
+      DummyChunkDataReader,
+      samplers::SequentialSampler,
+      samplers::SequentialSampler>>
+      dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
+          DummyChunkDataReader,
+          samplers::SequentialSampler,
+          samplers::SequentialSampler>>(
+          data_reader,
+          sampler,
+          sampler,
+          datasets::ChunkDatasetOptions(prefetch_count, batch_size));
+
+  samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
+
+  auto data_loader = torch::data::make_data_loader(
+      dataset.map(transforms::BatchLambda<std::vector<int>, int>(
+          [](std::vector<int> batch) {
+            return std::accumulate(batch.begin(), batch.end(), 0);
+          })),
+      DataLoaderOptions(batch_size).workers(0));
+
+  // before we start, the index should be 0.
+  ASSERT_EQ(chunk_sampler.index(), 0);
+
+  size_t sum = 0;
+  for (auto iterator = data_loader->begin(); iterator != data_loader->end();
+       ++iterator) {
+    sum += *iterator;
+  }
+  ASSERT_EQ(sum, 595); // sum([0, 35))
+  // 3 chunks, and when exhausted the value is already incremented.
+  ASSERT_EQ(chunk_sampler.index(), 3);
+}
index ae8e1c8..1507dc3 100644 (file)
@@ -43,12 +43,10 @@ class BatchDataBuffer {
   using BatchRequestType = typename ExampleSampler::BatchRequestType;
 
   BatchDataBuffer(
-      size_t num_chunks,
       size_t batch_size,
       ExampleSampler& example_sampler,
       size_t queue_capacity)
-      : remaining_chunk_count_(num_chunks),
-        batch_size_(batch_size),
+      : batch_size_(batch_size),
         example_sampler_(example_sampler),
         queue_capacity_(queue_capacity),
         stop_(false) {}
@@ -62,11 +60,10 @@ class BatchDataBuffer {
       // loaded (i.e. the dataset is exhausted for this epoch)
       return (
           this->total_example_count_in_queue_ >= batch_size_ ||
-          this->remaining_chunk_count_ == 0);
+          this->stop_.load());
     });
     if (batch_queue_.empty()) {
-      AT_ASSERT(remaining_chunk_count_ == 0);
-
+      AT_ASSERT(this->stop_.load());
       // All batches have been retrieved. Return an empty batch.
       return nullopt;
     }
@@ -84,26 +81,18 @@ class BatchDataBuffer {
     return batch.batch_data;
   }
 
-  // skip one chunk
-  void skip_chunk() {
-    std::unique_lock<std::mutex> lock(queue_mutex_);
-    AT_ASSERT(remaining_chunk_count_ > 0);
-    remaining_chunk_count_--;
-    lock.unlock();
-    cv_read_.notify_all();
-  }
-
   /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
   /// threads.
   void add_chunk_data(UnwrappedBatchType data) {
     std::unique_lock<std::mutex> lock(queue_mutex_);
     cv_write_.wait(lock, [this] {
       // stop loading if we have preloaded enough data.
-      return this->total_example_count_in_queue_ < this->queue_capacity_ || stop_.load();
+      return this->total_example_count_in_queue_ < this->queue_capacity_ ||
+          stop_.load();
     });
 
     if (stop_.load()){
-      // When stop_ is true, it means this current thread needs to be tore down.
+      // When stop_ is true, it means no further chunk loading is necessary.
       // Return without any further processing.
       return;
     }
@@ -150,10 +139,6 @@ class BatchDataBuffer {
       batch_queue_.emplace(std::move(current_batch));
     }
     total_example_count_in_queue_ += data_size;
-
-    AT_ASSERT(remaining_chunk_count_ > 0);
-    remaining_chunk_count_--;
-
     lock.unlock();
     cv_read_.notify_all();
   }
@@ -175,9 +160,6 @@ class BatchDataBuffer {
     }
 
     batch_queue_.emplace(e_ptr);
-
-    AT_ASSERT(remaining_chunk_count_ > 0);
-    remaining_chunk_count_--;
     lock.unlock();
     cv_read_.notify_all();
   }
@@ -187,13 +169,10 @@ class BatchDataBuffer {
 
     // notify all writers, wake them from wait to exit current method.
     cv_write_.notify_all();
+    // notify all readers too.
+    cv_read_.notify_all();
   }
-
-  /// count of remaining chunk to be loaded. It is initialized with the total
-  /// chunk count and it decreases when a chunk data is retrieved. When this reaches
-  /// to 0, no more chunk needs to be loaded.
-  size_t remaining_chunk_count_ = 0;
-
+  
   /// The batch size is needed to create batches from the chunk data. Similar to
   /// regular dataloader where the batches are created with prefetches,
   /// BatchDataBuffer perform the batch creation using the provided batch size.
@@ -310,8 +289,8 @@ class ChunkDataset final
         chunk_sampler_(std::move(chunk_sampler)),
         example_sampler_(std::move(example_sampler)),
         options_(std::move(options)),
-        quit_worker_(false) {
-  }
+        quit_worker_(false),
+        running_preloaders_(0) {}
 
   virtual ~ChunkDataset() {
     free_workers();
@@ -344,14 +323,12 @@ class ChunkDataset final
 
     chunk_reader_.reset();
 
-    size_t chunks_to_load = chunk_reader_.chunk_count();
-    chunk_sampler_.reset(chunks_to_load);
+    chunk_sampler_.reset(chunk_reader_.chunk_count());
 
     // Throw out any existing cached batch in the buffer and re-creates a new
     // chunk buffer.
     batch_buffer_ = torch::make_unique<
         detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
-        chunks_to_load,
         options_.batch_size_,
         example_sampler_,
         options_.cache_size_);
@@ -359,9 +336,10 @@ class ChunkDataset final
     // create new workers for this new epoch.
     quit_worker_ = false;
 
+    AT_ASSERT(running_preloaders_ == 0);
     for (size_t i = 0; i < options_.preloader_count_; ++i) {
-      preload_threads_.emplace_back(
-          [this, i]() { this->preloader(i); });
+      preload_threads_.emplace_back([this, i]() { this->preloader(i); });
+      ++running_preloaders_;
     }
   }
 
@@ -370,39 +348,45 @@ class ChunkDataset final
     return torch::nullopt;
   }
 
+  // provide a references to chunk sampler. Used mainly in distributed data
+  // loading to set the epoch number for the sampler.
+  ChunkSamplerType& chunk_sampler() {
+    return chunk_sampler_;
+  }
+
  private:
   /// running on worker thread to preload chunk data.
   void preloader(size_t id) {
     while (!quit_worker_.load()) {
       try {
         size_t chunk_id = 0;
-        if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
-          chunk_id = chunk_sampler_result.value()[0];
-        } else {
-          break;
+        {
+          std::lock_guard<std::mutex> lock(chunk_index_guard_);
+          if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
+            chunk_id = chunk_sampler_result.value()[0];
+          } else {
+            break;
+          }
         }
         UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
-        if (data.empty()) {
-          // if the chunk is empty, skip the current chunk data and move on to
-          // the next.
-          batch_buffer_->skip_chunk();
-        }
-        else {
+        if (!data.empty()) { // skip empty chunks.
           batch_buffer_->add_chunk_data(std::move(data));
         }
       } catch (...) {
         batch_buffer_->add_chunk_data(std::current_exception());
       }
     }
+    --running_preloaders_;
+    if (running_preloaders_.load() == 0) {
+      // all preloaders are completed, so we can notify the batch_buffer.
+      batch_buffer_->stop();
+    }
   }
 
   /// Block the current thread until the workers finish execution and exit.
   void free_workers() {
     if (!quit_worker_.load()) {
-      quit_worker_ = true;
-      if(batch_buffer_){
-        batch_buffer_->stop();
-      }
+      quit_worker_ = true;      
       for (auto& worker_thread : preload_threads_) {
         worker_thread.join();
       }
@@ -416,7 +400,7 @@ class ChunkDataset final
   ChunkReader chunk_reader_;
 
   // chunk sampler to shuffle different chunks
-  samplers::LockedSampler<ChunkSamplerType> chunk_sampler_;
+  ChunkSamplerType chunk_sampler_;
 
   // example sampler to shuffle examples in a specific chunk
   ExampleSamplerType example_sampler_;
@@ -433,6 +417,13 @@ class ChunkDataset final
 
   // indicate whether the worker thread can be teared down
   std::atomic<bool> quit_worker_;
+
+  // keep track of running preloaders to notify batch buffer. A value 0
+  // indicates that the chunk loading is completed.
+  std::atomic<size_t> running_preloaders_;
+
+  // mutex to synchronize chunk sampler next() call.
+  std::mutex chunk_index_guard_;
 };
 } // namespace datasets
 } // namespace data
index d576208..b9f8ac3 100644 (file)
@@ -42,41 +42,6 @@ class Sampler {
   TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
 };
 
-/// Wraps a provided sampler to make it thread safe.
-template <typename OriginalSampler>
-class LockedSampler
-    : public Sampler<typename OriginalSampler::BatchRequestType> {
- public:
-  using BatchRequestType = typename OriginalSampler::BatchRequestType;
-
-  explicit LockedSampler(OriginalSampler sampler) : sampler_(std::move(sampler)) {}
-
-  void reset(optional<size_t> new_size) override {
-    std::lock_guard<std::mutex> lock(this->mutex_);
-    sampler_.reset(new_size);
-  }
-
-  optional<BatchRequestType> next(size_t batch_size) override {
-    std::lock_guard<std::mutex> lock(this->mutex_);
-    return sampler_.next(batch_size);
-  }
-
-  void save(serialize::OutputArchive& archive) const override {
-    std::lock_guard<std::mutex> lock(this->mutex_);
-    sampler_.save(archive);
-  }
-
-  void load(serialize::InputArchive& archive) override {
-    std::lock_guard<std::mutex> lock(this->mutex_);
-    sampler_.load(archive);
-  }
-
- private:
-  // member variable for multi-threading lock.
-  // declare it to be mutable for locking in const member function.
-  mutable std::mutex mutex_;
-  OriginalSampler sampler_;
-};
 } // namespace samplers
 } // namespace data
 } // namespace torch