From 44cb43bcc168a559040beb20b266bc2806d30385 Mon Sep 17 00:00:00 2001 From: Jaliya Ekanayake Date: Thu, 29 Nov 2018 07:04:52 -0800 Subject: [PATCH] Jaliyae/samplers (#13870) Summary: Make Samplers optionally accept new size in their reset() method. This helps dataloader or dataset to reset the sampler for an epoch or a chunk of data with different sizes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13870 Differential Revision: D13240120 Pulled By: soumith fbshipit-source-id: 19c53f8be13c0fdcf504f0637b0d3e6009a8e599 --- test/cpp/api/dataloader.cpp | 39 +++++++++++++++++++++- torch/csrc/api/include/torch/data/samplers/base.h | 6 ++-- .../csrc/api/include/torch/data/samplers/random.h | 8 +++-- .../api/include/torch/data/samplers/sequential.h | 4 +-- .../csrc/api/include/torch/data/samplers/stream.h | 4 +-- torch/csrc/api/src/data/samplers/random.cpp | 5 +-- torch/csrc/api/src/data/samplers/sequential.cpp | 5 ++- torch/csrc/api/src/data/samplers/stream.cpp | 5 ++- 8 files changed, 62 insertions(+), 14 deletions(-) diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 56cee13..0d8d546 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -207,6 +207,19 @@ TEST(DataTest, SequentialSamplerResetsWell) { ASSERT_FALSE(sampler.next(2).has_value()); } +TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) { + samplers::SequentialSampler sampler(5); + ASSERT_EQ(sampler.next(5).value(), std::vector({0, 1, 2, 3, 4})); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(7); + ASSERT_EQ( + sampler.next(7).value(), std::vector({0, 1, 2, 3, 4, 5, 6})); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(3); + ASSERT_EQ(sampler.next(3).value(), std::vector({0, 1, 2})); + ASSERT_FALSE(sampler.next(2).has_value()); +} + TEST(DataTest, CanSaveAndLoadSequentialSampler) { { samplers::SequentialSampler a(10); @@ -272,6 +285,18 @@ TEST(DataTest, RandomSamplerResetsWell) { ASSERT_FALSE(sampler.next(2).has_value()); } +TEST(DataTest, RandomSamplerResetsWithNewSizeWell) { + samplers::RandomSampler sampler(5); + ASSERT_EQ(sampler.next(5).value().size(), 5); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(7); + ASSERT_EQ(sampler.next(7).value().size(), 7); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(3); + ASSERT_EQ(sampler.next(3).value().size(), 3); + ASSERT_FALSE(sampler.next(2).has_value()); +} + TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) { { samplers::RandomSampler a(10); @@ -320,6 +345,18 @@ TEST(DataTest, StreamSamplerResetsWell) { ASSERT_FALSE(sampler.next(2).has_value()); } +TEST(DataTest, StreamSamplerResetsWithNewSizeWell) { + samplers::StreamSampler sampler(/*epoch_size=*/5); + ASSERT_EQ(sampler.next(5).value().size(), 5); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(7); + ASSERT_EQ(sampler.next(7).value().size(), 7); + ASSERT_FALSE(sampler.next(2).has_value()); + sampler.reset(3); + ASSERT_EQ(sampler.next(3).value().size(), 3); + ASSERT_FALSE(sampler.next(2).has_value()); +} + TEST(DataTest, TensorDatasetConstructsFromSingleTensor) { datasets::TensorDataset dataset(torch::eye(5)); ASSERT_TRUE( @@ -618,7 +655,7 @@ struct TestIndexDataset struct TestIndexSampler : public samplers::Sampler { explicit TestIndexSampler(size_t size) : size_(size) {} - void reset() override {} + void reset(torch::optional new_size = torch::nullopt) override {} torch::optional next(size_t batch_size) override { if (index_ >= size_) { return torch::nullopt; diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h index 1767d65..ed2aa30 100644 --- a/torch/csrc/api/include/torch/data/samplers/base.h +++ b/torch/csrc/api/include/torch/data/samplers/base.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include @@ -27,7 +27,9 @@ class Sampler { /// Resets the `Sampler`'s internal state. /// Typically called before a new epoch. - TORCH_API virtual void reset() = 0; + + /// Optionally, accepts a new size when reseting the sampler. + TORCH_API virtual void reset(optional new_size) = 0; /// Returns the next index if possible, or an empty optional if the /// sampler is exhausted for this epoch. diff --git a/torch/csrc/api/include/torch/data/samplers/random.h b/torch/csrc/api/include/torch/data/samplers/random.h index b18a36b..f0e6a86 100644 --- a/torch/csrc/api/include/torch/data/samplers/random.h +++ b/torch/csrc/api/include/torch/data/samplers/random.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include @@ -26,10 +26,12 @@ class RandomSampler : public Sampler<> { /// The constructor will eagerly allocate all required indices, which is the /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored /// indices. You can change it to influence memory usage. - TORCH_API explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); + TORCH_API explicit RandomSampler( + int64_t size, + Dtype index_dtype = torch::kInt64); /// Resets the `RandomSampler` to a new set of indices. - TORCH_API void reset() override; + TORCH_API void reset(optional new_size = nullopt) override; /// Returns the next batch of indices. TORCH_API optional> next(size_t batch_size) override; diff --git a/torch/csrc/api/include/torch/data/samplers/sequential.h b/torch/csrc/api/include/torch/data/samplers/sequential.h index bd14d3b..5f83014 100644 --- a/torch/csrc/api/include/torch/data/samplers/sequential.h +++ b/torch/csrc/api/include/torch/data/samplers/sequential.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include @@ -26,7 +26,7 @@ class SequentialSampler : public Sampler<> { TORCH_API explicit SequentialSampler(size_t size); /// Resets the `SequentialSampler` to zero. - TORCH_API void reset() override; + TORCH_API void reset(optional new_size = nullopt) override; /// Returns the next batch of indices. TORCH_API optional> next(size_t batch_size) override; diff --git a/torch/csrc/api/include/torch/data/samplers/stream.h b/torch/csrc/api/include/torch/data/samplers/stream.h index fefc301..6f376ac 100644 --- a/torch/csrc/api/include/torch/data/samplers/stream.h +++ b/torch/csrc/api/include/torch/data/samplers/stream.h @@ -1,9 +1,9 @@ #pragma once +#include #include #include #include -#include #include @@ -39,7 +39,7 @@ class StreamSampler : public Sampler { TORCH_API explicit StreamSampler(size_t epoch_size); /// Resets the internal state of the sampler. - TORCH_API void reset() override; + TORCH_API void reset(optional new_size = nullopt) override; /// Returns a `BatchSize` object with the number of elements to fetch in the /// next batch. This number is the minimum of the supplied `batch_size` and diff --git a/torch/csrc/api/src/data/samplers/random.cpp b/torch/csrc/api/src/data/samplers/random.cpp index 4ea975b..0edbc8c 100644 --- a/torch/csrc/api/src/data/samplers/random.cpp +++ b/torch/csrc/api/src/data/samplers/random.cpp @@ -12,10 +12,11 @@ namespace samplers { RandomSampler::RandomSampler(int64_t size, Dtype index_dtype) : indices_(torch::randperm(size, index_dtype)) {} -void RandomSampler::reset() { +void RandomSampler::reset(optional new_size) { // This allocates a new chunk of memory every time (just FYI). It should be // amortized over the entire epoch hopefully. - indices_ = torch::randperm(indices_.numel(), indices_.options()); + const auto size = new_size.value_or(static_cast(indices_.numel())); + indices_ = torch::randperm(size, indices_.options()); index_ = 0; } diff --git a/torch/csrc/api/src/data/samplers/sequential.cpp b/torch/csrc/api/src/data/samplers/sequential.cpp index 3072346..9c294cb 100644 --- a/torch/csrc/api/src/data/samplers/sequential.cpp +++ b/torch/csrc/api/src/data/samplers/sequential.cpp @@ -11,7 +11,10 @@ namespace data { namespace samplers { SequentialSampler::SequentialSampler(size_t size) : size_(size) {} -void SequentialSampler::reset() { +void SequentialSampler::reset(optional new_size) { + if (new_size.has_value()) { + size_ = *new_size; + } index_ = 0; } diff --git a/torch/csrc/api/src/data/samplers/stream.cpp b/torch/csrc/api/src/data/samplers/stream.cpp index 2ac1755..6972846 100644 --- a/torch/csrc/api/src/data/samplers/stream.cpp +++ b/torch/csrc/api/src/data/samplers/stream.cpp @@ -20,7 +20,10 @@ BatchSize::operator size_t() const noexcept { StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {} -void StreamSampler::reset() { +void StreamSampler::reset(optional new_size) { + if (new_size.has_value()) { + epoch_size_ = *new_size; + } examples_retrieved_so_far_ = 0; } -- 2.7.4