}
}
}
+
+TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
+ const size_t prefetch_count = 2;
+ const size_t batch_size = 5;
+
+ DummyChunkDataReader data_reader;
+ samplers::SequentialSampler sampler(0);
+ datasets::SharedBatchDataset<datasets::ChunkDataset<
+ DummyChunkDataReader,
+ samplers::SequentialSampler,
+ samplers::SequentialSampler>>
+ dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
+ DummyChunkDataReader,
+ samplers::SequentialSampler,
+ samplers::SequentialSampler>>(
+ data_reader,
+ sampler,
+ sampler,
+ datasets::ChunkDatasetOptions(prefetch_count, batch_size));
+
+ samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
+
+ auto data_loader = torch::data::make_data_loader(
+ dataset.map(transforms::BatchLambda<std::vector<int>, int>(
+ [](std::vector<int> batch) {
+ return std::accumulate(batch.begin(), batch.end(), 0);
+ })),
+ DataLoaderOptions(batch_size).workers(0));
+
+ // before we start, the index should be 0.
+ ASSERT_EQ(chunk_sampler.index(), 0);
+
+ size_t sum = 0;
+ for (auto iterator = data_loader->begin(); iterator != data_loader->end();
+ ++iterator) {
+ sum += *iterator;
+ }
+ ASSERT_EQ(sum, 595); // sum([0, 35))
+ // 3 chunks, and when exhausted the value is already incremented.
+ ASSERT_EQ(chunk_sampler.index(), 3);
+}
using BatchRequestType = typename ExampleSampler::BatchRequestType;
BatchDataBuffer(
- size_t num_chunks,
size_t batch_size,
ExampleSampler& example_sampler,
size_t queue_capacity)
- : remaining_chunk_count_(num_chunks),
- batch_size_(batch_size),
+ : batch_size_(batch_size),
example_sampler_(example_sampler),
queue_capacity_(queue_capacity),
stop_(false) {}
// loaded (i.e. the dataset is exhausted for this epoch)
return (
this->total_example_count_in_queue_ >= batch_size_ ||
- this->remaining_chunk_count_ == 0);
+ this->stop_.load());
});
if (batch_queue_.empty()) {
- AT_ASSERT(remaining_chunk_count_ == 0);
-
+ AT_ASSERT(this->stop_.load());
// All batches have been retrieved. Return an empty batch.
return nullopt;
}
return batch.batch_data;
}
- // skip one chunk
- void skip_chunk() {
- std::unique_lock<std::mutex> lock(queue_mutex_);
- AT_ASSERT(remaining_chunk_count_ > 0);
- remaining_chunk_count_--;
- lock.unlock();
- cv_read_.notify_all();
- }
-
/// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
/// threads.
void add_chunk_data(UnwrappedBatchType data) {
std::unique_lock<std::mutex> lock(queue_mutex_);
cv_write_.wait(lock, [this] {
// stop loading if we have preloaded enough data.
- return this->total_example_count_in_queue_ < this->queue_capacity_ || stop_.load();
+ return this->total_example_count_in_queue_ < this->queue_capacity_ ||
+ stop_.load();
});
if (stop_.load()){
- // When stop_ is true, it means this current thread needs to be tore down.
+ // When stop_ is true, it means no further chunk loading is necessary.
// Return without any further processing.
return;
}
batch_queue_.emplace(std::move(current_batch));
}
total_example_count_in_queue_ += data_size;
-
- AT_ASSERT(remaining_chunk_count_ > 0);
- remaining_chunk_count_--;
-
lock.unlock();
cv_read_.notify_all();
}
}
batch_queue_.emplace(e_ptr);
-
- AT_ASSERT(remaining_chunk_count_ > 0);
- remaining_chunk_count_--;
lock.unlock();
cv_read_.notify_all();
}
// notify all writers, wake them from wait to exit current method.
cv_write_.notify_all();
+ // notify all readers too.
+ cv_read_.notify_all();
}
-
- /// count of remaining chunk to be loaded. It is initialized with the total
- /// chunk count and it decreases when a chunk data is retrieved. When this reaches
- /// to 0, no more chunk needs to be loaded.
- size_t remaining_chunk_count_ = 0;
-
+
/// The batch size is needed to create batches from the chunk data. Similar to
/// regular dataloader where the batches are created with prefetches,
/// BatchDataBuffer perform the batch creation using the provided batch size.
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
- quit_worker_(false) {
- }
+ quit_worker_(false),
+ running_preloaders_(0) {}
virtual ~ChunkDataset() {
free_workers();
chunk_reader_.reset();
- size_t chunks_to_load = chunk_reader_.chunk_count();
- chunk_sampler_.reset(chunks_to_load);
+ chunk_sampler_.reset(chunk_reader_.chunk_count());
// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
batch_buffer_ = torch::make_unique<
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
- chunks_to_load,
options_.batch_size_,
example_sampler_,
options_.cache_size_);
// create new workers for this new epoch.
quit_worker_ = false;
+ AT_ASSERT(running_preloaders_ == 0);
for (size_t i = 0; i < options_.preloader_count_; ++i) {
- preload_threads_.emplace_back(
- [this, i]() { this->preloader(i); });
+ preload_threads_.emplace_back([this, i]() { this->preloader(i); });
+ ++running_preloaders_;
}
}
return torch::nullopt;
}
+ // provide a references to chunk sampler. Used mainly in distributed data
+ // loading to set the epoch number for the sampler.
+ ChunkSamplerType& chunk_sampler() {
+ return chunk_sampler_;
+ }
+
private:
/// running on worker thread to preload chunk data.
void preloader(size_t id) {
while (!quit_worker_.load()) {
try {
size_t chunk_id = 0;
- if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
- chunk_id = chunk_sampler_result.value()[0];
- } else {
- break;
+ {
+ std::lock_guard<std::mutex> lock(chunk_index_guard_);
+ if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
+ chunk_id = chunk_sampler_result.value()[0];
+ } else {
+ break;
+ }
}
UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
- if (data.empty()) {
- // if the chunk is empty, skip the current chunk data and move on to
- // the next.
- batch_buffer_->skip_chunk();
- }
- else {
+ if (!data.empty()) { // skip empty chunks.
batch_buffer_->add_chunk_data(std::move(data));
}
} catch (...) {
batch_buffer_->add_chunk_data(std::current_exception());
}
}
+ --running_preloaders_;
+ if (running_preloaders_.load() == 0) {
+ // all preloaders are completed, so we can notify the batch_buffer.
+ batch_buffer_->stop();
+ }
}
/// Block the current thread until the workers finish execution and exit.
void free_workers() {
if (!quit_worker_.load()) {
- quit_worker_ = true;
- if(batch_buffer_){
- batch_buffer_->stop();
- }
+ quit_worker_ = true;
for (auto& worker_thread : preload_threads_) {
worker_thread.join();
}
ChunkReader chunk_reader_;
// chunk sampler to shuffle different chunks
- samplers::LockedSampler<ChunkSamplerType> chunk_sampler_;
+ ChunkSamplerType chunk_sampler_;
// example sampler to shuffle examples in a specific chunk
ExampleSamplerType example_sampler_;
// indicate whether the worker thread can be teared down
std::atomic<bool> quit_worker_;
+
+ // keep track of running preloaders to notify batch buffer. A value 0
+ // indicates that the chunk loading is completed.
+ std::atomic<size_t> running_preloaders_;
+
+ // mutex to synchronize chunk sampler next() call.
+ std::mutex chunk_index_guard_;
};
} // namespace datasets
} // namespace data
TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
};
-/// Wraps a provided sampler to make it thread safe.
-template <typename OriginalSampler>
-class LockedSampler
- : public Sampler<typename OriginalSampler::BatchRequestType> {
- public:
- using BatchRequestType = typename OriginalSampler::BatchRequestType;
-
- explicit LockedSampler(OriginalSampler sampler) : sampler_(std::move(sampler)) {}
-
- void reset(optional<size_t> new_size) override {
- std::lock_guard<std::mutex> lock(this->mutex_);
- sampler_.reset(new_size);
- }
-
- optional<BatchRequestType> next(size_t batch_size) override {
- std::lock_guard<std::mutex> lock(this->mutex_);
- return sampler_.next(batch_size);
- }
-
- void save(serialize::OutputArchive& archive) const override {
- std::lock_guard<std::mutex> lock(this->mutex_);
- sampler_.save(archive);
- }
-
- void load(serialize::InputArchive& archive) override {
- std::lock_guard<std::mutex> lock(this->mutex_);
- sampler_.load(archive);
- }
-
- private:
- // member variable for multi-threading lock.
- // declare it to be mutable for locking in const member function.
- mutable std::mutex mutex_;
- OriginalSampler sampler_;
-};
} // namespace samplers
} // namespace data
} // namespace torch