From: Xiang Gao Date: Wed, 10 Apr 2019 14:33:15 +0000 (-0700) Subject: Add torch.unique_consecutive (#19060) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~292 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ea2405c7dc4d69477d41d4347b6b90af9a42e7ee;p=platform%2Fupstream%2Fpytorch.git Add torch.unique_consecutive (#19060) Summary: Fixes: https://github.com/pytorch/pytorch/issues/19045 Please review: VitalyFedyunin ngimel This is independent on the #18649 series. This will cause merge conflicts in #18649 series, but please merge this first, and I will resolve the merge conflicts there. The new feature is exposed in `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon`. But not at `torch.unique` yet. I will take care of the API after #18649 series get merged completely. Benchmark on a tensor of shape `torch.Size([15320, 2])`: ```python print(torch.__version__) print() a = tensor.sort().values.to('cpu') print('cpu, sorted_input=False:') %timeit torch._unique2_temporary_will_remove_soon(a) %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True) %timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True) %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True) print() print('cpu, sorted_input=True:') %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True) %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True) print() a = a.to('cuda') print('cuda, sorted_input=False:') %timeit torch._unique2_temporary_will_remove_soon(a); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True); torch.cuda.synchronize() print() print('cuda, sorted_input=True:') %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` 1.1.0a0+2addccc cpu, sorted_input=False: 340 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 717 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 52.3 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 52.3 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) cpu, sorted_input=True: 32.8 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 49.9 µs ± 557 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 51.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 78 µs ± 782 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) cuda, sorted_input=False: 213 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 291 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 250 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 321 µs ± 1.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) cuda, sorted_input=True: 45.6 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 110 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 82 µs ± 857 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 143 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ```python print(torch.__version__) print() a1, a2 = tensor.unbind(1) indices = (a1 * tensor.max() + a2).sort().indices a = tensor.index_select(0, indices).to('cpu') print('cpu, sorted_input=False:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True) print() print('cpu, sorted_input=True:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True) %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True) print() a = a.to('cuda') print('cuda, sorted_input=False:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True); torch.cuda.synchronize() print() print('cuda, sorted_input=True:') %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` cpu, sorted_input=False: 55.4 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.8 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.2 ms ± 402 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.1 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) cpu, sorted_input=True: 54.7 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 55.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 54.5 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 54.9 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) cuda, sorted_input=False: 171 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 220 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 203 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 251 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) cuda, sorted_input=True: 59.6 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 113 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 93.2 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 147 µs ± 2.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` The CPU implementation of `unique_dim` is super slow, see https://github.com/pytorch/pytorch/issues/18987, but this PR will not worry about this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19060 Differential Revision: D14866909 Pulled By: ezyang fbshipit-source-id: d20012cec68c37b05cf770a6f4d6524f910b950f --- diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 1dcf85d..2888097 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -14,16 +14,21 @@ namespace native{ namespace { template -std::tuple _unique_cpu_template( +std::tuple unique_cpu_template( const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { const Tensor& input = self.contiguous(); const scalar_t* input_data = input.data(); - std::unordered_set set(input_data, input_data + input.numel()); - Tensor output = at::empty({static_cast(set.size())}, input.options()); - scalar_t* output_data = output.data(); + int64_t numel = input.numel(); + Tensor output; + Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + + std::unordered_set set(input_data, input_data + numel); + output = at::empty({static_cast(set.size())}, input.options()); + scalar_t *output_data = output.data(); if (sorted) { std::vector vec(set.begin(), set.end()); @@ -33,8 +38,6 @@ std::tuple _unique_cpu_template( std::copy(set.begin(), set.end(), output_data); } - Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); - Tensor counts = at::empty({0}, self.options().dtype(kLong)); if (return_inverse || return_counts) { inverse_indices.resize_(input.sizes()); int64_t* inverse_indices_data = inverse_indices.data(); @@ -43,13 +46,13 @@ std::tuple _unique_cpu_template( for (int i = 0; i < output.numel(); ++i) { inverse_map[output_data[i]] = i; } - for (int i = 0; i < input.numel(); ++i) { + for (int i = 0; i < numel; ++i) { inverse_indices_data[i] = inverse_map[input_data[i]]; } if (return_counts) { counts.resize_(output.sizes()); counts.fill_(0); - for (int i = 0; i < input.numel(); ++i) { + for (int i = 0; i < numel; ++i) { counts[inverse_map[input_data[i]]] += 1; } } @@ -57,6 +60,57 @@ std::tuple _unique_cpu_template( return std::make_tuple(output, inverse_indices, counts); } +template +std::tuple unique_consecutive_cpu_template( + const Tensor& self, + const bool return_inverse, + const bool return_counts) { + const Tensor& input = self.contiguous(); + const scalar_t* input_data = input.data(); + int64_t numel = input.numel(); + Tensor output = at::empty({numel}, input.options()); + Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + + scalar_t *output_data = output.data(); + int64_t *inverse_data = nullptr; + int64_t *counts_data = nullptr; + if (numel > 0) { + *output_data = *input_data; + } + if (return_inverse) { + inverse_indices.resize_(input.sizes()); + inverse_data = inverse_indices.data(); + } + if (return_counts) { + counts.resize_(input.sizes()); + counts_data = counts.data(); + } + scalar_t *p = output_data; + int64_t *q = counts_data; + int64_t last = 0; + for (int64_t i = 0; i < numel; i++) { + if (input_data[i] != *p) { + *(++p) = input_data[i]; + if (return_counts) { + *(q++) = i - last; + last = i; + } + } + if (return_inverse) { + inverse_data[i] = p - output_data; + } + } + int64_t output_size = p - output_data + 1; + if (return_counts && numel > 0) { + *q = numel - last; + counts.resize_({output_size}); + } + output.resize_({output_size}); + + return std::make_tuple(output, inverse_indices, counts); +} + template ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, std::vector& indices, Tensor inverse_indices_vec, Tensor counts) { @@ -88,6 +142,7 @@ template std::tuple _unique_dim_cpu_template( const Tensor& self, const int64_t dim, + const bool consecutive, const bool return_inverse, const bool return_counts) { // reshape tensor as [dim, -1] @@ -101,23 +156,30 @@ std::tuple _unique_dim_cpu_template( scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr()); // sort indices using data - std::sort(indices.begin(), indices.end(), - [&](int64_t a, int64_t b) -> bool { - for (int64_t i = 0; i < numel; ++i) { - scalar_t lhs = input_flat_ptr[i + a * numel]; - scalar_t rhs = input_flat_ptr[i + b * numel]; - if (lhs < rhs) { - return true; - } else if (lhs > rhs) { - return false; + if (!consecutive) { + std::sort(indices.begin(), indices.end(), + [&](int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < numel; ++i) { + scalar_t lhs = input_flat_ptr[i + a * numel]; + scalar_t rhs = input_flat_ptr[i + b * numel]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } } - } - return false; - }); + return false; + }); + } - Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.options()); - for (int i = 0; i < indices.size(); ++i) { - input_sorted[i] = input_flat[indices[i]]; + Tensor input_sorted; + if (!consecutive) { + input_sorted = at::empty(input_flat.sizes(), input_flat.options()); + for (int i = 0; i < indices.size(); ++i) { + input_sorted[i] = input_flat[indices[i]]; + } + } else { + input_sorted = input_flat; } Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong)); @@ -137,6 +199,7 @@ std::tuple _unique_dim_cpu_template( return std::make_tuple(output, inverse_indices, counts); } + } // namespace @@ -144,7 +207,7 @@ std::tuple _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { Tensor output, inverse; - std::tie(output, inverse, std::ignore) = _unique_cpu_template(self, sorted, return_inverse, false); + std::tie(output, inverse, std::ignore) = unique_cpu_template(self, sorted, return_inverse, false); return std::make_tuple(output, inverse); }); } @@ -152,7 +215,7 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { std::tuple _unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { - return _unique_cpu_template(self, sorted, return_inverse, return_counts); + return unique_cpu_template(self, sorted, return_inverse, return_counts); }); } @@ -161,7 +224,7 @@ _unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors Tensor output, inverse; - std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template(self, dim, return_inverse, false); + std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template(self, dim, false, return_inverse, false); return std::make_tuple(output, inverse); }); } @@ -170,9 +233,26 @@ std::tuple _unique_dim2_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors - return _unique_dim_cpu_template(self, dim, return_inverse, return_counts); + return _unique_dim_cpu_template(self, dim, false, return_inverse, return_counts); }); } +std::tuple +unique_dim_consecutive_cpu(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { + return _unique_dim_cpu_template(self, dim, true, return_inverse, return_counts); + }); +} + +std::tuple +unique_consecutive_cpu(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional dim) { + if (!dim.has_value()) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + return unique_consecutive_cpu_template(self, return_inverse, return_counts); + }); + } + return unique_dim_consecutive_cpu(self, dim.value(), return_inverse, return_counts); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index e4945bb..734fa66 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -73,6 +73,7 @@ std::tuple compute_unique( template std::tuple unique_cuda_template( const Tensor& self, + const bool consecutive, const bool return_inverse, const bool return_counts ) { @@ -88,11 +89,15 @@ std::tuple unique_cuda_template( Tensor sorted_indices; if (!return_inverse) { - thrust::sort(policy, output_data, output_data + num_inp); + if (!consecutive) { + thrust::sort(policy, output_data, output_data + num_inp); + } } else { sorted_indices = at::arange(0, num_inp, options); - int64_t *sorted_indices_ptr = sorted_indices.data(); - thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr); + if (!consecutive) { + int64_t *sorted_indices_ptr = sorted_indices.data(); + thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr); + } } Tensor inverse_indices, counts; @@ -116,6 +121,7 @@ template std::tuple unique_dim_cuda_template( const Tensor& self, const int64_t dim, + const bool consecutive, const bool return_inverse, const bool return_counts ) { @@ -141,20 +147,22 @@ std::tuple unique_dim_cuda_template( Tensor indices = at::arange(0, num_inp, options); int64_t *indices_data = indices.data(); - thrust::sort(policy, indices_data, indices_data + num_inp, - [=] __device__ (int64_t a, int64_t b) -> bool { - for (int64_t i = 0; i < n; ++i) { - scalar_t lhs = input_flat_ptr[i + a * n]; - scalar_t rhs = input_flat_ptr[i + b * n]; - if (lhs < rhs) { - return true; - } else if (lhs > rhs) { - return false; + if (!consecutive) { + thrust::sort(policy, indices_data, indices_data + num_inp, + [=] __device__ (int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < n; ++i) { + scalar_t lhs = input_flat_ptr[i + a * n]; + scalar_t rhs = input_flat_ptr[i + b * n]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } } + return false; } - return false; - } - ); + ); + } Tensor inverse_indices, counts; int64_t num_out; @@ -196,7 +204,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust Tensor output, inverse; - std::tie(output, inverse, std::ignore) = unique_cuda_template(self, return_inverse, false); + std::tie(output, inverse, std::ignore) = unique_cuda_template(self, false, return_inverse, false); return std::make_tuple(output, inverse); }); } @@ -206,7 +214,7 @@ _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust - return unique_cuda_template(self, return_inverse, return_counts); + return unique_cuda_template(self, false, return_inverse, return_counts); }); } @@ -214,7 +222,7 @@ std::tuple _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { Tensor output, inverse; - std::tie(output, inverse, std::ignore) = unique_dim_cuda_template(self, dim, return_inverse, false); + std::tie(output, inverse, std::ignore) = unique_dim_cuda_template(self, dim, false, return_inverse, false); return std::make_tuple(output, inverse); }); } @@ -222,9 +230,28 @@ _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const std::tuple _unique_dim2_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { - return unique_dim_cuda_template(self, dim, return_inverse, return_counts); + return unique_dim_cuda_template(self, dim, false, return_inverse, return_counts); }); } +std::tuple +unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { + return unique_dim_cuda_template(self, dim, true, return_inverse, return_counts); + }); +} + +std::tuple +unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional dim) { + if (!dim.has_value()) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + // The current CUDA implementation of unique always sort due to the + // lack of hashtable implementation in thrust + return unique_cuda_template(self, true, return_inverse, return_counts); + }); + } + return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1499071..852add5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2428,6 +2428,20 @@ CPU: _unique_dim_cpu CUDA: _unique_dim_cuda +- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: unique_consecutive_cpu + CUDA: unique_consecutive_cuda + +- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: unique_dim_consecutive_cpu + CUDA: unique_dim_consecutive_cuda + # _unique and _unique_dim are fragile and modifying them easily cause internal break # below two operators are a temporary hack for adding return_counts support # Please don't rely on these two operators, they will be removed soon diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index af10e8e..0af3c52 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -451,6 +451,7 @@ view of a storage and defines numeric operations on it. .. automethod:: unfold .. automethod:: uniform_ .. automethod:: unique + .. automethod:: unique_consecutive .. automethod:: unsqueeze .. automethod:: unsqueeze_ .. automethod:: values diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 039fd0f..67f1955 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -225,6 +225,7 @@ Reduction Ops .. autofunction:: std .. autofunction:: sum .. autofunction:: unique +.. autofunction:: unique_consecutive .. autofunction:: var diff --git a/test/test_torch.py b/test/test_torch.py index 2a4c576..5adce86 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10611,6 +10611,28 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse) self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts) + # test consecutive version + z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device) + expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device) + expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) + expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device) + + z_unique = torch.unique_consecutive(z) + self.assertEqual(z_unique, expected_z_unique) + + z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_inverse, expected_z_inverse) + + z_unique, z_counts = torch.unique_consecutive(z, return_counts=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_counts, expected_z_counts) + + z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_inverse, expected_z_inverse) + self.assertEqual(z_counts, expected_z_counts) + run_test(torch.device('cpu')) if torch.cuda.is_available(): run_test(torch.device('cuda')) @@ -10742,6 +10764,37 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_inverse_dim2, x_inverse) self.assertEqual(expected_counts_dim2, x_counts) + # test consecutive version + y = torch.tensor( + [[0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_unique = torch.tensor( + [[0, 1], + [1, 2], + [3, 4], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=dtype, device=device) + expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=dtype, device=device) + y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + self.assertEqual(expected_y_inverse, y_inverse) + self.assertEqual(expected_y_counts, y_counts) + run_test(torch.float) run_test(torch.double) run_test(torch.long) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index e147010..2e7fe79 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -22,7 +22,7 @@ SKIP_PYTHON_BINDINGS = [ '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*', '_arange.*', '_range.*', '_linspace.*', '_logspace.*', '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', - 'index', + 'index', 'unique_dim_consecutive', '_indexCopy_', 'max_values', 'min_values', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index cd95a25..0d5ad32 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -72,6 +72,7 @@ blacklist = [ 'tensordot', 'norm', 'split', + 'unique_consecutive', # These are handled specially by python_arg_parser.cpp 'add', 'add_', diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index a521c45..9a748c9 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -87,6 +87,7 @@ class Tensor: center=True, pad_mode='reflect', normalized=False, onesided=True): ... def split(self, split_size, dim=0): ... def unique(self, sorted=True, return_inverse=False, dim=None): ... + def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ... def lu(self, pivot=True, get_infos=False): ... ${function_hints} diff --git a/torch/functional.py b/torch/functional.py index 4a13658..fd0e47e 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -28,6 +28,7 @@ __all__ = [ 'tensordot', 'trtrs', 'unique', + 'unique_consecutive', ] @@ -449,6 +450,67 @@ def unique(input, sorted=True, return_inverse=False, dim=None): return output +def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): + r"""Eliminates all but the first element from every consecutive group of equivalent elements. + + .. note:: This function is different from :func:`torch.unique` in the sense that this function + only eliminates consecutive duplicate values. This semantics is similar to `std::unique` + in C++. + + Arguments: + input (Tensor): the input tensor + return_inverse (bool): Whether to also return the indices for where + elements in the original input ended up in the returned unique list. + return_counts (bool): Whether to also return the counts for each unique + element. + dim (int): the dimension to apply unique. If ``None``, the unique of the + flattened input is returned. default: ``None`` + + Returns: + (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing + + - **output** (*Tensor*): the output list of unique scalar elements. + - **inverse_indices** (*Tensor*): (optional) if + :attr:`return_inverse` is True, there will be an additional + returned tensor (same shape as input) representing the indices + for where elements in the original input map to in the output; + otherwise, this function will only return a single tensor. + - **counts** (*Tensor*): (optional) if + :attr:`return_counts` is True, there will be an additional + returned tensor (same shape as output or output.size(dim), + if dim was specified) representing the number of occurrences + for each unique value or tensor. + + Example:: + + >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2]) + >>> output = torch.unique_consecutive(x) + >>> output + tensor([1, 2, 3, 1, 2]) + + >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True) + >>> output + tensor([1, 2, 3, 1, 2]) + >>> inverse_indices + tensor([0, 0, 1, 1, 2, 3, 3, 4]) + + >>> output, counts = torch.unique_consecutive(x, return_counts=True) + >>> output + tensor([1, 2, 3, 1, 2]) + >>> counts + tensor([2, 2, 1, 2, 1]) + """ + output, inverse_indices, counts = torch._C._VariableFunctions.unique_consecutive( + input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + if return_inverse and return_counts: + return output, inverse_indices, counts + if return_inverse: + return output, inverse_indices + if return_counts: + return output, counts + return output + + def tensordot(a, b, dims=2): r"""Returns a contraction of a and b over multiple dimensions. diff --git a/torch/tensor.py b/torch/tensor.py index a788c50..1f98a61 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -368,6 +368,13 @@ class Tensor(torch._C._TensorBase): else: return output + def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): + r"""Eliminates all but the first element from every consecutive group of equivalent elements. + + See :func:`torch.unique_consecutive` + """ + return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + def __rsub__(self, other): return _C._VariableFunctions.rsub(self, other)