// 3 chunks, and when exhausted the value is already incremented.
ASSERT_EQ(chunk_sampler.index(), 3);
}
+
+TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
+ const size_t prefetch_count = 2;
+ const size_t batch_size = 5;
+ // this will make the preloaders to wait till the `get_batch()` calls.
+ const size_t cache_size = 10;
+
+ DummyChunkDataReader data_reader;
+ samplers::SequentialSampler sampler(0);
+ datasets::SharedBatchDataset<datasets::ChunkDataset<
+ DummyChunkDataReader,
+ samplers::SequentialSampler,
+ samplers::SequentialSampler>>
+ dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
+ DummyChunkDataReader,
+ samplers::SequentialSampler,
+ samplers::SequentialSampler>>(
+ data_reader,
+ sampler,
+ sampler,
+ datasets::ChunkDatasetOptions(
+ prefetch_count, batch_size, cache_size));
+
+ samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
+
+ auto data_loader = torch::data::make_data_loader(
+ dataset.map(transforms::BatchLambda<std::vector<int>, int>(
+ [](std::vector<int> batch) {
+ return std::accumulate(batch.begin(), batch.end(), 0);
+ })),
+ DataLoaderOptions(batch_size).workers(0));
+ // simply creates the iterator but no iteration. chunk preloaders are waiting
+ // to fill the batch buffer but it is not draining. Still we need to exit
+ // cleanly.
+ auto iterator = data_loader->begin();
+}
running_preloaders_(0) {}
virtual ~ChunkDataset() {
+ // stop batch buffer first.
+ if (batch_buffer_) {
+ batch_buffer_->stop();
+ }
free_workers();
}
/// This will clear any internal state and starts the internal prefetching
/// mechanism for the chunk dataset.
void reset() override {
+ // We need this to support partial data reads via dataloader iterator.
+ if (batch_buffer_) {
+ batch_buffer_->stop();
+ }
// free workers from previous reset if there is any.
free_workers();
preload_threads_.clear();