From e2730ddb21164f1cb36330289ab2cd0629c06dc8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 25 Mar 2019 20:30:33 -0700 Subject: [PATCH] Add return_counts to torch.unique (#18391) Summary: Fixes: https://github.com/pytorch/pytorch/issues/12598 This PR was originally authorized by ptrblck at https://github.com/pytorch/pytorch/pull/15495, but since there was no update for months after the request change, I clone that branch and resolve the code reviews here. Hope everything is good now. Especially, the implementation of count is changed from ptrblck's original algorithm to the one ngimel suggest, i.e. using `unique_by_key` and `adjacent_difference`. The currently implementation of `_unique_dim` is VERY slow for computing inverse index and counts, see https://github.com/pytorch/pytorch/issues/18405. I will refactor `_unique_dim` in a later PR. For this PR, please allow me to keep the implementation as is. cc: ptrblck ezyang ngimel colesbury Pull Request resolved: https://github.com/pytorch/pytorch/pull/18391 Reviewed By: soumith Differential Revision: D14605905 Pulled By: VitalyFedyunin fbshipit-source-id: 555f5a12a8e28c38b10dfccf1b6bb16c030bfdce --- aten/src/ATen/native/Unique.cpp | 49 ++++++++----- aten/src/ATen/native/cuda/Unique.cu | 65 +++++++++++------ aten/src/ATen/native/native_functions.yaml | 4 +- aten/src/ATen/native/sparse/SparseTensor.cpp | 2 +- test/test_torch.py | 105 +++++++++++++++++++++++++-- tools/autograd/derivatives.yaml | 2 +- torch/functional.py | 32 +++++--- torch/onnx/symbolic.py | 7 +- torch/tensor.py | 16 ++-- 9 files changed, 214 insertions(+), 68 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 8cc867f..cd59f18 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -14,10 +14,11 @@ namespace native{ namespace { template -std::tuple _unique_cpu_template( +std::tuple _unique_cpu_template( const Tensor& self, const bool sorted, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { const Tensor& input = self.contiguous(); const scalar_t* input_data = input.data(); std::unordered_set set(input_data, input_data + input.numel()); @@ -33,7 +34,8 @@ std::tuple _unique_cpu_template( } Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); - if (return_inverse) { + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + if (return_inverse || return_counts) { inverse_indices.resize_(input.sizes()); int64_t* inverse_indices_data = inverse_indices.data(); std::unordered_map inverse_map; @@ -44,21 +46,29 @@ std::tuple _unique_cpu_template( for (int i = 0; i < input.numel(); ++i) { inverse_indices_data[i] = inverse_map[input_data[i]]; } + if (return_counts) { + counts.resize_(output.sizes()); + counts.fill_(0); + for (int i = 0; i < input.numel(); ++i) { + counts[inverse_map[input_data[i]]] += 1; + } + } } - return std::make_tuple(output, inverse_indices); + return std::make_tuple(output, inverse_indices, counts); } template ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, - std::vector& indices, Tensor inverse_indices_vec) { + std::vector& indices, Tensor inverse_indices_vec, Tensor counts) { if (first == last) { return last; } // save to calculate distance to iterators ForwardIt begin = first; - // set first inverse index + // set first inverse index and count inverse_indices_vec[indices[0]] = 0; + counts[0] += 1; ForwardIt result = first; while (++first != last) { @@ -68,16 +78,18 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, int64_t idx_result = std::distance(begin, result); int64_t idx_first = std::distance(begin, first); inverse_indices_vec[indices[idx_first]] = idx_result; + counts[idx_result] += 1; } return ++result; } template -std::tuple _unique_dim_cpu_template( +std::tuple _unique_dim_cpu_template( const Tensor& self, const int64_t dim, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { // reshape tensor as [dim, -1] Tensor input_flat = self.transpose(dim, 0); auto orig_sizes = input_flat.sizes().vec(); @@ -109,10 +121,12 @@ std::tuple _unique_dim_cpu_template( } Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong)); + Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong)); std::vector input_unbind = at::unbind(input_sorted, 0); auto last = _unique_dim_cpu_impl( - input_unbind.begin(), input_unbind.end(), indices, inverse_indices); + input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts); input_unbind.erase(last, input_unbind.end()); + counts = at::narrow(counts, 0, 0, input_unbind.size()); // reshape back auto output = at::stack(input_unbind, 0); @@ -121,22 +135,23 @@ std::tuple _unique_dim_cpu_template( output = output.view(new_sizes); output = output.transpose(0, dim); - return std::make_tuple(output, inverse_indices); + return std::make_tuple(output, inverse_indices, counts); } } // namespace -std::tuple -_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] { - return _unique_cpu_template(self, sorted, return_inverse); + +std::tuple +_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + return _unique_cpu_template(self, sorted, return_inverse, return_counts); }); } -std::tuple -_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { +std::tuple +_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors - return _unique_dim_cpu_template(self, dim, return_inverse); + return _unique_dim_cpu_template(self, dim, return_inverse, return_counts); }); } diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 0ba6812..204c9fc 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -16,9 +16,10 @@ namespace native{ namespace { template - std::tuple _unique_cuda_template( + std::tuple _unique_cuda_template( const Tensor& self, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); @@ -28,7 +29,7 @@ template int64_t num_inp = input.numel(); const scalar_t* input_data = input.data(); - //sort & unique + //sort Tensor output = input.clone(); output = output.view(-1); scalar_t* output_data = output.data(); @@ -47,21 +48,36 @@ template thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }}); inv_loc[0] = 0; thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr); - thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); + thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); inverse_indices.resize_(input.sizes()); } - int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data; - output.resize_(num_out); + + // unique + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + if (!return_counts) { + int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data; + output.resize_(num_out); + } else { + Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong)); + int64_t* sorted_indices_ptr = sorted_indices.data(); + int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data; + sorted_indices[num_out] = num_inp; + output.resize_(num_out); + counts.resize_(num_out); + int64_t* counts_ptr = counts.data(); + thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr); + } THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices); + return std::tuple(output, inverse_indices, counts); } template - std::tuple _unique_dim_cuda_template( + std::tuple _unique_dim_cuda_template( const Tensor& self, const int64_t dim, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); @@ -73,7 +89,7 @@ template scalar_t* input_flat_ptr = input_flat.data(); - Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong)); + Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong)); int64_t* indices_ptr = indices.data(); int64_t numel = input_flat.size(1); @@ -96,7 +112,7 @@ template // get unique tensors scalar_t* input_sorted_ptr = input_sorted.data(); - Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong)); + Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong)); int64_t* input_sorted_indices_ptr = input_sorted_indices.data(); auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(), [=] __device__ (int64_t a, int64_t b) -> bool { @@ -118,12 +134,13 @@ template output = output.view(new_sizes); output = output.transpose(0, dim); - // calculate inverse indices - Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong)); - if (return_inverse) { + // calculate inverse indices and counts + Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); + Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong)); + if (return_inverse || return_counts) { int64_t size = self.size(dim); inverse_indices.resize_(size); - Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong)); + Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong)); mask[0] = 1; for (int i = 0; i < input_sorted.size(0) - 1; ++i) { if (!at::equal(input_sorted[i], input_sorted[i+1])) { @@ -136,27 +153,29 @@ template Tensor imask = at::cumsum(mask, 0) - 1; for (int i = 0; i < indices.size(0); ++i) { inverse_indices[indices[i]] = imask[i]; + counts[inverse_indices[indices[i]]] += 1; } } THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices); + return std::tuple(output, inverse_indices, counts); } } // namespace -std::tuple -_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] { + +std::tuple +_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust - return _unique_cuda_template(self, return_inverse); + return _unique_cuda_template(self, return_inverse, return_counts); }); } -std::tuple -_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { +std::tuple +_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { - return _unique_dim_cuda_template(self, dim, return_inverse); + return _unique_dim_cuda_template(self, dim, return_inverse, return_counts); }); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d3dd911..bcdb704 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2327,14 +2327,14 @@ matches_jit_signature: True variants: method -- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) +- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) matches_jit_signature: True variants: function dispatch: CPU: _unique_cpu CUDA: _unique_cuda -- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) +- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) matches_jit_signature: True variants: function dispatch: diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index d3278c6..25389a5 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ indices = nz.clone(); } else { Tensor i = nz.narrow(0, 0, sparse_dim); - std::tie(indices, std::ignore) = _unique_dim(i, 1); + std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1); indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633 } diff --git a/test/test_torch.py b/test/test_torch.py index 989c1f2..e337797 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10184,6 +10184,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3]) expected_unique = torch.LongTensor([1, 2, 3, 5, 8]) expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2]) + expected_counts = torch.LongTensor([1, 3, 2, 1, 1]) x_unique = torch.unique(x) self.assertEqual( @@ -10197,38 +10198,62 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], x_unique = x.unique(sorted=True) self.assertEqual(expected_unique, x_unique) + x_unique, x_counts = x.unique(sorted=True, return_counts=True) + self.assertEqual(expected_counts, x_counts) + x_unique, x_inverse = torch.unique( x, sorted=True, return_inverse=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) + x_unique, x_inverse, x_counts = torch.unique( + x, sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + self.assertEqual(expected_counts, x_counts) + # Tests per-element unique on a higher rank tensor. y = x.view(2, 2, 2) y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) + y_unique, y_inverse, y_counts = y.unique( + sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) + self.assertEqual(expected_counts, y_counts) + # Tests unique on other types. - int_unique, int_inverse = torch.unique( - torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True) + int_unique, int_inverse, int_counts = torch.unique( + torch.IntTensor([2, 1, 2]), + sorted=True, + return_inverse=True, + return_counts=True + ) self.assertEqual(torch.IntTensor([1, 2]), int_unique) self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse) + self.assertEqual(torch.LongTensor([1, 2]), int_counts) - double_unique, double_inverse = torch.unique( + double_unique, double_inverse, double_counts = torch.unique( torch.DoubleTensor([2., 1.5, 2.1, 2.]), sorted=True, return_inverse=True, + return_counts=True ) self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique) self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse) + self.assertEqual(torch.LongTensor([1, 2, 1]), double_counts) - byte_unique, byte_inverse = torch.unique( + byte_unique, byte_inverse, byte_counts = torch.unique( torch.ByteTensor([133, 7, 7, 7, 42, 128]), sorted=True, return_inverse=True, + return_counts=True ) self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique) self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse) + self.assertEqual(torch.LongTensor([3, 1, 1, 1]), byte_counts) def test_unique_dim(self): def run_test(dtype=torch.float): @@ -10245,6 +10270,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [2., 1.], [0., 1.]]], dtype=dtype) expected_inverse_dim0 = torch.tensor([0, 0]) + expected_counts_dim0 = torch.tensor([2]) expected_unique_dim1 = torch.tensor([[[0., 1.], [1., 1.], [2., 1.]], @@ -10252,6 +10278,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [1., 1.], [2., 1.]]], dtype=dtype) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) + expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_unique_dim2 = torch.tensor([[[1., 1.], [0., 1.], [2., 1.], @@ -10261,30 +10288,94 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [2., 1.], [0., 1.]]], dtype=dtype) expected_inverse_dim2 = torch.tensor([0, 1]) + expected_counts_dim2 = torch.tensor([1, 1]) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + return_counts=False, + dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_counts_dim0, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + self.assertEqual(expected_counts_dim0, x_counts) + # dim1 x_unique = torch.unique(x, dim=1) self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + return_counts=False, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_counts_dim1, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=1) self.assertEqual(expected_unique_dim1, x_unique) self.assertEqual(expected_inverse_dim1, x_inverse) + self.assertEqual(expected_counts_dim1, x_counts) # dim2 x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + return_counts=False, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_counts_dim2, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) + self.assertEqual(expected_counts_dim2, x_counts) run_test(torch.float) run_test(torch.double) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b7acf82..b9a10d6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -871,7 +871,7 @@ - name: uniform_(Tensor self, double from, double to, Generator generator) self: zeros_like(grad) -- name: _unique(Tensor self, bool sorted, bool return_inverse) +- name: _unique(Tensor self, bool sorted, bool return_inverse, bool return_counts) self: not_implemented("_unique") - name: _unsafe_view(Tensor self, IntArrayRef size) diff --git a/torch/functional.py b/torch/functional.py index 580227c..9cd9e1c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -374,8 +374,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided) -def unique(input, sorted=True, return_inverse=False, dim=None): - r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. +def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + r"""Returns the unique elements of the input tensor. Arguments: input (Tensor): the input tensor @@ -383,18 +383,26 @@ def unique(input, sorted=True, return_inverse=False, dim=None): before returning as output. return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. + return_counts (bool): Whether to also return the counts for each unique + element. dim (int): the dimension to apply unique. If ``None``, the unique of the flattened input is returned. default: ``None`` Returns: - (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing + (Tensor, Tensor (optional) Tensor (optional)): + A tensor or a tuple of tensors containing - **output** (*Tensor*): the output list of unique scalar elements. - **inverse_indices** (*Tensor*): (optional) if - :attr:`return_inverse` is True, there will be a - 2nd returned tensor (same shape as input) representing the indices + :attr:`return_inverse` is True, there will be an additional + returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor. + - **counts** (*Tensor*): (optional) if + :attr:`return_counts` is True, there will be an additional + returned tensor (same shape as output or output.size(dim), + if dim was specified) representing the number of occurences + for each unique value or tensor. Example:: @@ -419,20 +427,26 @@ def unique(input, sorted=True, return_inverse=False, dim=None): """ if dim is not None: - output, inverse_indices = torch._unique_dim( + output, inverse_indices, counts = torch._unique_dim( input, dim, sorted=sorted, - return_inverse=return_inverse + return_inverse=return_inverse, + return_counts=return_counts ) else: - output, inverse_indices = torch._unique( + output, inverse_indices, counts = torch._unique( input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts ) - if return_inverse: + if return_inverse and return_counts: + return output, inverse_indices, counts + elif return_inverse: return output, inverse_indices + elif return_counts: + return output, counts else: return output diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 468ca46..2821b41 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1205,10 +1205,11 @@ def conv_tbc(g, input, weight, bias, pad): return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) -@parse_args('v', 'i', 'i') -def _unique(g, input, sorted, return_inverse): +@parse_args('v', 'i', 'i', 'i') +def _unique(g, input, sorted, return_inverse, return_counts): return g.op("ATen", input, operator_s="_unique", sorted_i=sorted, - return_inverse_i=return_inverse, outputs=2) + return_inverse_i=return_inverse, return_counts_i=return_counts, + outputs=3) # Metaprogram symbolics for each ATen native specialized cast operator. diff --git a/torch/tensor.py b/torch/tensor.py index bf239b3..f1ec022 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -315,26 +315,32 @@ class Tensor(torch._C._TensorBase): else: return super(Tensor, self).split_with_sizes(split_size, dim) - def unique(self, sorted=True, return_inverse=False, dim=None): + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): r"""Returns the unique scalar elements of the tensor as a 1-D tensor. See :func:`torch.unique` """ if dim is not None: - output, inverse_indices = torch._unique_dim( + output, inverse_indices, counts = torch._unique_dim( self, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim ) else: - output, inverse_indices = torch._unique( + output, inverse_indices, counts = torch._unique( self, sorted=sorted, - return_inverse=return_inverse + return_inverse=return_inverse, + return_counts=return_counts ) - if return_inverse: + if return_inverse and return_counts: + return output, inverse_indices, counts + elif return_inverse: return output, inverse_indices + elif return_counts: + return output, counts else: return output -- 2.7.4