From 66628f78b780726c8bd1d29a2cc798712d211cd9 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Tue, 26 Mar 2019 17:14:26 -0700 Subject: [PATCH] Revert D14605905: [pytorch][PR] Add return_counts to torch.unique Differential Revision: D14605905 Original commit changeset: 555f5a12a8e2 fbshipit-source-id: c7874f5987893e956c022180a37763d88bba38db --- aten/src/ATen/native/Unique.cpp | 49 +++++-------- aten/src/ATen/native/cuda/Unique.cu | 65 ++++++----------- aten/src/ATen/native/native_functions.yaml | 4 +- aten/src/ATen/native/sparse/SparseTensor.cpp | 2 +- test/test_torch.py | 105 ++------------------------- tools/autograd/derivatives.yaml | 2 +- torch/functional.py | 32 +++----- torch/onnx/symbolic.py | 7 +- torch/tensor.py | 16 ++-- 9 files changed, 68 insertions(+), 214 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index cd59f18..8cc867f 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -14,11 +14,10 @@ namespace native{ namespace { template -std::tuple _unique_cpu_template( +std::tuple _unique_cpu_template( const Tensor& self, const bool sorted, - const bool return_inverse, - const bool return_counts) { + const bool return_inverse) { const Tensor& input = self.contiguous(); const scalar_t* input_data = input.data(); std::unordered_set set(input_data, input_data + input.numel()); @@ -34,8 +33,7 @@ std::tuple _unique_cpu_template( } Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); - Tensor counts = at::empty({0}, self.options().dtype(kLong)); - if (return_inverse || return_counts) { + if (return_inverse) { inverse_indices.resize_(input.sizes()); int64_t* inverse_indices_data = inverse_indices.data(); std::unordered_map inverse_map; @@ -46,29 +44,21 @@ std::tuple _unique_cpu_template( for (int i = 0; i < input.numel(); ++i) { inverse_indices_data[i] = inverse_map[input_data[i]]; } - if (return_counts) { - counts.resize_(output.sizes()); - counts.fill_(0); - for (int i = 0; i < input.numel(); ++i) { - counts[inverse_map[input_data[i]]] += 1; - } - } } - return std::make_tuple(output, inverse_indices, counts); + return std::make_tuple(output, inverse_indices); } template ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, - std::vector& indices, Tensor inverse_indices_vec, Tensor counts) { + std::vector& indices, Tensor inverse_indices_vec) { if (first == last) { return last; } // save to calculate distance to iterators ForwardIt begin = first; - // set first inverse index and count + // set first inverse index inverse_indices_vec[indices[0]] = 0; - counts[0] += 1; ForwardIt result = first; while (++first != last) { @@ -78,18 +68,16 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, int64_t idx_result = std::distance(begin, result); int64_t idx_first = std::distance(begin, first); inverse_indices_vec[indices[idx_first]] = idx_result; - counts[idx_result] += 1; } return ++result; } template -std::tuple _unique_dim_cpu_template( +std::tuple _unique_dim_cpu_template( const Tensor& self, const int64_t dim, - const bool return_inverse, - const bool return_counts) { + const bool return_inverse) { // reshape tensor as [dim, -1] Tensor input_flat = self.transpose(dim, 0); auto orig_sizes = input_flat.sizes().vec(); @@ -121,12 +109,10 @@ std::tuple _unique_dim_cpu_template( } Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong)); - Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong)); std::vector input_unbind = at::unbind(input_sorted, 0); auto last = _unique_dim_cpu_impl( - input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts); + input_unbind.begin(), input_unbind.end(), indices, inverse_indices); input_unbind.erase(last, input_unbind.end()); - counts = at::narrow(counts, 0, 0, input_unbind.size()); // reshape back auto output = at::stack(input_unbind, 0); @@ -135,23 +121,22 @@ std::tuple _unique_dim_cpu_template( output = output.view(new_sizes); output = output.transpose(0, dim); - return std::make_tuple(output, inverse_indices, counts); + return std::make_tuple(output, inverse_indices); } } // namespace - -std::tuple -_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { - return _unique_cpu_template(self, sorted, return_inverse, return_counts); +std::tuple +_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] { + return _unique_cpu_template(self, sorted, return_inverse); }); } -std::tuple -_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { +std::tuple +_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors - return _unique_dim_cpu_template(self, dim, return_inverse, return_counts); + return _unique_dim_cpu_template(self, dim, return_inverse); }); } diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 204c9fc..0ba6812 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -16,10 +16,9 @@ namespace native{ namespace { template - std::tuple _unique_cuda_template( + std::tuple _unique_cuda_template( const Tensor& self, - const bool return_inverse, - const bool return_counts) { + const bool return_inverse) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); @@ -29,7 +28,7 @@ template int64_t num_inp = input.numel(); const scalar_t* input_data = input.data(); - //sort + //sort & unique Tensor output = input.clone(); output = output.view(-1); scalar_t* output_data = output.data(); @@ -48,36 +47,21 @@ template thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }}); inv_loc[0] = 0; thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr); - thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); + thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); inverse_indices.resize_(input.sizes()); } - - // unique - Tensor counts = at::empty({0}, self.options().dtype(kLong)); - if (!return_counts) { - int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data; - output.resize_(num_out); - } else { - Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong)); - int64_t* sorted_indices_ptr = sorted_indices.data(); - int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data; - sorted_indices[num_out] = num_inp; - output.resize_(num_out); - counts.resize_(num_out); - int64_t* counts_ptr = counts.data(); - thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr); - } + int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data; + output.resize_(num_out); THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices, counts); + return std::tuple(output, inverse_indices); } template - std::tuple _unique_dim_cuda_template( + std::tuple _unique_dim_cuda_template( const Tensor& self, const int64_t dim, - const bool return_inverse, - const bool return_counts) { + const bool return_inverse) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); @@ -89,7 +73,7 @@ template scalar_t* input_flat_ptr = input_flat.data(); - Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong)); + Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong)); int64_t* indices_ptr = indices.data(); int64_t numel = input_flat.size(1); @@ -112,7 +96,7 @@ template // get unique tensors scalar_t* input_sorted_ptr = input_sorted.data(); - Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong)); + Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong)); int64_t* input_sorted_indices_ptr = input_sorted_indices.data(); auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(), [=] __device__ (int64_t a, int64_t b) -> bool { @@ -134,13 +118,12 @@ template output = output.view(new_sizes); output = output.transpose(0, dim); - // calculate inverse indices and counts - Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); - Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong)); - if (return_inverse || return_counts) { + // calculate inverse indices + Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong)); + if (return_inverse) { int64_t size = self.size(dim); inverse_indices.resize_(size); - Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong)); + Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong)); mask[0] = 1; for (int i = 0; i < input_sorted.size(0) - 1; ++i) { if (!at::equal(input_sorted[i], input_sorted[i+1])) { @@ -153,29 +136,27 @@ template Tensor imask = at::cumsum(mask, 0) - 1; for (int i = 0; i < indices.size(0); ++i) { inverse_indices[indices[i]] = imask[i]; - counts[inverse_indices[indices[i]]] += 1; } } THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices, counts); + return std::tuple(output, inverse_indices); } } // namespace - -std::tuple -_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { +std::tuple +_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust - return _unique_cuda_template(self, return_inverse, return_counts); + return _unique_cuda_template(self, return_inverse); }); } -std::tuple -_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { +std::tuple +_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { - return _unique_dim_cuda_template(self, dim, return_inverse, return_counts); + return _unique_dim_cuda_template(self, dim, return_inverse); }); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0d4da52..152a2ae 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2339,14 +2339,14 @@ matches_jit_signature: True variants: method -- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) +- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) matches_jit_signature: True variants: function dispatch: CPU: _unique_cpu CUDA: _unique_cuda -- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) +- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) matches_jit_signature: True variants: function dispatch: diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 25389a5..d3278c6 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ indices = nz.clone(); } else { Tensor i = nz.narrow(0, 0, sparse_dim); - std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1); + std::tie(indices, std::ignore) = _unique_dim(i, 1); indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633 } diff --git a/test/test_torch.py b/test/test_torch.py index f895516..1718ec4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10362,7 +10362,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3]) expected_unique = torch.LongTensor([1, 2, 3, 5, 8]) expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2]) - expected_counts = torch.LongTensor([1, 3, 2, 1, 1]) x_unique = torch.unique(x) self.assertEqual( @@ -10376,62 +10375,38 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], x_unique = x.unique(sorted=True) self.assertEqual(expected_unique, x_unique) - x_unique, x_counts = x.unique(sorted=True, return_counts=True) - self.assertEqual(expected_counts, x_counts) - x_unique, x_inverse = torch.unique( x, sorted=True, return_inverse=True) self.assertEqual(expected_unique, x_unique) self.assertEqual(expected_inverse, x_inverse) - x_unique, x_inverse, x_counts = torch.unique( - x, sorted=True, return_inverse=True, return_counts=True) - self.assertEqual(expected_unique, x_unique) - self.assertEqual(expected_inverse, x_inverse) - self.assertEqual(expected_counts, x_counts) - # Tests per-element unique on a higher rank tensor. y = x.view(2, 2, 2) y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(y.size()), y_inverse) - y_unique, y_inverse, y_counts = y.unique( - sorted=True, return_inverse=True, return_counts=True) - self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse.view(y.size()), y_inverse) - self.assertEqual(expected_counts, y_counts) - # Tests unique on other types. - int_unique, int_inverse, int_counts = torch.unique( - torch.IntTensor([2, 1, 2]), - sorted=True, - return_inverse=True, - return_counts=True - ) + int_unique, int_inverse = torch.unique( + torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True) self.assertEqual(torch.IntTensor([1, 2]), int_unique) self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse) - self.assertEqual(torch.LongTensor([1, 2]), int_counts) - double_unique, double_inverse, double_counts = torch.unique( + double_unique, double_inverse = torch.unique( torch.DoubleTensor([2., 1.5, 2.1, 2.]), sorted=True, return_inverse=True, - return_counts=True ) self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique) self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse) - self.assertEqual(torch.LongTensor([1, 2, 1]), double_counts) - byte_unique, byte_inverse, byte_counts = torch.unique( + byte_unique, byte_inverse = torch.unique( torch.ByteTensor([133, 7, 7, 7, 42, 128]), sorted=True, return_inverse=True, - return_counts=True ) self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique) self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse) - self.assertEqual(torch.LongTensor([3, 1, 1, 1]), byte_counts) def test_unique_dim(self): def run_test(dtype=torch.float): @@ -10448,7 +10423,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [2., 1.], [0., 1.]]], dtype=dtype) expected_inverse_dim0 = torch.tensor([0, 0]) - expected_counts_dim0 = torch.tensor([2]) expected_unique_dim1 = torch.tensor([[[0., 1.], [1., 1.], [2., 1.]], @@ -10456,7 +10430,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [1., 1.], [2., 1.]]], dtype=dtype) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) - expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_unique_dim2 = torch.tensor([[[1., 1.], [0., 1.], [2., 1.], @@ -10466,94 +10439,30 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [2., 1.], [0., 1.]]], dtype=dtype) expected_inverse_dim2 = torch.tensor([0, 1]) - expected_counts_dim2 = torch.tensor([1, 1]) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - return_counts=False, - dim=0) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_counts_dim0, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_inverse_dim0, x_inverse) - self.assertEqual(expected_counts_dim0, x_counts) - # dim1 x_unique = torch.unique(x, dim=1) self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - return_counts=False, - dim=1) - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_inverse_dim1, x_inverse) - - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_counts_dim1, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) self.assertEqual(expected_unique_dim1, x_unique) self.assertEqual(expected_inverse_dim1, x_inverse) - self.assertEqual(expected_counts_dim1, x_counts) # dim2 x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - return_counts=False, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_inverse_dim2, x_inverse) - - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_counts_dim2, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) - self.assertEqual(expected_counts_dim2, x_counts) run_test(torch.float) run_test(torch.double) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 0a460a2..4a59185 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -871,7 +871,7 @@ - name: uniform_(Tensor self, double from, double to, Generator generator) self: zeros_like(grad) -- name: _unique(Tensor self, bool sorted, bool return_inverse, bool return_counts) +- name: _unique(Tensor self, bool sorted, bool return_inverse) self: not_implemented("_unique") - name: _unsafe_view(Tensor self, IntArrayRef size) diff --git a/torch/functional.py b/torch/functional.py index 9cd9e1c..580227c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -374,8 +374,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided) -def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): - r"""Returns the unique elements of the input tensor. +def unique(input, sorted=True, return_inverse=False, dim=None): + r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. Arguments: input (Tensor): the input tensor @@ -383,26 +383,18 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No before returning as output. return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. - return_counts (bool): Whether to also return the counts for each unique - element. dim (int): the dimension to apply unique. If ``None``, the unique of the flattened input is returned. default: ``None`` Returns: - (Tensor, Tensor (optional) Tensor (optional)): - A tensor or a tuple of tensors containing + (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing - **output** (*Tensor*): the output list of unique scalar elements. - **inverse_indices** (*Tensor*): (optional) if - :attr:`return_inverse` is True, there will be an additional - returned tensor (same shape as input) representing the indices + :attr:`return_inverse` is True, there will be a + 2nd returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor. - - **counts** (*Tensor*): (optional) if - :attr:`return_counts` is True, there will be an additional - returned tensor (same shape as output or output.size(dim), - if dim was specified) representing the number of occurences - for each unique value or tensor. Example:: @@ -427,26 +419,20 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No """ if dim is not None: - output, inverse_indices, counts = torch._unique_dim( + output, inverse_indices = torch._unique_dim( input, dim, sorted=sorted, - return_inverse=return_inverse, - return_counts=return_counts + return_inverse=return_inverse ) else: - output, inverse_indices, counts = torch._unique( + output, inverse_indices = torch._unique( input, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts ) - if return_inverse and return_counts: - return output, inverse_indices, counts - elif return_inverse: + if return_inverse: return output, inverse_indices - elif return_counts: - return output, counts else: return output diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 288f362..157b895 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1205,11 +1205,10 @@ def conv_tbc(g, input, weight, bias, pad): return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) -@parse_args('v', 'i', 'i', 'i') -def _unique(g, input, sorted, return_inverse, return_counts): +@parse_args('v', 'i', 'i') +def _unique(g, input, sorted, return_inverse): return g.op("ATen", input, operator_s="_unique", sorted_i=sorted, - return_inverse_i=return_inverse, return_counts_i=return_counts, - outputs=3) + return_inverse_i=return_inverse, outputs=2) # Metaprogram symbolics for each ATen native specialized cast operator. diff --git a/torch/tensor.py b/torch/tensor.py index f1ec022..bf239b3 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -315,32 +315,26 @@ class Tensor(torch._C._TensorBase): else: return super(Tensor, self).split_with_sizes(split_size, dim) - def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + def unique(self, sorted=True, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the tensor as a 1-D tensor. See :func:`torch.unique` """ if dim is not None: - output, inverse_indices, counts = torch._unique_dim( + output, inverse_indices = torch._unique_dim( self, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts, dim=dim ) else: - output, inverse_indices, counts = torch._unique( + output, inverse_indices = torch._unique( self, sorted=sorted, - return_inverse=return_inverse, - return_counts=return_counts + return_inverse=return_inverse ) - if return_inverse and return_counts: - return output, inverse_indices, counts - elif return_inverse: + if return_inverse: return output, inverse_indices - elif return_counts: - return output, counts else: return output -- 2.7.4