'lib/include/torch/csrc/*.h',
'lib/include/torch/csrc/api/include/torch/*.h',
'lib/include/torch/csrc/api/include/torch/data/*.h',
+ 'lib/include/torch/csrc/api/include/torch/data/dataloader/*.h',
'lib/include/torch/csrc/api/include/torch/data/datasets/*.h',
'lib/include/torch/csrc/api/include/torch/data/detail/*.h',
'lib/include/torch/csrc/api/include/torch/data/samplers/*.h',
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
- kBatchSize,
- samplers::StreamSampler(/*epoch_size=*/39));
+ samplers::StreamSampler(/*epoch_size=*/39),
+ kBatchSize);
size_t batch_index = 0;
for (auto& batch : *data_loader) {
size_t index = 0;
auto getter = [&v, &index]() { return S{v.at(index++)}; };
- // Let's say the sequence number matches for the first one, then it should
+ // Let's say the sequence number matches for the batch one, then it should
// return immediately.
- const auto first = sequencer.next(getter);
- ASSERT_EQ(first.value().sequence_number, 0);
+ const auto batch = sequencer.next(getter);
+ ASSERT_EQ(batch.value().sequence_number, 0);
ASSERT_EQ(index, 1);
// Now it should call the getter until it gets the next value.
auto d = D().map(transforms::Stack<Example<>>());
- Example<> first = d.get_batch({0, 1});
- ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
- ASSERT_TRUE(first.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
+ Example<> batch = d.get_batch({0, 1});
+ ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
+ ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
Example<> second = d.get_batch({2, 3});
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
auto d = datasets::TensorDataset(torch::eye(4))
.map(transforms::Stack<TensorExample>());
- TensorExample first = d.get_batch({0, 1});
- ASSERT_TRUE(first.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
+ TensorExample batch = d.get_batch({0, 1});
+ ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
TensorExample second = d.get_batch({2, 3});
ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
using torch::data::detail::Queue;
- // First test: push first and the pop in thread.
+ // First test: push batch and the pop in thread.
{
Queue<int> queue;
queue.push(1);
ASSERT_EQ(future.get(), 1);
}
- // Second test: attempt to pop first (and block), then push.
+ // Second test: attempt to pop batch (and block), then push.
{
Queue<int> queue;
std::thread thread([&queue] {
TEST(DataTest, DataShuttleCanPushAndPopResult) {
torch::data::detail::DataShuttle<int, int> shuttle;
- // pop_result() will only attempt to pop if there was a push_job() first.
+ // pop_result() will only attempt to pop if there was a push_job() batch.
shuttle.push_job(1);
shuttle.push_job(2);
};
TEST(DataTest, CanUseCustomTypeAsIndexType) {
- const size_t kBatchSize = 10;
+ const int kBatchSize = 10;
auto data_loader = torch::data::make_data_loader(
- TestIndexDataset(23), kBatchSize, TestIndexSampler(23));
+ TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
size_t i = 0;
for (auto batch : *data_loader) {
ASSERT_LT(duration.count(), 1);
}
-// https://stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
+// stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
struct Barrier {
explicit Barrier(size_t target) : counter_(target) {}
void wait() {
// thread (for outside consumption) is not deterministic. Imagine the sampler is
// a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
// will be a single "job". Inside the dataloader, worker threads block until a
-// job is available. It is not deterministic which worker thread wakes up first
+// job is available. It is not deterministic which worker thread wakes up batch
// to dequeue a particular batch. Further, some worker threads may take longer
// than others to read the data for their index. As such, it could be that
// worker thread 2 finishes before all other threads and returns its batch to
// the main thread. In that case, the dataloader iterator would return the datum
-// at index 2 first, and afterwards the datum from whatever thread finishes
+// at index 2 batch, and afterwards the datum from whatever thread finishes
// next. As such, the user may see data from indices 2, 0, 3, 1. On another run
// of the same dataloader on the same data, threads may be scheduled differently
// and return in order 0, 2, 3, 1. To force this ordering to deterministically
// `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
// has a copy of the dataset, and thus `get_batch()` is called on the
// thread-local copy in each worker. We want to simulate out-of-order completion
-// of these threads. For this, we first set a barrier in the `get_batch()`
+// of these threads. For this, we batch set a barrier in the `get_batch()`
// method to make sure every worker has some index to fetch assigned. Further,
// each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
// There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
auto data_loader = torch::data::make_data_loader(
ordering_test::Dataset{},
+ torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
DataLoaderOptions()
.batch_size(1)
.workers(ordering_test::kNumberOfWorkers)
- .enforce_ordering(true),
- torch::data::samplers::SequentialSampler(
- ordering_test::kNumberOfWorkers));
+ .enforce_ordering(true));
std::vector<size_t> output;
for (size_t value : *data_loader) {
output.push_back(value);
}
};
- auto data_loader =
- torch::data::make_data_loader(D{}, DataLoaderOptions().workers(2));
+ auto data_loader = torch::data::make_data_loader(
+ D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
auto iterator = data_loader->begin();
try {
std::rethrow_exception(e.original_exception), std::invalid_argument);
}
}
+
+TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
+ const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+ struct D : datasets::StatefulDataset<D, int, size_t> {
+ torch::optional<int> get_batch(size_t) override {
+ if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+ return counter++;
+ }
+ return torch::nullopt;
+ }
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+ void reset() override {
+ counter = 0;
+ }
+ int counter = 0;
+ };
+
+ auto data_loader = torch::data::make_data_loader(D{});
+
+ for (size_t i = 0; i < 10; ++i) {
+ const auto number_of_iterations =
+ std::distance(data_loader->begin(), data_loader->end());
+ ASSERT_EQ(
+ number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+ << "epoch " << i;
+ }
+
+ for (const int i : *data_loader) {
+ ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
+ }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
+ const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+ const int kNumberOfWorkers = 4;
+
+ struct D : datasets::StatefulDataset<D, int, size_t> {
+ torch::optional<int> get_batch(size_t) override {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+ return counter++;
+ }
+ return torch::nullopt;
+ }
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+ void reset() override {
+ counter = 0;
+ }
+ int counter = 0;
+ std::mutex mutex;
+ };
+
+ auto data_loader = torch::data::make_data_loader(
+ torch::data::datasets::make_shared_dataset<D>(),
+ DataLoaderOptions().workers(kNumberOfWorkers));
+
+ for (size_t i = 0; i < 10; ++i) {
+ const auto number_of_iterations =
+ std::distance(data_loader->begin(), data_loader->end());
+ ASSERT_EQ(
+ number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+ << "epoch " << i;
+ }
+
+ for (const int i : *data_loader) {
+ ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
+ }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithMap) {
+ const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+ struct D : datasets::StatefulDataset<D, int, size_t> {
+ torch::optional<int> get_batch(size_t) override {
+ if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+ return counter++;
+ }
+ return torch::nullopt;
+ }
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+ void reset() override {
+ counter = 0;
+ }
+ int counter = 0;
+ };
+
+ auto data_loader = torch::data::make_data_loader(
+ D().map(transforms::BatchLambda<int, std::string>(
+ [](int x) { return std::to_string(x); }))
+ .map(transforms::BatchLambda<std::string, torch::Tensor>(
+ [](const std::string& x) {
+ return torch::tensor(static_cast<int64_t>(std::stoi(x)));
+ })),
+ DataLoaderOptions{});
+
+ for (size_t i = 0; i < 10; ++i) {
+ const auto number_of_iterations =
+ std::distance(data_loader->begin(), data_loader->end());
+ ASSERT_EQ(
+ number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
+ << "epoch " << i;
+ }
+
+ for (const torch::Tensor& t : *data_loader) {
+ ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
+ }
+}
+
+TEST(DataLoaderTest, StatefulDatasetWithCollate) {
+ const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
+
+ struct D : datasets::StatefulDataset<D> {
+ torch::optional<std::vector<Example<>>> get_batch(
+ size_t batch_size) override {
+ if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
+ counter += batch_size;
+ std::vector<Example<>> batch(
+ /*count=*/batch_size,
+ Example<>{torch::ones(batch_size + 1),
+ torch::zeros(batch_size - 1)});
+ return batch;
+ }
+ return torch::nullopt;
+ }
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+ void reset() override {
+ counter = 0;
+ }
+ int counter = 0;
+ };
+
+ auto d = D().map(transforms::Stack<Example<>>());
+
+ const size_t kBatchSize = 5;
+
+ // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
+ // the `Stack` collation stacks the tensors into one.
+ torch::optional<Example<>> batch = d.get_batch(kBatchSize);
+ ASSERT_TRUE(batch.has_value());
+ ASSERT_EQ(batch->data.size(0), kBatchSize);
+ ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
+ ASSERT_EQ(batch->target.size(0), kBatchSize);
+ ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
+
+ ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
+ ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
+}
#pragma once
-#include <torch/data/dataloader_options.h>
-#include <torch/data/detail/data_shuttle.h>
-#include <torch/data/detail/sequencers.h>
-#include <torch/data/iterator.h>
-#include <torch/data/samplers/random.h>
-#include <torch/data/worker_exception.h>
-#include <torch/types.h>
+#include <torch/data/dataloader/stateful.h>
+#include <torch/data/dataloader/stateless.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <c10/util/Exception.h>
#include <cstddef>
-#include <exception>
#include <memory>
-#include <thread>
#include <type_traits>
#include <utility>
-#include <vector>
namespace torch {
namespace data {
-template <typename Dataset, typename Sampler>
-class DataLoader {
- public:
- using Batch = typename Dataset::BatchType;
- using BatchRequest = typename Sampler::BatchRequestType;
-
- /// Constructs a new `DataLoader` from a `dataset` to sample from, `options`
- /// to configure the `DataLoader` with, and a `sampler` that specifies the
- /// sampling strategy.
- DataLoader(Dataset dataset, DataLoaderOptions options, Sampler sampler)
- : options_(std::move(options)),
- sampler_(std::move(sampler)),
- sequencer_(new_sequencer()) {
- for (size_t w = 0; w < options_.workers; ++w) {
- // Here we copy the dataset into the worker thread closure. Each worker
- // has its own copy of the dataset. This means the dataset must be
- // trivially copiable, or else we don't expect more than one worker to
- // be in use.
- workers_.emplace_back([this, dataset]() mutable {
- this->worker_thread(std::move(dataset));
- });
- }
- if (options_.workers == 0) {
- main_thread_dataset_ = torch::make_unique<Dataset>(std::move(dataset));
- }
- }
-
- virtual ~DataLoader() {
- join();
- }
-
- /// Returns an iterator into the `DataLoader`. The lifetime of the iterator is
- /// bound to the `DataLoader`. In C++ standards language, the category of the
- /// iterator is `OutputIterator`. See
- /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
- /// means. In short: you may increment the iterator and dereference it, but
- /// cannot go back, or step forward more than one position at a time. When the
- /// `DataLoader` is exhausted, it will compare equal with the special
- /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
- /// should only use range-for loops to loop over the `DataLoader`, but
- /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
- /// output_iterator)` are supported too.
- Iterator<Batch> begin() {
- AT_CHECK(
- shuttle_.in_flight_jobs() == 0,
- "Attempted to get a new DataLoader iterator "
- "while another iterator is not yet exhausted");
- reset();
- return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
- [this] { return this->next(); }));
- }
-
- /// Returns a special "sentinel" iterator that compares equal with a
- /// non-sentinel iterator once the `DataLoader` is exhausted.
- Iterator<Batch> end() {
- return Iterator<Batch>(
- torch::make_unique<detail::SentinelIterator<Batch>>());
- }
-
- /// Joins the `DataLoader`'s worker threads and drains internal queues.
- /// This function may only be invoked from the main thread (in which the
- /// `DataLoader` lives).
- void join() {
- if (joined_) {
- return;
- }
- shuttle_.drain();
- // Send one 'quit' message per worker. Since a worker dies (exits its
- // thread) after receiving this message, each `QuitWorker()` message will be
- // read by exactly one worker.
- for (size_t w = 0; w < options_.workers; ++w) {
- push_job(QuitWorker());
- }
- for (auto& worker : workers_) {
- worker.join();
- }
- joined_ = true;
- }
-
- /// Returns the options with which the `DataLoader` was configured.
- const FullDataLoaderOptions& options() const noexcept {
- return options_;
- }
-
- /// Returns the sampler currently used by the `DataLoader`.
- const Sampler& sampler() const noexcept {
- return sampler_;
- }
-
- /// Returns the sampler currently used by the `DataLoader`.
- Sampler& sampler() noexcept {
- return sampler_;
- }
-
- private:
- /// Simple mix-in to give something a sequence number.
- struct Sequenced {
- Sequenced() = default;
- Sequenced(size_t sqn) : sequence_number(sqn) {}
- size_t sequence_number;
- };
-
- struct QuitWorker {};
-
- /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
- /// `QuitWorker` object, to indicate the worker should shut down.
- struct Job : Sequenced {
- Job() = default;
- Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
- Job(BatchRequest&& i, size_t sqn)
- : Sequenced(sqn), batch_request(std::move(i)) {}
- optional<QuitWorker> quit;
- optional<BatchRequest> batch_request;
- };
-
- /// The finished result of a job.
- struct Result : Sequenced {
- Result() = default;
- Result(Batch&& b, size_t sqn) : Sequenced(sqn), batch(std::move(b)) {}
- Result(std::exception_ptr exception, size_t sqn)
- : Sequenced(sqn), exception(std::move(exception)) {}
- optional<Batch> batch;
- std::exception_ptr exception;
- };
-
- /// Resets the internal state of the `DataLoader`, optionally pre-fetching
- /// new jobs.
- void reset(bool prefetch = true) {
- shuttle_.drain();
- sampler_.reset();
- sequence_number_ = 0;
- sequencer_ = new_sequencer();
- if (prefetch) {
- this->prefetch();
- }
- }
-
- /// Schedules `requested_jobs` many new batches to be fetched. The actual
- /// number of jobs scheduled may be less if the `DataLoader` exhausts.
- void prefetch(size_t requested_jobs) {
- while (requested_jobs-- > 0) {
- if (auto batch_request = get_batch_request()) {
- push_job(std::move(*batch_request));
- } else {
- break;
- }
- }
- }
-
- /// Schedules the maximum number of jobs (based on the `max_jobs` option).
- void prefetch() {
- prefetch(options_.max_jobs);
- }
-
- /// Returns the next batch of data, or an empty `optional` if the `DataLoader`
- /// is exhausted. This operation will block until a batch is available.
- optional<Batch> next() {
- optional<Batch> batch;
- if (options_.workers > 0) {
- optional<Result> result = sequencer_->next(
- [this] { return this->shuttle_.pop_result(this->options_.timeout); });
- if (result) {
- if (result->exception) {
- throw WorkerException(result->exception);
- } else {
- AT_ASSERT(result->batch.has_value());
- batch = std::move(result->batch);
- prefetch(1);
- }
- }
- } else if (auto batch_request = get_batch_request()) {
- AT_ASSERT(main_thread_dataset_ != nullptr);
- batch = main_thread_dataset_->get_batch(std::move(*batch_request));
- }
- return batch;
- }
- /// The function that worker threads run.
- void worker_thread(Dataset dataset) {
- while (true) {
- auto job = shuttle_.pop_job();
- if (job.quit) {
- break;
- }
- try {
- auto batch = dataset.get_batch(std::move(*job.batch_request));
- shuttle_.push_result({std::move(batch), job.sequence_number});
- } catch (...) {
- shuttle_.push_result({std::current_exception(), job.sequence_number});
- }
- }
- }
-
- optional<BatchRequest> get_batch_request() {
- auto indices = sampler_.next(options_.batch_size);
- if (!indices ||
- (indices->size() < options_.batch_size && options_.drop_last)) {
- return nullopt;
- }
- AT_ASSERT(indices->size() > 0);
- return indices;
- }
-
- template <typename T>
- void push_job(T value) {
- shuttle_.push_job({std::move(value), sequence_number_++});
- }
-
- std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
- if (options_.enforce_ordering) {
- return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
- options_.max_jobs);
- }
- return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
- }
-
- /// The options the `DataLoader` was configured with.
- const FullDataLoaderOptions options_;
-
- /// The dataset for the main thread, only has a value if the number of
- /// worker threads was configured as zero, meaning the main thread has to do
- /// all the work (synchronously). NOTE: Really want this to be on the heap
- /// when empty, therefore `unique_ptr` and not `optional`.
- std::unique_ptr<Dataset> main_thread_dataset_;
-
- /// The sampler with which new batch requests are created.
- Sampler sampler_;
-
- /// The sequence number for the *next* batch to be retrieved from the
- /// dataset.
- size_t sequence_number_ = 0;
-
- /// The worker threads, running the `worker_thread()` method.
- std::vector<std::thread> workers_;
-
- /// The `DataShuttle` which takes care of the life cycle of a job.
- detail::DataShuttle<Job, Result> shuttle_;
-
- /// The `Sequencer`, which handles optional ordering of batches.
- std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
-
- /// True if the `DataLoader` has joined its worker threads.
- bool joined_ = false;
-}; // namespace data
-
-/// Creates a new `DataLoader`, inferring the necessary template types from
-/// the given arguments.
+/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
+/// some `options`.
template <typename Dataset, typename Sampler>
-std::unique_ptr<DataLoader<Dataset, Sampler>> make_data_loader(
- Dataset dataset,
- DataLoaderOptions options,
- Sampler sampler) {
- return torch::make_unique<DataLoader<Dataset, Sampler>>(
- std::move(dataset), std::move(options), std::move(sampler));
+torch::disable_if_t<
+ Dataset::is_stateful,
+ std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
+make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
+ return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
+ std::move(dataset), std::move(sampler), std::move(options));
}
-/// Creates a new `DataLoader`, inferring the necessary template types from
-/// the given arguments.
-template <
- typename Sampler = samplers::RandomSampler,
- typename Dataset,
- typename =
- torch::enable_if_t<std::is_constructible<Sampler, size_t>::value>>
-std::unique_ptr<DataLoader<Dataset, Sampler>> make_data_loader(
+/// Creates a `DataLoader` instance for a stateless `dataset` and some
+/// `options`. A sampler (by default a `RandomSampler`) will be constructed from
+/// the size of the dataset.
+template <typename Sampler = samplers::RandomSampler, typename Dataset>
+torch::disable_if_t<
+ Dataset::is_stateful || !std::is_constructible<Sampler, size_t>::value,
+ std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
+make_data_loader(
Dataset dataset,
DataLoaderOptions options = DataLoaderOptions()) {
const optional<size_t> size = dataset.size();
"Expected the dataset to be sized in "
"order to construct the Sampler");
return make_data_loader(
- std::move(dataset), std::move(options), Sampler(*size));
+ std::move(dataset), Sampler(*size), std::move(options));
}
+/// Creates a `DataLoader` for a stateful `dataset` and some `options`.
+template <typename Dataset, typename = torch::enable_if_t<Dataset::is_stateful>>
+std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
+ Dataset dataset,
+ DataLoaderOptions options = DataLoaderOptions()) {
+ return torch::make_unique<StatefulDataLoader<Dataset>>(
+ std::move(dataset), std::move(options));
+}
} // namespace data
} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/data/dataloader_options.h>
+#include <torch/data/detail/data_shuttle.h>
+#include <torch/data/detail/sequencers.h>
+#include <torch/data/iterator.h>
+#include <torch/data/samplers/random.h>
+#include <torch/data/worker_exception.h>
+#include <torch/types.h>
+
+#include <torch/csrc/utils/memory.h>
+#include <torch/csrc/utils/variadic.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <exception>
+#include <memory>
+#include <thread>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace torch {
+namespace data {
+template <typename Dataset, typename Batch, typename BatchRequest>
+class DataLoaderBase {
+ public:
+ using BatchType = Batch;
+ using BatchRequestType = BatchRequest;
+
+ /// Constructs a new DataLoader from a `dataset` to sample from, `options`
+ /// to configure the DataLoader with, and a `sampler` that specifies the
+ /// sampling strategy.
+ DataLoaderBase(
+ DataLoaderOptions options,
+ std::unique_ptr<Dataset> main_thread_dataset = nullptr)
+ : options_(std::move(options)),
+ main_thread_dataset_(std::move(main_thread_dataset)),
+ sequencer_(new_sequencer()) {}
+
+ virtual ~DataLoaderBase() {
+ join();
+ }
+
+ /// Returns an iterator into the DataLoader. The lifetime of the iterator is
+ /// bound to the DataLoader. In C++ standards language, the category of the
+ /// iterator is `OutputIterator`. See
+ /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
+ /// means. In short: you may increment the iterator and dereference it, but
+ /// cannot go back, or step forward more than one position at a time. When the
+ /// DataLoader is exhausted, it will compare equal with the special
+ /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
+ /// should only use range-for loops to loop over the DataLoader, but
+ /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
+ /// output_iterator)` are supported too.
+ Iterator<Batch> begin() {
+ AT_CHECK(
+ shuttle_.in_flight_jobs() == 0,
+ "Attempted to get a new DataLoader iterator "
+ "while another iterator is not yet exhausted");
+ reset();
+ return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
+ [this] { return this->next(); }));
+ }
+
+ /// Returns a special "sentinel" iterator that compares equal with a
+ /// non-sentinel iterator once the DataLoader is exhausted.
+ Iterator<Batch> end() {
+ return Iterator<Batch>(
+ torch::make_unique<detail::SentinelIterator<Batch>>());
+ }
+
+ /// Joins the DataLoader's worker threads and drains internal queues.
+ /// This function may only be invoked from the main thread (in which the
+ /// DataLoader lives).
+ void join() {
+ if (joined_) {
+ return;
+ }
+ shuttle_.drain();
+ // Send one 'quit' message per worker. Since a worker dies (exits its
+ // thread) after receiving this message, each `QuitWorker()` message will be
+ // read by exactly one worker.
+ for (size_t w = 0; w < options_.workers; ++w) {
+ push_job(QuitWorker());
+ }
+ for (auto& worker : workers_) {
+ worker.join();
+ }
+ joined_ = true;
+ }
+
+ /// Returns the options with which the DataLoader was configured.
+ const FullDataLoaderOptions& options() const noexcept {
+ return options_;
+ }
+
+ protected:
+ /// Simple mix-in to give something a sequence number.
+ struct Sequenced {
+ Sequenced() = default;
+ Sequenced(size_t sqn) : sequence_number(sqn) {}
+ size_t sequence_number;
+ };
+
+ struct QuitWorker {};
+
+ /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
+ /// `QuitWorker` object, to indicate the worker should shut down.
+ struct Job : Sequenced {
+ Job() = default;
+ Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
+ Job(BatchRequest&& i, size_t sqn)
+ : Sequenced(sqn), batch_request(std::move(i)) {}
+ optional<QuitWorker> quit;
+ optional<BatchRequest> batch_request;
+ };
+
+ /// The finished result of a job.
+ struct Result : Sequenced {
+ Result() = default;
+ Result(optional<Batch>&& b, size_t sqn)
+ : Sequenced(sqn), batch(std::move(b)) {}
+ Result(std::exception_ptr exception, size_t sqn)
+ : Sequenced(sqn), exception(std::move(exception)) {}
+ optional<Batch> batch;
+ std::exception_ptr exception;
+ };
+
+ /// Subclass hook for getting the next batch request. The stateless case will
+ /// ask the sampler for a new batch request (e.g. a vector of indices), while
+ /// the stateful one will simply return the batch size.
+ virtual optional<BatchRequestType> get_batch_request() = 0;
+
+ /// Resets the internal state of the DataLoader, optionally pre-fetching
+ /// new jobs.
+ virtual void reset() {
+ shuttle_.drain();
+ sequence_number_ = 0;
+ sequencer_ = new_sequencer();
+ prefetch();
+ }
+
+ /// Schedules `requested_jobs` many new batches to be fetched. The actual
+ /// number of jobs scheduled may be less if the DataLoader exhausts.
+ void prefetch(size_t requested_jobs) {
+ for (size_t r = 0; r < requested_jobs; ++r) {
+ if (auto batch_request = get_batch_request()) {
+ this->push_job(std::move(*batch_request));
+ } else {
+ break;
+ }
+ }
+ }
+
+ /// Schedules the maximum number of jobs (based on the `max_jobs` option).
+ void prefetch() {
+ prefetch(options_.max_jobs);
+ }
+
+ /// Returns the next batch of data, or an empty `optional` if the DataLoader
+ /// is exhausted. This operation will block until a batch is available if one
+ /// is still expected.
+ optional<BatchType> next() {
+ if (options_.workers > 0) {
+ while (optional<Result> result = this->pop_result()) {
+ if (result->exception) {
+ throw WorkerException(result->exception);
+ } else if (result->batch) {
+ prefetch(1);
+ return std::move(result->batch);
+ }
+ }
+ } else if (auto batch_request = get_batch_request()) {
+ return this->main_thread_dataset_->get_batch(std::move(*batch_request));
+ }
+ return nullopt;
+ }
+
+ /// The function that worker threads run.
+ void worker_thread(Dataset& dataset) {
+ while (true) {
+ auto job = shuttle_.pop_job();
+ if (job.quit) {
+ break;
+ }
+ try {
+ auto batch = dataset.get_batch(std::move(*job.batch_request));
+ shuttle_.push_result({std::move(batch), job.sequence_number});
+ } catch (...) {
+ shuttle_.push_result({std::current_exception(), job.sequence_number});
+ }
+ }
+ }
+
+ /// Convenience method that calls `shuttle_.push_job()` with the next sequence
+ /// number.
+ template <typename T>
+ void push_job(T value) {
+ shuttle_.push_job({std::move(value), sequence_number_++});
+ }
+
+ /// Convenience method that gets the next result from the sequencer.
+ optional<Result> pop_result() {
+ return sequencer_->next(
+ [this] { return this->shuttle_.pop_result(this->options_.timeout); });
+ }
+
+ /// Convenience method that creates a new sequencer based on the
+ /// `enforce_ordering` option.
+ std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
+ if (options_.enforce_ordering) {
+ return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
+ options_.max_jobs);
+ }
+ return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
+ }
+
+ /// The options the DataLoader was configured with.
+ const FullDataLoaderOptions options_;
+
+ /// The dataset for the main thread, only has a value if the number of
+ /// worker threads was configured as zero, meaning the main thread has to do
+ /// all the work (synchronously). NOTE: Really want this to be on the heap
+ /// when empty, therefore `unique_ptr` and not `optional`.
+ std::unique_ptr<Dataset> main_thread_dataset_;
+
+ /// The sequence number for the *next* batch to be retrieved from the
+ /// dataset.
+ size_t sequence_number_ = 0;
+
+ /// The worker threads, running the `worker_thread()` method.
+ std::vector<std::thread> workers_;
+
+ /// The `DataShuttle` which takes care of the life cycle of a job.
+ detail::DataShuttle<Job, Result> shuttle_;
+
+ /// The `Sequencer`, which handles optional ordering of batches.
+ std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
+
+ /// True if the DataLoader has joined its worker threads.
+ bool joined_ = false;
+};
+} // namespace data
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/data/dataloader/base.h>
+
+#include <cstddef>
+#include <thread>
+#include <utility>
+
+namespace torch {
+namespace data {
+
+/// A dataloader for stateful datasets.
+///
+/// A dataloader for stateful datatasets differs from one for stateless
+/// datasets one in that the dataset is shared among worker threads, and that
+/// this dataset is itself responsible for producing batches rather than
+/// depending on a sampler. The statefulness here actually refers to the
+/// dataset. The StatefulDataLoader simply alters the data loading algorithm to
+/// accomodate the stateful, shared nature of the dataset. Note that the dataset
+/// must be thread safe if more than one worker thread is used.
+///
+/// A stateful dataloader is created by calling `make_data_loader` with a
+/// stateful dataset.
+template <typename Dataset>
+class StatefulDataLoader : public DataLoaderBase<
+ Dataset,
+ typename Dataset::BatchType::value_type,
+ typename Dataset::BatchRequestType> {
+ public:
+ using super = DataLoaderBase<
+ Dataset,
+ typename Dataset::BatchType::value_type,
+ typename Dataset::BatchRequestType>;
+ using typename super::BatchRequestType;
+
+ /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
+ StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
+ : super(
+ std::move(options),
+ torch::make_unique<Dataset>(std::move(dataset))) {
+ for (size_t w = 0; w < this->options_.workers; ++w) {
+ // As opposed to the stateless case, here all worker threads access the
+ // same underlying dataset.
+ this->workers_.emplace_back(
+ [this] { this->worker_thread(*this->main_thread_dataset_); });
+ }
+ }
+
+ private:
+ /// Resets the internal state of the dataloader and the dataset.
+ void reset() override {
+ this->main_thread_dataset_->reset();
+ // Call the base class method last because it calls `prefetch()`
+ super::reset();
+ }
+
+ /// For stateful datasets, the batch request is always the batch size. The
+ /// dataset is responsible for determining what goes into the batch next.
+ optional<BatchRequestType> get_batch_request() override {
+ return this->options_.batch_size;
+ }
+};
+} // namespace data
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/data/dataloader/base.h>
+#include <torch/data/worker_exception.h>
+
+#include <torch/csrc/utils/memory.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <thread>
+#include <utility>
+
+namespace torch {
+namespace data {
+
+/// A dataloader for stateless datasets.
+///
+/// This dataloader follows the traditional PyTorch dataloader design, whereby a
+/// (posssibly) stateful sampler produces *batch requests* for a stateless
+/// dataset, which acts as a simple batch request to batch mapping. The batch
+/// request will often be an array of indices, and if the dataset is a simple
+/// image dataset, the dataset would produce the images at those indices.
+template <typename Dataset, typename Sampler>
+class StatelessDataLoader : public DataLoaderBase<
+ Dataset,
+ typename Dataset::BatchType,
+ typename Sampler::BatchRequestType> {
+ public:
+ using super = DataLoaderBase<
+ Dataset,
+ typename Dataset::BatchType,
+ typename Sampler::BatchRequestType>;
+ using typename super::BatchRequestType;
+
+ /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
+ /// some `options`.
+ StatelessDataLoader(
+ Dataset dataset,
+ Sampler sampler,
+ DataLoaderOptions options)
+ : super(std::move(options)), sampler_(std::move(sampler)) {
+ for (size_t w = 0; w < this->options_.workers; ++w) {
+ // Here we copy the dataset into the worker thread closure. Each worker
+ // has its own copy of the dataset. This means the dataset must be
+ // trivially copiable, or else we don't expect more than one worker to
+ // be in use.
+ this->workers_.emplace_back(
+ [this, dataset]() mutable { this->worker_thread(dataset); });
+ }
+ if (this->options_.workers == 0) {
+ this->main_thread_dataset_ =
+ torch::make_unique<Dataset>(std::move(dataset));
+ }
+ }
+
+ private:
+ /// Resets the internal state of the dataloader and the sampler.
+ void reset() override {
+ sampler_.reset();
+ // Call the base class method last because it calls `prefetch()`
+ super::reset();
+ }
+
+ /// Queries the sampler for the next batch request (possibly progressing its
+ /// internal state).
+ optional<BatchRequestType> get_batch_request() override {
+ auto indices = sampler_.next(this->options_.batch_size);
+ if (!indices ||
+ (indices->size() < this->options_.batch_size &&
+ this->options_.drop_last)) {
+ return nullopt;
+ }
+ AT_ASSERT(indices->size() > 0);
+ return indices;
+ }
+
+ /// The `Sampler` used to produce batch requests.
+ Sampler sampler_;
+};
+} // namespace data
+} // namespace torch
namespace torch {
namespace data {
-namespace detail {
-struct FullDataLoaderOptions;
-} // namespace detail
-} // namespace data
-} // namespace torch
-
-namespace torch {
-namespace data {
/// Options to configure a `DataLoader`.
struct DataLoaderOptions {
- friend struct torch::data::detail::FullDataLoaderOptions;
-
DataLoaderOptions() = default;
/* implicit */ DataLoaderOptions(size_t batch_size)
: batch_size_(batch_size) {}
#include <torch/data/datasets/map.h>
#include <torch/data/datasets/mnist.h>
#include <torch/data/datasets/shared.h>
+#include <torch/data/datasets/stateful.h>
#include <torch/data/datasets/tensor.h>
namespace data {
namespace datasets {
template <typename S, typename T>
-struct MapDataset;
+class MapDataset;
template <typename D, typename T>
MapDataset<D, T> map(D, T); // NOLINT
} // namespace datasets
namespace torch {
namespace data {
namespace datasets {
+namespace detail {
+template <typename T>
+struct is_optional : std::false_type {};
+template <typename T>
+struct is_optional<optional<T>> : std::true_type {};
+} // namespace detail
/// A dataset that can yield data only in batches.
template <
using SelfType = Self;
using BatchType = Batch;
using BatchRequestType = BatchRequest;
+ constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
virtual ~BatchDataset() = default;
#include <c10/util/ArrayRef.h>
#include <cstddef>
+#include <type_traits>
#include <utility>
namespace torch {
namespace data {
namespace datasets {
+namespace detail {
+template <bool C, typename T>
+using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
+} // namespace detail
/// A `MapDataset` is a dataset that applies a transform to a source dataset.
template <typename SourceDataset, typename AppliedTransform>
-struct MapDataset : BatchDataset<
- MapDataset<SourceDataset, AppliedTransform>,
- typename AppliedTransform::OutputBatchType,
- typename SourceDataset::BatchRequestType> {
+class MapDataset : public BatchDataset<
+ MapDataset<SourceDataset, AppliedTransform>,
+ detail::optional_if_t<
+ SourceDataset::is_stateful,
+ typename AppliedTransform::OutputBatchType>,
+ typename SourceDataset::BatchRequestType> {
+ public:
using DatasetType = SourceDataset;
using TransformType = AppliedTransform;
using BatchRequestType = typename SourceDataset::BatchRequestType;
- using OutputBatchType = typename TransformType::OutputBatchType;
+ using OutputBatchType = detail::optional_if_t<
+ SourceDataset::is_stateful,
+ typename AppliedTransform::OutputBatchType>;
MapDataset(DatasetType dataset, TransformType transform)
- : dataset(std::move(dataset)), transform(std::move(transform)) {}
+ : dataset_(std::move(dataset)), transform_(std::move(transform)) {}
/// Gets a batch from the source dataset and applies the transform to it,
/// returning the result.
OutputBatchType get_batch(BatchRequestType indices) override {
- return transform.apply_batch(dataset.get_batch(indices));
+ return get_batch_impl(std::move(indices));
}
/// Returns the size of the source dataset.
- optional<size_t> size() const noexcept {
- return dataset.size();
+ optional<size_t> size() const noexcept override {
+ return dataset_.size();
}
- SourceDataset dataset;
- AppliedTransform transform;
+ /// Calls `reset()` on the underlying dataset.
+ /// NOTE: Stateless datasets do not have a reset() method, so a call to this
+ /// method will only compile for stateful datasets (which have a reset()
+ /// method).
+ void reset() {
+ dataset_.reset();
+ }
+
+ /// Returns the underlying dataset.
+ const SourceDataset& dataset() noexcept {
+ return dataset_;
+ }
+
+ /// Returns the transform being applied.
+ const AppliedTransform& transform() noexcept {
+ return transform_;
+ }
+
+ private:
+ /// The implementation of `get_batch()` for the stateless case, which simply
+ /// applies the transform to the output of `get_batch()` from the dataset.
+ template <
+ typename D = SourceDataset,
+ typename = torch::disable_if_t<D::is_stateful>>
+ OutputBatchType get_batch_impl(BatchRequestType indices) {
+ return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
+ }
+
+ /// The implementation of `get_batch()` for the stateful case. Here, we follow
+ /// the semantics of `Optional.map()` in many functional languages, which
+ /// applies a transformation to the optional's content when the optional
+ /// contains a value, and returns a new optional (of a different type) if the
+ /// original optional returned by `get_batch()` was empty.
+ template <typename D = SourceDataset>
+ torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
+ BatchRequestType indices) {
+ if (auto batch = dataset_.get_batch(std::move(indices))) {
+ return transform_.apply_batch(std::move(*batch));
+ }
+ return nullopt;
+ }
+
+ /// The underlying dataset being transformed.
+ SourceDataset dataset_;
+
+ // The transformation that is applied to batches received from the dataset.
+ AppliedTransform transform_;
};
/// Creates a `MapDataset` with the given dataset and transform.
TransformType transform) {
static_assert(
std::is_same<
- typename DatasetType::BatchType,
+ typename std::conditional<
+ DatasetType::is_stateful,
+ typename DatasetType::BatchType::value_type,
+ typename DatasetType::BatchType>::type,
typename TransformType::InputBatchType>::value,
"BatchType type of dataset does not match input type of transform");
return {std::move(dataset), std::move(transform)};
return dataset_.get();
}
+ /// Calls `reset()` on the underlying dataset.
+ void reset() {
+ dataset_->reset();
+ }
+
private:
std::shared_ptr<UnderlyingDataset> dataset_;
};
--- /dev/null
+#pragma once
+
+#include <torch/data/datasets/base.h>
+#include <torch/data/example.h>
+
+#include <cstddef>
+#include <vector>
+
+namespace torch {
+namespace data {
+namespace datasets {
+
+/// A stateful dataset is a dataset that maintains some internal state, which
+/// will be `reset()` at the beginning of each epoch. Subclasses can override
+/// the `reset()` method to configure this behavior. Further, the return type of
+/// a stateful dataset's `get_batch()` method is always an `optional`. When the
+/// stateful dataset wants to indicate to the dataloader that its epoch has
+/// ended, it should return an empty optional. The dataloader knows to modify
+/// its implementation based on whether the dataset is stateless or stateful.
+///
+/// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
+/// type of `get_batch()`, which the subclass must override, will be
+/// `optional<T>` (i.e. the type specified in the `StatefulDataset` specialization is automatically boxed into an `optional` for the datast's `BatchType`).
+template <
+ typename Self,
+ typename Batch = std::vector<Example<>>,
+ typename BatchRequest = size_t>
+class StatefulDataset
+ : public BatchDataset<Self, optional<Batch>, BatchRequest> {
+ public:
+ /// Resets internal state of the dataset.
+ virtual void reset() = 0;
+};
+} // namespace datasets
+} // namespace data
+} // namespace torch