From: Xiang Gao Date: Tue, 16 Apr 2019 20:55:37 +0000 (-0700) Subject: Step 3: Add support for return_counts to torch.unique for dim not None (#18650) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~212 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=df67969e6b8846e901e65febd859a1293d8f0509;p=platform%2Fupstream%2Fpytorch.git Step 3: Add support for return_counts to torch.unique for dim not None (#18650) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18650 ghimport-source-id: 75759c95e6c48e27c172b919097dbc40c6bfb5e6 Differential Revision: D14892319 Pulled By: VitalyFedyunin fbshipit-source-id: ec5d1b80fc879d273ac5a534434fd648468dda1e --- diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1f860b2..64224df 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1898,6 +1898,15 @@ CPU: _unique_cpu CUDA: _unique_cuda +- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: unique_dim_cpu + CUDA: unique_dim_cuda + +# _unique_dim is deprecated and will be removed in the future. Please use unique_dim + - func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) variants: function dispatch: @@ -1917,7 +1926,7 @@ CUDA: unique_dim_consecutive_cuda # _unique and _unique_dim are fragile and modifying them easily cause internal break -# below two operators are a temporary hack for adding return_counts support +# the below operator is a temporary hack for adding return_counts support # Please don't rely on these two operators, they will be removed soon - func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) @@ -1926,13 +1935,6 @@ CPU: _unique2_cpu CUDA: _unique2_cuda -- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) - matches_jit_signature: True - variants: function - dispatch: - CPU: unique_dim_cpu - CUDA: unique_dim_cuda - - func: _unsafe_view(Tensor self, int[] size) -> Tensor - func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) diff --git a/test/test_torch.py b/test/test_torch.py index b7ba9f7..5449aec 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10711,7 +10711,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], def test_unique_dim(self): self.assertFalse(hasattr(torch, 'unique_dim')) - torch.unique_dim = torch._C._VariableFunctions.unique_dim def run_test(dtype=torch.float, device=torch.device('cpu')): x = torch.tensor([[[1., 1.], @@ -10766,7 +10765,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) - x_unique, _, x_counts = torch.unique_dim( + x_unique, x_counts = torch.unique( x, return_inverse=False, return_counts=True, @@ -10774,7 +10773,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_counts_dim0, x_counts) - x_unique, x_inverse, x_counts = torch.unique_dim( + x_unique, x_inverse, x_counts = torch.unique( x, return_inverse=True, return_counts=True, @@ -10794,7 +10793,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim1, x_unique) self.assertEqual(expected_inverse_dim1, x_inverse) - x_unique, _, x_counts = torch.unique_dim( + x_unique, x_counts = torch.unique( x, return_inverse=False, return_counts=True, @@ -10802,7 +10801,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim1, x_unique) self.assertEqual(expected_counts_dim1, x_counts) - x_unique, x_inverse, x_counts = torch.unique_dim( + x_unique, x_inverse, x_counts = torch.unique( x, return_inverse=True, return_counts=True, @@ -10822,7 +10821,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) - x_unique, _, x_counts = torch.unique_dim( + x_unique, x_counts = torch.unique( x, return_inverse=False, return_counts=True, @@ -10830,7 +10829,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_counts_dim2, x_counts) - x_unique, x_inverse, x_counts = torch.unique_dim( + x_unique, x_inverse, x_counts = torch.unique( x, return_inverse=True, return_counts=True, diff --git a/torch/functional.py b/torch/functional.py index 5f0f525..7b589ab 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -390,8 +390,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, del torch.unique_dim -def unique(input, sorted=True, return_inverse=False, dim=None): - r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. +def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + r"""Returns the unique elements of the input tensor. Arguments: input (Tensor): the input tensor @@ -399,18 +399,26 @@ def unique(input, sorted=True, return_inverse=False, dim=None): before returning as output. return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. + return_counts (bool): Whether to also return the counts for each unique + element. Currently only supported when `dim` is not None. dim (int): the dimension to apply unique. If ``None``, the unique of the flattened input is returned. default: ``None`` Returns: - (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing + (Tensor, Tensor (optional) Tensor (optional)):: + A tensor or a tuple of tensors containing - **output** (*Tensor*): the output list of unique scalar elements. - **inverse_indices** (*Tensor*): (optional) if - :attr:`return_inverse` is True, there will be a - 2nd returned tensor (same shape as input) representing the indices + :attr:`return_inverse` is True, there will be an additional + returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor. + - **counts** (*Tensor*): (optional) if + :attr:`return_counts` is True, there will be an additional + returned tensor (same shape as output or output.size(dim), + if dim was specified) representing the number of occurrences + for each unique value or tensor. Example:: @@ -435,20 +443,28 @@ def unique(input, sorted=True, return_inverse=False, dim=None): """ if dim is not None: - output, inverse_indices = torch._unique_dim( + output, inverse_indices, counts = torch._C._VariableFunctions.unique_dim( input, dim, sorted=sorted, - return_inverse=return_inverse + return_inverse=return_inverse, + return_counts=return_counts, ) else: + if return_counts: + raise NotImplementedError( + "torch.unique currently does not support return_counts with dim not None") output, inverse_indices = torch._unique( input, sorted=sorted, - return_inverse=return_inverse, + return_inverse=return_inverse ) - if return_inverse: + if return_inverse and return_counts: + return output, inverse_indices, counts + elif return_inverse: return output, inverse_indices + elif return_counts: + return output, counts else: return output diff --git a/torch/tensor.py b/torch/tensor.py index 1f98a61..c549a70 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -345,28 +345,12 @@ class Tensor(torch._C._TensorBase): else: return super(Tensor, self).split_with_sizes(split_size, dim) - def unique(self, sorted=True, return_inverse=False, dim=None): - r"""Returns the unique scalar elements of the tensor as a 1-D tensor. + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + r"""Returns the unique elements of the input tensor. See :func:`torch.unique` """ - if dim is not None: - output, inverse_indices = torch._unique_dim( - self, - sorted=sorted, - return_inverse=return_inverse, - dim=dim - ) - else: - output, inverse_indices = torch._unique( - self, - sorted=sorted, - return_inverse=return_inverse - ) - if return_inverse: - return output, inverse_indices - else: - return output + return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements.