Step 4: add support for unique with dim=None (#18651)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Fri, 19 Apr 2019 01:24:50 +0000 (18:24 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 01:28:07 +0000 (18:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18651
ghimport-source-id: e11988130a3f9a73529de0b0d08b4ec25fbc639c

Differential Revision: D15000463

Pulled By: VitalyFedyunin

fbshipit-source-id: 9e258e473dea6a3fc2307da2119b887ba3f7934a

aten/src/ATen/native/native_functions.yaml
test/test_torch.py
torch/functional.py

index 10e2265..8a07de1 100644 (file)
     CUDA: _unique_cuda
 
 - func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
-  matches_jit_signature: True
   variants: function
   dispatch:
     CPU: unique_dim_cpu
 # the below operator is a temporary hack for adding return_counts support
 # Please don't rely on these two operators, they will be removed soon
 
-- func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+- func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
   variants: function
   dispatch:
     CPU: _unique2_cpu
index aaba09d..b4d95bc 100644 (file)
@@ -10659,7 +10659,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             x_unique = x.unique(sorted=True)
             self.assertEqual(expected_unique, x_unique)
 
-            x_unique, _, x_counts = torch._unique2_temporary_will_remove_soon(x, sorted=True, return_counts=True)
+            x_unique, x_counts = torch.unique(x, sorted=True, return_counts=True)
             self.assertEqual(expected_counts, x_counts)
 
             x_unique, x_inverse = torch.unique(
@@ -10667,7 +10667,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique, x_unique)
             self.assertEqual(expected_inverse, x_inverse)
 
-            x_unique, x_inverse, x_counts = torch._unique2_temporary_will_remove_soon(
+            x_unique, x_inverse, x_counts = torch.unique(
                 x, sorted=True, return_inverse=True, return_counts=True)
             self.assertEqual(expected_unique, x_unique)
             self.assertEqual(expected_inverse, x_inverse)
@@ -10679,14 +10679,14 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_unique, y_unique)
             self.assertEqual(expected_inverse.view(y.size()), y_inverse)
 
-            y_unique, y_inverse, y_counts = torch._unique2_temporary_will_remove_soon(
+            y_unique, y_inverse, y_counts = torch.unique(
                 y, sorted=True, return_inverse=True, return_counts=True)
             self.assertEqual(expected_unique, y_unique)
             self.assertEqual(expected_inverse.view(y.size()), y_inverse)
             self.assertEqual(expected_counts, y_counts)
 
             # Tests unique on other types.
-            int_unique, int_inverse, int_counts = torch._unique2_temporary_will_remove_soon(
+            int_unique, int_inverse, int_counts = torch.unique(
                 torch.tensor([2, 1, 2], dtype=torch.int, device=device),
                 sorted=True,
                 return_inverse=True,
@@ -10696,7 +10696,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse)
             self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts)
 
-            double_unique, double_inverse, double_counts = torch._unique2_temporary_will_remove_soon(
+            double_unique, double_inverse, double_counts = torch.unique(
                 torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device),
                 sorted=True,
                 return_inverse=True,
@@ -10706,7 +10706,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse)
             self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts)
 
-            byte_unique, byte_inverse, byte_counts = torch._unique2_temporary_will_remove_soon(
+            byte_unique, byte_inverse, byte_counts = torch.unique(
                 torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device),
                 sorted=True,
                 return_inverse=True,
index 7b589ab..aae20d5 100644 (file)
@@ -400,7 +400,7 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
         return_inverse (bool): Whether to also return the indices for where
             elements in the original input ended up in the returned unique list.
         return_counts (bool): Whether to also return the counts for each unique
-            element. Currently only supported when `dim` is not None.
+            element.
         dim (int): the dimension to apply unique. If ``None``, the unique of the
             flattened input is returned. default: ``None``
 
@@ -451,13 +451,11 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
             return_counts=return_counts,
         )
     else:
-        if return_counts:
-            raise NotImplementedError(
-                "torch.unique currently does not support return_counts with dim not None")
-        output, inverse_indices = torch._unique(
+        output, inverse_indices, counts = torch._unique2(
             input,
             sorted=sorted,
-            return_inverse=return_inverse
+            return_inverse=return_inverse,
+            return_counts=return_counts,
         )
     if return_inverse and return_counts:
         return output, inverse_indices, counts