Support stateful dataset (#15096)
authorPeter Goldsborough <psag@fb.com>
Mon, 24 Dec 2018 14:23:32 +0000 (06:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 24 Dec 2018 14:26:40 +0000 (06:26 -0800)
Summary:
Currently re-implements the dataloader for stateful datasets. Outstanding work:
- Refactor DataLoader and DataLoader2 to have common base classes and only differ in specifi pieces of logic,
- Figure out how to not duplicate the `MapDataset` logic for stateful vs. non-stateful
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15096

Differential Revision: D13522043

Pulled By: goldsborough

fbshipit-source-id: 08e461ca51783047f11facc4d27dfa2e4f1e4c2a

12 files changed:
setup.py
test/cpp/api/dataloader.cpp
torch/csrc/api/include/torch/data/dataloader.h
torch/csrc/api/include/torch/data/dataloader/base.h [new file with mode: 0644]
torch/csrc/api/include/torch/data/dataloader/stateful.h [new file with mode: 0644]
torch/csrc/api/include/torch/data/dataloader/stateless.h [new file with mode: 0644]
torch/csrc/api/include/torch/data/dataloader_options.h
torch/csrc/api/include/torch/data/datasets.h
torch/csrc/api/include/torch/data/datasets/base.h
torch/csrc/api/include/torch/data/datasets/map.h
torch/csrc/api/include/torch/data/datasets/shared.h
torch/csrc/api/include/torch/data/datasets/stateful.h [new file with mode: 0644]

index 4d327ba..5759ab2 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -962,6 +962,7 @@ if __name__ == '__main__':
                 'lib/include/torch/csrc/*.h',
                 'lib/include/torch/csrc/api/include/torch/*.h',
                 'lib/include/torch/csrc/api/include/torch/data/*.h',
+                'lib/include/torch/csrc/api/include/torch/data/dataloader/*.h',
                 'lib/include/torch/csrc/api/include/torch/data/datasets/*.h',
                 'lib/include/torch/csrc/api/include/torch/data/detail/*.h',
                 'lib/include/torch/csrc/api/include/torch/data/samplers/*.h',
index 0d8d546..461dfe5 100644 (file)
@@ -84,8 +84,8 @@ TEST(DataTest, InfiniteStreamDataset) {
 
   auto data_loader = torch::data::make_data_loader(
       std::move(dataset),
-      kBatchSize,
-      samplers::StreamSampler(/*epoch_size=*/39));
+      samplers::StreamSampler(/*epoch_size=*/39),
+      kBatchSize);
 
   size_t batch_index = 0;
   for (auto& batch : *data_loader) {
@@ -128,10 +128,10 @@ TEST(DataTest, OrderedSequencerReOrdersValues) {
   size_t index = 0;
   auto getter = [&v, &index]() { return S{v.at(index++)}; };
 
-  // Let's say the sequence number matches for the first one, then it should
+  // Let's say the sequence number matches for the batch one, then it should
   // return immediately.
-  const auto first = sequencer.next(getter);
-  ASSERT_EQ(first.value().sequence_number, 0);
+  const auto batch = sequencer.next(getter);
+  ASSERT_EQ(batch.value().sequence_number, 0);
   ASSERT_EQ(index, 1);
 
   // Now it should call the getter until it gets the next value.
@@ -385,9 +385,9 @@ TEST(DataTest, StackTransformWorksForExample) {
 
   auto d = D().map(transforms::Stack<Example<>>());
 
-  Example<> first = d.get_batch({0, 1});
-  ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
-  ASSERT_TRUE(first.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
+  Example<> batch = d.get_batch({0, 1});
+  ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
+  ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
 
   Example<> second = d.get_batch({2, 3});
   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
@@ -398,8 +398,8 @@ TEST(DataTest, StackTransformWorksForTensorExample) {
   auto d = datasets::TensorDataset(torch::eye(4))
                .map(transforms::Stack<TensorExample>());
 
-  TensorExample first = d.get_batch({0, 1});
-  ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
+  TensorExample batch = d.get_batch({0, 1});
+  ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
 
   TensorExample second = d.get_batch({2, 3});
   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
@@ -504,7 +504,7 @@ TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
   using torch::data::detail::Queue;
 
-  // First test: push first and the pop in thread.
+  // First test: push batch and the pop in thread.
   {
     Queue<int> queue;
     queue.push(1);
@@ -513,7 +513,7 @@ TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
     ASSERT_EQ(future.get(), 1);
   }
 
-  // Second test: attempt to pop first (and block), then push.
+  // Second test: attempt to pop batch (and block), then push.
   {
     Queue<int> queue;
     std::thread thread([&queue] {
@@ -544,7 +544,7 @@ TEST(DataTest, DataShuttleCanPushAndPopJob) {
 
 TEST(DataTest, DataShuttleCanPushAndPopResult) {
   torch::data::detail::DataShuttle<int, int> shuttle;
-  // pop_result() will only attempt to pop if there was a push_job() first.
+  // pop_result() will only attempt to pop if there was a push_job() batch.
   shuttle.push_job(1);
   shuttle.push_job(2);
 
@@ -672,9 +672,9 @@ struct TestIndexSampler : public samplers::Sampler<TestIndex> {
 };
 
 TEST(DataTest, CanUseCustomTypeAsIndexType) {
-  const size_t kBatchSize = 10;
+  const int kBatchSize = 10;
   auto data_loader = torch::data::make_data_loader(
-      TestIndexDataset(23), kBatchSize, TestIndexSampler(23));
+      TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
 
   size_t i = 0;
   for (auto batch : *data_loader) {
@@ -948,7 +948,7 @@ TEST(DataLoaderTest, RespectsTimeout) {
   ASSERT_LT(duration.count(), 1);
 }
 
-// https://stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
+// stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
 struct Barrier {
   explicit Barrier(size_t target) : counter_(target) {}
   void wait() {
@@ -973,12 +973,12 @@ struct Barrier {
 // thread (for outside consumption) is not deterministic. Imagine the sampler is
 // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
 // will be a single "job". Inside the dataloader, worker threads block until a
-// job is available. It is not deterministic which worker thread wakes up first
+// job is available. It is not deterministic which worker thread wakes up batch
 // to dequeue a particular batch. Further, some worker threads may take longer
 // than others to read the data for their index. As such, it could be that
 // worker thread 2 finishes before all other threads and returns its batch to
 // the main thread. In that case, the dataloader iterator would return the datum
-// at index 2 first, and afterwards the datum from whatever thread finishes
+// at index 2 batch, and afterwards the datum from whatever thread finishes
 // next. As such, the user may see data from indices 2, 0, 3, 1. On another run
 // of the same dataloader on the same data, threads may be scheduled differently
 // and return in order 0, 2, 3, 1. To force this ordering to deterministically
@@ -996,7 +996,7 @@ struct Barrier {
 // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
 // has a copy of the dataset, and thus `get_batch()` is called on the
 // thread-local copy in each worker. We want to simulate out-of-order completion
-// of these threads. For this, we first set a barrier in the `get_batch()`
+// of these threads. For this, we batch set a barrier in the `get_batch()`
 // method to make sure every worker has some index to fetch assigned. Further,
 // each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
 // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
@@ -1057,12 +1057,11 @@ struct Dataset : datasets::BatchDataset<Dataset, size_t> {
 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
   auto data_loader = torch::data::make_data_loader(
       ordering_test::Dataset{},
+      torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
       DataLoaderOptions()
           .batch_size(1)
           .workers(ordering_test::kNumberOfWorkers)
-          .enforce_ordering(true),
-      torch::data::samplers::SequentialSampler(
-          ordering_test::kNumberOfWorkers));
+          .enforce_ordering(true));
   std::vector<size_t> output;
   for (size_t value : *data_loader) {
     output.push_back(value);
@@ -1104,8 +1103,8 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
     }
   };
 
-  auto data_loader =
-      torch::data::make_data_loader(D{}, DataLoaderOptions().workers(2));
+  auto data_loader = torch::data::make_data_loader(
+      D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
   auto iterator = data_loader->begin();
 
   try {
@@ -1119,3 +1118,159 @@ TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
         std::rethrow_exception(e.original_exception), std::invalid_argument);
   }
 }
+
+TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
+  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+  struct D : datasets::StatefulDataset<D, int, size_t> {
+    torch::optional<int> get_batch(size_t) override {
+      if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+        return counter++;
+      }
+      return torch::nullopt;
+    }
+    torch::optional<size_t> size() const override {
+      return 100;
+    }
+    void reset() override {
+      counter = 0;
+    }
+    int counter = 0;
+  };
+
+  auto data_loader = torch::data::make_data_loader(D{});
+
+  for (size_t i = 0; i < 10; ++i) {
+    const auto number_of_iterations =
+        std::distance(data_loader->begin(), data_loader->end());
+    ASSERT_EQ(
+        number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+        << "epoch " << i;
+  }
+
+  for (const int i : *data_loader) {
+    ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
+  }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
+  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+  const int kNumberOfWorkers = 4;
+
+  struct D : datasets::StatefulDataset<D, int, size_t> {
+    torch::optional<int> get_batch(size_t) override {
+      std::lock_guard<std::mutex> lock(mutex);
+      if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+        return counter++;
+      }
+      return torch::nullopt;
+    }
+    torch::optional<size_t> size() const override {
+      return 100;
+    }
+    void reset() override {
+      counter = 0;
+    }
+    int counter = 0;
+    std::mutex mutex;
+  };
+
+  auto data_loader = torch::data::make_data_loader(
+      torch::data::datasets::make_shared_dataset<D>(),
+      DataLoaderOptions().workers(kNumberOfWorkers));
+
+  for (size_t i = 0; i < 10; ++i) {
+    const auto number_of_iterations =
+        std::distance(data_loader->begin(), data_loader->end());
+    ASSERT_EQ(
+        number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+        << "epoch " << i;
+  }
+
+  for (const int i : *data_loader) {
+    ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
+  }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithMap) {
+  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+  struct D : datasets::StatefulDataset<D, int, size_t> {
+    torch::optional<int> get_batch(size_t) override {
+      if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+        return counter++;
+      }
+      return torch::nullopt;
+    }
+    torch::optional<size_t> size() const override {
+      return 100;
+    }
+    void reset() override {
+      counter = 0;
+    }
+    int counter = 0;
+  };
+
+  auto data_loader = torch::data::make_data_loader(
+      D().map(transforms::BatchLambda<int, std::string>(
+                  [](int x) { return std::to_string(x); }))
+          .map(transforms::BatchLambda<std::string, torch::Tensor>(
+              [](const std::string& x) {
+                return torch::tensor(static_cast<int64_t>(std::stoi(x)));
+              })),
+      DataLoaderOptions{});
+
+  for (size_t i = 0; i < 10; ++i) {
+    const auto number_of_iterations =
+        std::distance(data_loader->begin(), data_loader->end());
+    ASSERT_EQ(
+        number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+        << "epoch " << i;
+  }
+
+  for (const torch::Tensor& t : *data_loader) {
+    ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
+  }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithCollate) {
+  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+  struct D : datasets::StatefulDataset<D> {
+    torch::optional<std::vector<Example<>>> get_batch(
+        size_t batch_size) override {
+      if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+        counter += batch_size;
+        std::vector<Example<>> batch(
+            /*count=*/batch_size,
+            Example<>{torch::ones(batch_size + 1),
+                      torch::zeros(batch_size - 1)});
+        return batch;
+      }
+      return torch::nullopt;
+    }
+    torch::optional<size_t> size() const override {
+      return 100;
+    }
+    void reset() override {
+      counter = 0;
+    }
+    int counter = 0;
+  };
+
+  auto d = D().map(transforms::Stack<Example<>>());
+
+  const size_t kBatchSize = 5;
+
+  // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
+  // the `Stack` collation stacks the tensors into one.
+  torch::optional<Example<>> batch = d.get_batch(kBatchSize);
+  ASSERT_TRUE(batch.has_value());
+  ASSERT_EQ(batch->data.size(0), kBatchSize);
+  ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
+  ASSERT_EQ(batch->target.size(0), kBatchSize);
+  ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
+
+  ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
+  ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
+}
index 61f0c2c..14da9ed 100644 (file)
@@ -1,12 +1,7 @@
 #pragma once
 
-#include <torch/data/dataloader_options.h>
-#include <torch/data/detail/data_shuttle.h>
-#include <torch/data/detail/sequencers.h>
-#include <torch/data/iterator.h>
-#include <torch/data/samplers/random.h>
-#include <torch/data/worker_exception.h>
-#include <torch/types.h>
+#include <torch/data/dataloader/stateful.h>
+#include <torch/data/dataloader/stateless.h>
 
 #include <torch/csrc/utils/memory.h>
 #include <torch/csrc/utils/variadic.h>
 #include <c10/util/Exception.h>
 
 #include <cstddef>
-#include <exception>
 #include <memory>
-#include <thread>
 #include <type_traits>
 #include <utility>
-#include <vector>
 
 namespace torch {
 namespace data {
-template <typename Dataset, typename Sampler>
-class DataLoader {
- public:
-  using Batch = typename Dataset::BatchType;
-  using BatchRequest = typename Sampler::BatchRequestType;
-
-  /// Constructs a new `DataLoader` from a `dataset` to sample from, `options`
-  /// to configure the `DataLoader` with, and a `sampler` that specifies the
-  /// sampling strategy.
-  DataLoader(Dataset dataset, DataLoaderOptions options, Sampler sampler)
-      : options_(std::move(options)),
-        sampler_(std::move(sampler)),
-        sequencer_(new_sequencer()) {
-    for (size_t w = 0; w < options_.workers; ++w) {
-      // Here we copy the dataset into the worker thread closure. Each worker
-      // has its own copy of the dataset. This means the dataset must be
-      // trivially copiable, or else we don't expect more than one worker to
-      // be in use.
-      workers_.emplace_back([this, dataset]() mutable {
-        this->worker_thread(std::move(dataset));
-      });
-    }
-    if (options_.workers == 0) {
-      main_thread_dataset_ = torch::make_unique<Dataset>(std::move(dataset));
-    }
-  }
-
-  virtual ~DataLoader() {
-    join();
-  }
-
-  /// Returns an iterator into the `DataLoader`. The lifetime of the iterator is
-  /// bound to the `DataLoader`. In C++ standards language, the category of the
-  /// iterator is `OutputIterator`. See
-  /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
-  /// means. In short: you may increment the iterator and dereference it, but
-  /// cannot go back, or step forward more than one position at a time. When the
-  /// `DataLoader` is exhausted, it will compare equal with the special
-  /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
-  /// should only use range-for loops to loop over the `DataLoader`, but
-  /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
-  /// output_iterator)`  are supported too.
-  Iterator<Batch> begin() {
-    AT_CHECK(
-        shuttle_.in_flight_jobs() == 0,
-        "Attempted to get a new DataLoader iterator "
-        "while another iterator is not yet exhausted");
-    reset();
-    return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
-        [this] { return this->next(); }));
-  }
-
-  /// Returns a special "sentinel" iterator that compares equal with a
-  /// non-sentinel iterator once the `DataLoader` is exhausted.
-  Iterator<Batch> end() {
-    return Iterator<Batch>(
-        torch::make_unique<detail::SentinelIterator<Batch>>());
-  }
-
-  /// Joins the `DataLoader`'s worker threads and drains internal queues.
-  /// This function may only be invoked from the main thread (in which the
-  /// `DataLoader` lives).
-  void join() {
-    if (joined_) {
-      return;
-    }
-    shuttle_.drain();
-    // Send one 'quit' message per worker. Since a worker dies (exits its
-    // thread) after receiving this message, each `QuitWorker()` message will be
-    // read by exactly one worker.
-    for (size_t w = 0; w < options_.workers; ++w) {
-      push_job(QuitWorker());
-    }
-    for (auto& worker : workers_) {
-      worker.join();
-    }
-    joined_ = true;
-  }
-
-  /// Returns the options with which the `DataLoader` was configured.
-  const FullDataLoaderOptions& options() const noexcept {
-    return options_;
-  }
-
-  /// Returns the sampler currently used by the `DataLoader`.
-  const Sampler& sampler() const noexcept {
-    return sampler_;
-  }
-
-  /// Returns the sampler currently used by the `DataLoader`.
-  Sampler& sampler() noexcept {
-    return sampler_;
-  }
-
- private:
-  /// Simple mix-in to give something a sequence number.
-  struct Sequenced {
-    Sequenced() = default;
-    Sequenced(size_t sqn) : sequence_number(sqn) {}
-    size_t sequence_number;
-  };
-
-  struct QuitWorker {};
-
-  /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
-  /// `QuitWorker` object, to indicate the worker should shut down.
-  struct Job : Sequenced {
-    Job() = default;
-    Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
-    Job(BatchRequest&& i, size_t sqn)
-        : Sequenced(sqn), batch_request(std::move(i)) {}
-    optional<QuitWorker> quit;
-    optional<BatchRequest> batch_request;
-  };
-
-  /// The finished result of a job.
-  struct Result : Sequenced {
-    Result() = default;
-    Result(Batch&& b, size_t sqn) : Sequenced(sqn), batch(std::move(b)) {}
-    Result(std::exception_ptr exception, size_t sqn)
-        : Sequenced(sqn), exception(std::move(exception)) {}
-    optional<Batch> batch;
-    std::exception_ptr exception;
-  };
-
-  /// Resets the internal state of the `DataLoader`, optionally pre-fetching
-  /// new jobs.
-  void reset(bool prefetch = true) {
-    shuttle_.drain();
-    sampler_.reset();
-    sequence_number_ = 0;
-    sequencer_ = new_sequencer();
-    if (prefetch) {
-      this->prefetch();
-    }
-  }
-
-  /// Schedules `requested_jobs` many new batches to be fetched. The actual
-  /// number of jobs scheduled may be less if the `DataLoader` exhausts.
-  void prefetch(size_t requested_jobs) {
-    while (requested_jobs-- > 0) {
-      if (auto batch_request = get_batch_request()) {
-        push_job(std::move(*batch_request));
-      } else {
-        break;
-      }
-    }
-  }
-
-  /// Schedules the maximum number of jobs (based on the `max_jobs` option).
-  void prefetch() {
-    prefetch(options_.max_jobs);
-  }
-
-  /// Returns the next batch of data, or an empty `optional` if the `DataLoader`
-  /// is exhausted. This operation will block until a batch is available.
-  optional<Batch> next() {
-    optional<Batch> batch;
-    if (options_.workers > 0) {
-      optional<Result> result = sequencer_->next(
-          [this] { return this->shuttle_.pop_result(this->options_.timeout); });
-      if (result) {
-        if (result->exception) {
-          throw WorkerException(result->exception);
-        } else {
-          AT_ASSERT(result->batch.has_value());
-          batch = std::move(result->batch);
-          prefetch(1);
-        }
-      }
-    } else if (auto batch_request = get_batch_request()) {
-      AT_ASSERT(main_thread_dataset_ != nullptr);
-      batch = main_thread_dataset_->get_batch(std::move(*batch_request));
-    }
-    return batch;
-  }
 
-  /// The function that worker threads run.
-  void worker_thread(Dataset dataset) {
-    while (true) {
-      auto job = shuttle_.pop_job();
-      if (job.quit) {
-        break;
-      }
-      try {
-        auto batch = dataset.get_batch(std::move(*job.batch_request));
-        shuttle_.push_result({std::move(batch), job.sequence_number});
-      } catch (...) {
-        shuttle_.push_result({std::current_exception(), job.sequence_number});
-      }
-    }
-  }
-
-  optional<BatchRequest> get_batch_request() {
-    auto indices = sampler_.next(options_.batch_size);
-    if (!indices ||
-        (indices->size() < options_.batch_size && options_.drop_last)) {
-      return nullopt;
-    }
-    AT_ASSERT(indices->size() > 0);
-    return indices;
-  }
-
-  template <typename T>
-  void push_job(T value) {
-    shuttle_.push_job({std::move(value), sequence_number_++});
-  }
-
-  std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
-    if (options_.enforce_ordering) {
-      return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
-          options_.max_jobs);
-    }
-    return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
-  }
-
-  /// The options the `DataLoader` was configured with.
-  const FullDataLoaderOptions options_;
-
-  /// The dataset for the main thread, only has a value if the number of
-  /// worker threads was configured as zero, meaning the main thread has to do
-  /// all the work (synchronously). NOTE: Really want this to be on the heap
-  /// when empty, therefore `unique_ptr` and not `optional`.
-  std::unique_ptr<Dataset> main_thread_dataset_;
-
-  /// The sampler with which new batch requests are created.
-  Sampler sampler_;
-
-  /// The sequence number for the *next* batch to be retrieved from the
-  /// dataset.
-  size_t sequence_number_ = 0;
-
-  /// The worker threads, running the `worker_thread()` method.
-  std::vector<std::thread> workers_;
-
-  /// The `DataShuttle` which takes care of the life cycle of a job.
-  detail::DataShuttle<Job, Result> shuttle_;
-
-  /// The `Sequencer`, which handles optional ordering of batches.
-  std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
-
-  /// True if the `DataLoader` has joined its worker threads.
-  bool joined_ = false;
-}; // namespace data
-
-/// Creates a new `DataLoader`, inferring the necessary template types from
-/// the given arguments.
+/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
+/// some `options`.
 template <typename Dataset, typename Sampler>
-std::unique_ptr<DataLoader<Dataset, Sampler>> make_data_loader(
-    Dataset dataset,
-    DataLoaderOptions options,
-    Sampler sampler) {
-  return torch::make_unique<DataLoader<Dataset, Sampler>>(
-      std::move(dataset), std::move(options), std::move(sampler));
+torch::disable_if_t<
+    Dataset::is_stateful,
+    std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
+make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
+  return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
+      std::move(dataset), std::move(sampler), std::move(options));
 }
 
-/// Creates a new `DataLoader`, inferring the necessary template types from
-/// the given arguments.
-template <
-    typename Sampler = samplers::RandomSampler,
-    typename Dataset,
-    typename =
-        torch::enable_if_t<std::is_constructible<Sampler, size_t>::value>>
-std::unique_ptr<DataLoader<Dataset, Sampler>> make_data_loader(
+/// Creates a `DataLoader` instance for a stateless `dataset` and some
+/// `options`. A sampler (by default a `RandomSampler`) will be constructed from
+/// the size of the dataset.
+template <typename Sampler = samplers::RandomSampler, typename Dataset>
+torch::disable_if_t<
+    Dataset::is_stateful || !std::is_constructible<Sampler, size_t>::value,
+    std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
+make_data_loader(
     Dataset dataset,
     DataLoaderOptions options = DataLoaderOptions()) {
   const optional<size_t> size = dataset.size();
@@ -295,8 +43,16 @@ std::unique_ptr<DataLoader<Dataset, Sampler>> make_data_loader(
       "Expected the dataset to be sized in "
       "order to construct the Sampler");
   return make_data_loader(
-      std::move(dataset), std::move(options), Sampler(*size));
+      std::move(dataset), Sampler(*size), std::move(options));
 }
 
+/// Creates a `DataLoader` for a stateful `dataset` and some `options`.
+template <typename Dataset, typename = torch::enable_if_t<Dataset::is_stateful>>
+std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
+    Dataset dataset,
+    DataLoaderOptions options = DataLoaderOptions()) {
+  return torch::make_unique<StatefulDataLoader<Dataset>>(
+      std::move(dataset), std::move(options));
+}
 } // namespace data
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/data/dataloader/base.h b/torch/csrc/api/include/torch/data/dataloader/base.h
new file mode 100644 (file)
index 0000000..1ec478b
--- /dev/null
@@ -0,0 +1,246 @@
+#pragma once
+
+#include <torch/data/dataloader_options.h>
+#include <torch/data/detail/data_shuttle.h>
+#include <torch/data/detail/sequencers.h>
+#include <torch/data/iterator.h>
+#include <torch/data/samplers/random.h>
+#include <torch/data/worker_exception.h>
+#include <torch/types.h>
+
+#include <torch/csrc/utils/memory.h>
+#include <torch/csrc/utils/variadic.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <exception>
+#include <memory>
+#include <thread>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace torch {
+namespace data {
+template <typename Dataset, typename Batch, typename BatchRequest>
+class DataLoaderBase {
+ public:
+  using BatchType = Batch;
+  using BatchRequestType = BatchRequest;
+
+  /// Constructs a new DataLoader from a `dataset` to sample from, `options`
+  /// to configure the DataLoader with, and a `sampler` that specifies the
+  /// sampling strategy.
+  DataLoaderBase(
+      DataLoaderOptions options,
+      std::unique_ptr<Dataset> main_thread_dataset = nullptr)
+      : options_(std::move(options)),
+        main_thread_dataset_(std::move(main_thread_dataset)),
+        sequencer_(new_sequencer()) {}
+
+  virtual ~DataLoaderBase() {
+    join();
+  }
+
+  /// Returns an iterator into the DataLoader. The lifetime of the iterator is
+  /// bound to the DataLoader. In C++ standards language, the category of the
+  /// iterator is `OutputIterator`. See
+  /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
+  /// means. In short: you may increment the iterator and dereference it, but
+  /// cannot go back, or step forward more than one position at a time. When the
+  /// DataLoader is exhausted, it will compare equal with the special
+  /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
+  /// should only use range-for loops to loop over the DataLoader, but
+  /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
+  /// output_iterator)`  are supported too.
+  Iterator<Batch> begin() {
+    AT_CHECK(
+        shuttle_.in_flight_jobs() == 0,
+        "Attempted to get a new DataLoader iterator "
+        "while another iterator is not yet exhausted");
+    reset();
+    return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
+        [this] { return this->next(); }));
+  }
+
+  /// Returns a special "sentinel" iterator that compares equal with a
+  /// non-sentinel iterator once the DataLoader is exhausted.
+  Iterator<Batch> end() {
+    return Iterator<Batch>(
+        torch::make_unique<detail::SentinelIterator<Batch>>());
+  }
+
+  /// Joins the DataLoader's worker threads and drains internal queues.
+  /// This function may only be invoked from the main thread (in which the
+  /// DataLoader lives).
+  void join() {
+    if (joined_) {
+      return;
+    }
+    shuttle_.drain();
+    // Send one 'quit' message per worker. Since a worker dies (exits its
+    // thread) after receiving this message, each `QuitWorker()` message will be
+    // read by exactly one worker.
+    for (size_t w = 0; w < options_.workers; ++w) {
+      push_job(QuitWorker());
+    }
+    for (auto& worker : workers_) {
+      worker.join();
+    }
+    joined_ = true;
+  }
+
+  /// Returns the options with which the DataLoader was configured.
+  const FullDataLoaderOptions& options() const noexcept {
+    return options_;
+  }
+
+ protected:
+  /// Simple mix-in to give something a sequence number.
+  struct Sequenced {
+    Sequenced() = default;
+    Sequenced(size_t sqn) : sequence_number(sqn) {}
+    size_t sequence_number;
+  };
+
+  struct QuitWorker {};
+
+  /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
+  /// `QuitWorker` object, to indicate the worker should shut down.
+  struct Job : Sequenced {
+    Job() = default;
+    Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
+    Job(BatchRequest&& i, size_t sqn)
+        : Sequenced(sqn), batch_request(std::move(i)) {}
+    optional<QuitWorker> quit;
+    optional<BatchRequest> batch_request;
+  };
+
+  /// The finished result of a job.
+  struct Result : Sequenced {
+    Result() = default;
+    Result(optional<Batch>&& b, size_t sqn)
+        : Sequenced(sqn), batch(std::move(b)) {}
+    Result(std::exception_ptr exception, size_t sqn)
+        : Sequenced(sqn), exception(std::move(exception)) {}
+    optional<Batch> batch;
+    std::exception_ptr exception;
+  };
+
+  /// Subclass hook for getting the next batch request. The stateless case will
+  /// ask the sampler for a new batch request (e.g. a vector of indices), while
+  /// the stateful one will simply return the batch size.
+  virtual optional<BatchRequestType> get_batch_request() = 0;
+
+  /// Resets the internal state of the DataLoader, optionally pre-fetching
+  /// new jobs.
+  virtual void reset() {
+    shuttle_.drain();
+    sequence_number_ = 0;
+    sequencer_ = new_sequencer();
+    prefetch();
+  }
+
+  /// Schedules `requested_jobs` many new batches to be fetched. The actual
+  /// number of jobs scheduled may be less if the DataLoader exhausts.
+  void prefetch(size_t requested_jobs) {
+    for (size_t r = 0; r < requested_jobs; ++r) {
+      if (auto batch_request = get_batch_request()) {
+        this->push_job(std::move(*batch_request));
+      } else {
+        break;
+      }
+    }
+  }
+
+  /// Schedules the maximum number of jobs (based on the `max_jobs` option).
+  void prefetch() {
+    prefetch(options_.max_jobs);
+  }
+
+  /// Returns the next batch of data, or an empty `optional` if the DataLoader
+  /// is exhausted. This operation will block until a batch is available if one
+  /// is still expected.
+  optional<BatchType> next() {
+    if (options_.workers > 0) {
+      while (optional<Result> result = this->pop_result()) {
+        if (result->exception) {
+          throw WorkerException(result->exception);
+        } else if (result->batch) {
+          prefetch(1);
+          return std::move(result->batch);
+        }
+      }
+    } else if (auto batch_request = get_batch_request()) {
+      return this->main_thread_dataset_->get_batch(std::move(*batch_request));
+    }
+    return nullopt;
+  }
+
+  /// The function that worker threads run.
+  void worker_thread(Dataset& dataset) {
+    while (true) {
+      auto job = shuttle_.pop_job();
+      if (job.quit) {
+        break;
+      }
+      try {
+        auto batch = dataset.get_batch(std::move(*job.batch_request));
+        shuttle_.push_result({std::move(batch), job.sequence_number});
+      } catch (...) {
+        shuttle_.push_result({std::current_exception(), job.sequence_number});
+      }
+    }
+  }
+
+  /// Convenience method that calls `shuttle_.push_job()` with the next sequence
+  /// number.
+  template <typename T>
+  void push_job(T value) {
+    shuttle_.push_job({std::move(value), sequence_number_++});
+  }
+
+  /// Convenience method that gets the next result from the sequencer.
+  optional<Result> pop_result() {
+    return sequencer_->next(
+        [this] { return this->shuttle_.pop_result(this->options_.timeout); });
+  }
+
+  /// Convenience method that creates a new sequencer based on the
+  /// `enforce_ordering` option.
+  std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
+    if (options_.enforce_ordering) {
+      return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
+          options_.max_jobs);
+    }
+    return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
+  }
+
+  /// The options the DataLoader was configured with.
+  const FullDataLoaderOptions options_;
+
+  /// The dataset for the main thread, only has a value if the number of
+  /// worker threads was configured as zero, meaning the main thread has to do
+  /// all the work (synchronously). NOTE: Really want this to be on the heap
+  /// when empty, therefore `unique_ptr` and not `optional`.
+  std::unique_ptr<Dataset> main_thread_dataset_;
+
+  /// The sequence number for the *next* batch to be retrieved from the
+  /// dataset.
+  size_t sequence_number_ = 0;
+
+  /// The worker threads, running the `worker_thread()` method.
+  std::vector<std::thread> workers_;
+
+  /// The `DataShuttle` which takes care of the life cycle of a job.
+  detail::DataShuttle<Job, Result> shuttle_;
+
+  /// The `Sequencer`, which handles optional ordering of batches.
+  std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
+
+  /// True if the DataLoader has joined its worker threads.
+  bool joined_ = false;
+};
+} // namespace data
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/data/dataloader/stateful.h b/torch/csrc/api/include/torch/data/dataloader/stateful.h
new file mode 100644 (file)
index 0000000..1c35c0b
--- /dev/null
@@ -0,0 +1,64 @@
+#pragma once
+
+#include <torch/data/dataloader/base.h>
+
+#include <cstddef>
+#include <thread>
+#include <utility>
+
+namespace torch {
+namespace data {
+
+/// A dataloader for stateful datasets.
+///
+/// A dataloader for stateful datatasets differs from one for stateless
+/// datasets one in that the dataset is shared among worker threads, and that
+/// this dataset is itself responsible for producing batches rather than
+/// depending on a sampler. The statefulness here actually refers to the
+/// dataset. The StatefulDataLoader simply alters the data loading algorithm to
+/// accomodate the stateful, shared nature of the dataset. Note that the dataset
+/// must be thread safe if more than one worker thread is used.
+///
+/// A stateful dataloader is created by calling `make_data_loader` with a
+/// stateful dataset.
+template <typename Dataset>
+class StatefulDataLoader : public DataLoaderBase<
+                               Dataset,
+                               typename Dataset::BatchType::value_type,
+                               typename Dataset::BatchRequestType> {
+ public:
+  using super = DataLoaderBase<
+      Dataset,
+      typename Dataset::BatchType::value_type,
+      typename Dataset::BatchRequestType>;
+  using typename super::BatchRequestType;
+
+  /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
+  StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
+      : super(
+            std::move(options),
+            torch::make_unique<Dataset>(std::move(dataset))) {
+    for (size_t w = 0; w < this->options_.workers; ++w) {
+      // As opposed to the stateless case, here all worker threads access the
+      // same underlying dataset.
+      this->workers_.emplace_back(
+          [this] { this->worker_thread(*this->main_thread_dataset_); });
+    }
+  }
+
+ private:
+  /// Resets the internal state of the dataloader and the dataset.
+  void reset() override {
+    this->main_thread_dataset_->reset();
+    // Call the base class method last because it calls `prefetch()`
+    super::reset();
+  }
+
+  /// For stateful datasets, the batch request is always the batch size. The
+  /// dataset is responsible for determining what goes into the batch next.
+  optional<BatchRequestType> get_batch_request() override {
+    return this->options_.batch_size;
+  }
+};
+} // namespace data
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/data/dataloader/stateless.h b/torch/csrc/api/include/torch/data/dataloader/stateless.h
new file mode 100644 (file)
index 0000000..fba7ccb
--- /dev/null
@@ -0,0 +1,82 @@
+#pragma once
+
+#include <torch/data/dataloader/base.h>
+#include <torch/data/worker_exception.h>
+
+#include <torch/csrc/utils/memory.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <thread>
+#include <utility>
+
+namespace torch {
+namespace data {
+
+/// A dataloader for stateless datasets.
+///
+/// This dataloader follows the traditional PyTorch dataloader design, whereby a
+/// (posssibly) stateful sampler produces *batch requests* for a stateless
+/// dataset, which acts as a simple batch request to batch mapping. The batch
+/// request will often be an array of indices, and if the dataset is a simple
+/// image dataset, the dataset would produce the images at those indices.
+template <typename Dataset, typename Sampler>
+class StatelessDataLoader : public DataLoaderBase<
+                                Dataset,
+                                typename Dataset::BatchType,
+                                typename Sampler::BatchRequestType> {
+ public:
+  using super = DataLoaderBase<
+      Dataset,
+      typename Dataset::BatchType,
+      typename Sampler::BatchRequestType>;
+  using typename super::BatchRequestType;
+
+  /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
+  /// some `options`.
+  StatelessDataLoader(
+      Dataset dataset,
+      Sampler sampler,
+      DataLoaderOptions options)
+      : super(std::move(options)), sampler_(std::move(sampler)) {
+    for (size_t w = 0; w < this->options_.workers; ++w) {
+      // Here we copy the dataset into the worker thread closure. Each worker
+      // has its own copy of the dataset. This means the dataset must be
+      // trivially copiable, or else we don't expect more than one worker to
+      // be in use.
+      this->workers_.emplace_back(
+          [this, dataset]() mutable { this->worker_thread(dataset); });
+    }
+    if (this->options_.workers == 0) {
+      this->main_thread_dataset_ =
+          torch::make_unique<Dataset>(std::move(dataset));
+    }
+  }
+
+ private:
+  /// Resets the internal state of the dataloader and the sampler.
+  void reset() override {
+    sampler_.reset();
+    // Call the base class method last because it calls `prefetch()`
+    super::reset();
+  }
+
+  /// Queries the sampler for the next batch request (possibly progressing its
+  /// internal state).
+  optional<BatchRequestType> get_batch_request() override {
+    auto indices = sampler_.next(this->options_.batch_size);
+    if (!indices ||
+        (indices->size() < this->options_.batch_size &&
+         this->options_.drop_last)) {
+      return nullopt;
+    }
+    AT_ASSERT(indices->size() > 0);
+    return indices;
+  }
+
+  /// The `Sampler` used to produce batch requests.
+  Sampler sampler_;
+};
+} // namespace data
+} // namespace torch
index e16e2a5..4e0442f 100644 (file)
@@ -8,19 +8,9 @@
 
 namespace torch {
 namespace data {
-namespace detail {
-struct FullDataLoaderOptions;
-} // namespace detail
-} // namespace data
-} // namespace torch
-
-namespace torch {
-namespace data {
 
 /// Options to configure a `DataLoader`.
 struct DataLoaderOptions {
-  friend struct torch::data::detail::FullDataLoaderOptions;
-
   DataLoaderOptions() = default;
   /* implicit */ DataLoaderOptions(size_t batch_size)
       : batch_size_(batch_size) {}
index ba5f645..82c31fe 100644 (file)
@@ -4,4 +4,5 @@
 #include <torch/data/datasets/map.h>
 #include <torch/data/datasets/mnist.h>
 #include <torch/data/datasets/shared.h>
+#include <torch/data/datasets/stateful.h>
 #include <torch/data/datasets/tensor.h>
index 2ca1128..9d5ca15 100644 (file)
@@ -15,7 +15,7 @@ namespace torch {
 namespace data {
 namespace datasets {
 template <typename S, typename T>
-struct MapDataset;
+class MapDataset;
 template <typename D, typename T>
 MapDataset<D, T> map(D, T); // NOLINT
 } // namespace datasets
@@ -25,6 +25,12 @@ MapDataset<D, T> map(D, T); // NOLINT
 namespace torch {
 namespace data {
 namespace datasets {
+namespace detail {
+template <typename T>
+struct is_optional : std::false_type {};
+template <typename T>
+struct is_optional<optional<T>> : std::true_type {};
+} // namespace detail
 
 /// A dataset that can yield data only in batches.
 template <
@@ -36,6 +42,7 @@ class BatchDataset {
   using SelfType = Self;
   using BatchType = Batch;
   using BatchRequestType = BatchRequest;
+  constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
 
   virtual ~BatchDataset() = default;
 
index ec46681..ac8b6c8 100644 (file)
@@ -6,39 +6,94 @@
 #include <c10/util/ArrayRef.h>
 
 #include <cstddef>
+#include <type_traits>
 #include <utility>
 
 namespace torch {
 namespace data {
 namespace datasets {
+namespace detail {
+template <bool C, typename T>
+using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
+} // namespace detail
 
 /// A `MapDataset` is a dataset that applies a transform to a source dataset.
 template <typename SourceDataset, typename AppliedTransform>
-struct MapDataset : BatchDataset<
-                        MapDataset<SourceDataset, AppliedTransform>,
-                        typename AppliedTransform::OutputBatchType,
-                        typename SourceDataset::BatchRequestType> {
+class MapDataset : public BatchDataset<
+                       MapDataset<SourceDataset, AppliedTransform>,
+                       detail::optional_if_t<
+                           SourceDataset::is_stateful,
+                           typename AppliedTransform::OutputBatchType>,
+                       typename SourceDataset::BatchRequestType> {
+ public:
   using DatasetType = SourceDataset;
   using TransformType = AppliedTransform;
   using BatchRequestType = typename SourceDataset::BatchRequestType;
-  using OutputBatchType = typename TransformType::OutputBatchType;
+  using OutputBatchType = detail::optional_if_t<
+      SourceDataset::is_stateful,
+      typename AppliedTransform::OutputBatchType>;
 
   MapDataset(DatasetType dataset, TransformType transform)
-      : dataset(std::move(dataset)), transform(std::move(transform)) {}
+      : dataset_(std::move(dataset)), transform_(std::move(transform)) {}
 
   /// Gets a batch from the source dataset and applies the transform to it,
   /// returning the result.
   OutputBatchType get_batch(BatchRequestType indices) override {
-    return transform.apply_batch(dataset.get_batch(indices));
+    return get_batch_impl(std::move(indices));
   }
 
   /// Returns the size of the source dataset.
-  optional<size_t> size() const noexcept {
-    return dataset.size();
+  optional<size_t> size() const noexcept override {
+    return dataset_.size();
   }
 
-  SourceDataset dataset;
-  AppliedTransform transform;
+  /// Calls `reset()` on the underlying dataset.
+  /// NOTE: Stateless datasets do not have a reset() method, so a call to this
+  /// method will only compile for stateful datasets (which have a reset()
+  /// method).
+  void reset() {
+    dataset_.reset();
+  }
+
+  /// Returns the underlying dataset.
+  const SourceDataset& dataset() noexcept {
+    return dataset_;
+  }
+
+  /// Returns the transform being applied.
+  const AppliedTransform& transform() noexcept {
+    return transform_;
+  }
+
+ private:
+  /// The implementation of `get_batch()` for the stateless case, which simply
+  /// applies the transform to the output of `get_batch()` from the dataset.
+  template <
+      typename D = SourceDataset,
+      typename = torch::disable_if_t<D::is_stateful>>
+  OutputBatchType get_batch_impl(BatchRequestType indices) {
+    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
+  }
+
+  /// The implementation of `get_batch()` for the stateful case. Here, we follow
+  /// the semantics of `Optional.map()` in many functional languages, which
+  /// applies a transformation to the optional's content when the optional
+  /// contains a value, and returns a new optional (of a different type)  if the
+  /// original optional returned by `get_batch()` was empty.
+  template <typename D = SourceDataset>
+  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
+      BatchRequestType indices) {
+    if (auto batch = dataset_.get_batch(std::move(indices))) {
+      return transform_.apply_batch(std::move(*batch));
+    }
+    return nullopt;
+  }
+
+  /// The underlying dataset being transformed.
+  SourceDataset dataset_;
+
+  // The transformation that is applied to batches received from the dataset.
+  AppliedTransform transform_;
 };
 
 /// Creates a `MapDataset` with the given dataset and transform.
@@ -48,7 +103,10 @@ MapDataset<DatasetType, TransformType> map(
     TransformType transform) {
   static_assert(
       std::is_same<
-          typename DatasetType::BatchType,
+          typename std::conditional<
+              DatasetType::is_stateful,
+              typename DatasetType::BatchType::value_type,
+              typename DatasetType::BatchType>::type,
           typename TransformType::InputBatchType>::value,
       "BatchType type of dataset does not match input type of transform");
   return {std::move(dataset), std::move(transform)};
index d804304..0972147 100644 (file)
@@ -62,6 +62,11 @@ class SharedBatchDataset : public BatchDataset<
     return dataset_.get();
   }
 
+  /// Calls `reset()` on the underlying dataset.
+  void reset() {
+    dataset_->reset();
+  }
+
  private:
   std::shared_ptr<UnderlyingDataset> dataset_;
 };
diff --git a/torch/csrc/api/include/torch/data/datasets/stateful.h b/torch/csrc/api/include/torch/data/datasets/stateful.h
new file mode 100644 (file)
index 0000000..eba22cb
--- /dev/null
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <torch/data/datasets/base.h>
+#include <torch/data/example.h>
+
+#include <cstddef>
+#include <vector>
+
+namespace torch {
+namespace data {
+namespace datasets {
+
+/// A stateful dataset is a dataset that maintains some internal state, which
+/// will be `reset()` at the beginning of each epoch. Subclasses can override
+/// the `reset()` method to configure this behavior. Further, the return type of
+/// a stateful dataset's `get_batch()` method is always an `optional`. When the
+/// stateful dataset wants to indicate to the dataloader that its epoch has
+/// ended, it should return an empty optional. The dataloader knows to modify
+/// its implementation based on whether the dataset is stateless or stateful.
+///
+/// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
+/// type of `get_batch()`, which the subclass must override, will be
+/// `optional<T>` (i.e. the type specified in the `StatefulDataset` specialization is automatically boxed into an `optional` for the datast's `BatchType`).
+template <
+    typename Self,
+    typename Batch = std::vector<Example<>>,
+    typename BatchRequest = size_t>
+class StatefulDataset
+    : public BatchDataset<Self, optional<Batch>, BatchRequest> {
+ public:
+  /// Resets internal state of the dataset.
+  virtual void reset() = 0;
+};
+} // namespace datasets
+} // namespace data
+} // namespace torch