From e1750754c80c87385061995bc46c2770f56a1e39 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 18 Apr 2019 18:24:50 -0700 Subject: [PATCH] Step 4: add support for unique with dim=None (#18651) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18651 ghimport-source-id: e11988130a3f9a73529de0b0d08b4ec25fbc639c Differential Revision: D15000463 Pulled By: VitalyFedyunin fbshipit-source-id: 9e258e473dea6a3fc2307da2119b887ba3f7934a --- aten/src/ATen/native/native_functions.yaml | 3 +-- test/test_torch.py | 12 ++++++------ torch/functional.py | 10 ++++------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 10e2265..8a07de1 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1916,7 +1916,6 @@ CUDA: _unique_cuda - func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) - matches_jit_signature: True variants: function dispatch: CPU: unique_dim_cpu @@ -1946,7 +1945,7 @@ # the below operator is a temporary hack for adding return_counts support # Please don't rely on these two operators, they will be removed soon -- func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) +- func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) variants: function dispatch: CPU: _unique2_cpu diff --git a/test/test_torch.py b/test/test_torch.py index aaba09d..b4d95bc 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10659,7 +10659,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], x_unique = x.unique(sorted=True) self.assertEqual(expected_unique, x_unique) - x_unique, _, x_counts = torch._unique2_temporary_will_remove_soon(x, sorted=True, return_counts=True) + x_unique, x_counts = torch.unique(x, sorted=True, return_counts=True) self.assertEqual(expected_counts, x_counts) x_unique, x_inverse = torch.unique( @@ -10667,7 +10667,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) - x_unique, x_inverse, x_counts = torch._unique2_temporary_will_remove_soon( + x_unique, x_inverse, x_counts = torch.unique( x, sorted=True, return_inverse=True, return_counts=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) @@ -10679,14 +10679,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) - y_unique, y_inverse, y_counts = torch._unique2_temporary_will_remove_soon( + y_unique, y_inverse, y_counts = torch.unique( y, sorted=True, return_inverse=True, return_counts=True) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) self.assertEqual(expected_counts, y_counts) # Tests unique on other types. - int_unique, int_inverse, int_counts = torch._unique2_temporary_will_remove_soon( + int_unique, int_inverse, int_counts = torch.unique( torch.tensor([2, 1, 2], dtype=torch.int, device=device), sorted=True, return_inverse=True, @@ -10696,7 +10696,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse) self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts) - double_unique, double_inverse, double_counts = torch._unique2_temporary_will_remove_soon( + double_unique, double_inverse, double_counts = torch.unique( torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device), sorted=True, return_inverse=True, @@ -10706,7 +10706,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse) self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts) - byte_unique, byte_inverse, byte_counts = torch._unique2_temporary_will_remove_soon( + byte_unique, byte_inverse, byte_counts = torch.unique( torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device), sorted=True, return_inverse=True, diff --git a/torch/functional.py b/torch/functional.py index 7b589ab..aae20d5 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -400,7 +400,7 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. return_counts (bool): Whether to also return the counts for each unique - element. Currently only supported when `dim` is not None. + element. dim (int): the dimension to apply unique. If ``None``, the unique of the flattened input is returned. default: ``None`` @@ -451,13 +451,11 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No return_counts=return_counts, ) else: - if return_counts: - raise NotImplementedError( - "torch.unique currently does not support return_counts with dim not None") - output, inverse_indices = torch._unique( + output, inverse_indices, counts = torch._unique2( input, sorted=sorted, - return_inverse=return_inverse + return_inverse=return_inverse, + return_counts=return_counts, ) if return_inverse and return_counts: return output, inverse_indices, counts -- 2.7.4