Step 3: Add support for return_counts to torch.unique for dim not None (#18650)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 16 Apr 2019 20:55:37 +0000 (13:55 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 21:06:45 +0000 (14:06 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18650
ghimport-source-id: 75759c95e6c48e27c172b919097dbc40c6bfb5e6

Differential Revision: D14892319

Pulled By: VitalyFedyunin

fbshipit-source-id: ec5d1b80fc879d273ac5a534434fd648468dda1e

aten/src/ATen/native/native_functions.yaml
test/test_torch.py
torch/functional.py
torch/tensor.py

index 1f860b2..64224df 100644 (file)
     CPU: _unique_cpu
     CUDA: _unique_cuda
 
+- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: unique_dim_cpu
+    CUDA: unique_dim_cuda
+
+# _unique_dim is deprecated and will be removed in the future. Please use unique_dim
+
 - func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
   variants: function
   dispatch:
     CUDA: unique_dim_consecutive_cuda
 
 # _unique and _unique_dim are fragile and modifying them easily cause internal break
-# below two operators are a temporary hack for adding return_counts support
+# the below operator is a temporary hack for adding return_counts support
 # Please don't rely on these two operators, they will be removed soon
 
 - func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
     CPU: _unique2_cpu
     CUDA: _unique2_cuda
 
-- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
-  matches_jit_signature: True
-  variants: function
-  dispatch:
-    CPU: unique_dim_cpu
-    CUDA: unique_dim_cuda
-
 - func: _unsafe_view(Tensor self, int[] size) -> Tensor
 
 - func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
index b7ba9f7..5449aec 100644 (file)
@@ -10711,7 +10711,6 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
 
     def test_unique_dim(self):
         self.assertFalse(hasattr(torch, 'unique_dim'))
-        torch.unique_dim = torch._C._VariableFunctions.unique_dim
 
         def run_test(dtype=torch.float, device=torch.device('cpu')):
             x = torch.tensor([[[1., 1.],
@@ -10766,7 +10765,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim0, x_unique)
             self.assertEqual(expected_inverse_dim0, x_inverse)
 
-            x_unique, _, x_counts = torch.unique_dim(
+            x_unique, x_counts = torch.unique(
                 x,
                 return_inverse=False,
                 return_counts=True,
@@ -10774,7 +10773,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim0, x_unique)
             self.assertEqual(expected_counts_dim0, x_counts)
 
-            x_unique, x_inverse, x_counts = torch.unique_dim(
+            x_unique, x_inverse, x_counts = torch.unique(
                 x,
                 return_inverse=True,
                 return_counts=True,
@@ -10794,7 +10793,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim1, x_unique)
             self.assertEqual(expected_inverse_dim1, x_inverse)
 
-            x_unique, _, x_counts = torch.unique_dim(
+            x_unique, x_counts = torch.unique(
                 x,
                 return_inverse=False,
                 return_counts=True,
@@ -10802,7 +10801,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim1, x_unique)
             self.assertEqual(expected_counts_dim1, x_counts)
 
-            x_unique, x_inverse, x_counts = torch.unique_dim(
+            x_unique, x_inverse, x_counts = torch.unique(
                 x,
                 return_inverse=True,
                 return_counts=True,
@@ -10822,7 +10821,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim2, x_unique)
             self.assertEqual(expected_inverse_dim2, x_inverse)
 
-            x_unique, _, x_counts = torch.unique_dim(
+            x_unique, x_counts = torch.unique(
                 x,
                 return_inverse=False,
                 return_counts=True,
@@ -10830,7 +10829,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique_dim2, x_unique)
             self.assertEqual(expected_counts_dim2, x_counts)
 
-            x_unique, x_inverse, x_counts = torch.unique_dim(
+            x_unique, x_inverse, x_counts = torch.unique(
                 x,
                 return_inverse=True,
                 return_counts=True,
index 5f0f525..7b589ab 100644 (file)
@@ -390,8 +390,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
 del torch.unique_dim
 
 
-def unique(input, sorted=True, return_inverse=False, dim=None):
-    r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
+def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    r"""Returns the unique elements of the input tensor.
 
     Arguments:
         input (Tensor): the input tensor
@@ -399,18 +399,26 @@ def unique(input, sorted=True, return_inverse=False, dim=None):
             before returning as output.
         return_inverse (bool): Whether to also return the indices for where
             elements in the original input ended up in the returned unique list.
+        return_counts (bool): Whether to also return the counts for each unique
+            element. Currently only supported when `dim` is not None.
         dim (int): the dimension to apply unique. If ``None``, the unique of the
             flattened input is returned. default: ``None``
 
     Returns:
-        (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing
+        (Tensor, Tensor (optional) Tensor (optional))::
+        A tensor or a tuple of tensors containing
 
             - **output** (*Tensor*): the output list of unique scalar elements.
             - **inverse_indices** (*Tensor*): (optional) if
-              :attr:`return_inverse` is True, there will be a
-              2nd returned tensor (same shape as input) representing the indices
+              :attr:`return_inverse` is True, there will be an additional
+              returned tensor (same shape as input) representing the indices
               for where elements in the original input map to in the output;
               otherwise, this function will only return a single tensor.
+            - **counts** (*Tensor*): (optional) if
+              :attr:`return_counts` is True, there will be an additional
+              returned tensor (same shape as output or output.size(dim),
+              if dim was specified) representing the number of occurrences
+              for each unique value or tensor.
 
     Example::
 
@@ -435,20 +443,28 @@ def unique(input, sorted=True, return_inverse=False, dim=None):
 
     """
     if dim is not None:
-        output, inverse_indices = torch._unique_dim(
+        output, inverse_indices, counts = torch._C._VariableFunctions.unique_dim(
             input,
             dim,
             sorted=sorted,
-            return_inverse=return_inverse
+            return_inverse=return_inverse,
+            return_counts=return_counts,
         )
     else:
+        if return_counts:
+            raise NotImplementedError(
+                "torch.unique currently does not support return_counts with dim not None")
         output, inverse_indices = torch._unique(
             input,
             sorted=sorted,
-            return_inverse=return_inverse,
+            return_inverse=return_inverse
         )
-    if return_inverse:
+    if return_inverse and return_counts:
+        return output, inverse_indices, counts
+    elif return_inverse:
         return output, inverse_indices
+    elif return_counts:
+        return output, counts
     else:
         return output
 
index 1f98a61..c549a70 100644 (file)
@@ -345,28 +345,12 @@ class Tensor(torch._C._TensorBase):
         else:
             return super(Tensor, self).split_with_sizes(split_size, dim)
 
-    def unique(self, sorted=True, return_inverse=False, dim=None):
-        r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
+    def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
+        r"""Returns the unique elements of the input tensor.
 
         See :func:`torch.unique`
         """
-        if dim is not None:
-            output, inverse_indices = torch._unique_dim(
-                self,
-                sorted=sorted,
-                return_inverse=return_inverse,
-                dim=dim
-            )
-        else:
-            output, inverse_indices = torch._unique(
-                self,
-                sorted=sorted,
-                return_inverse=return_inverse
-            )
-        if return_inverse:
-            return output, inverse_indices
-        else:
-            return output
+        return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
 
     def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
         r"""Eliminates all but the first element from every consecutive group of equivalent elements.