Revert D14605905: [pytorch][PR] Add return_counts to torch.unique
authorSoumith Chintala <soumith@fb.com>
Wed, 27 Mar 2019 00:14:26 +0000 (17:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 00:18:01 +0000 (17:18 -0700)
Differential Revision:
D14605905

Original commit changeset: 555f5a12a8e2

fbshipit-source-id: c7874f5987893e956c022180a37763d88bba38db

aten/src/ATen/native/Unique.cpp
aten/src/ATen/native/cuda/Unique.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensor.cpp
test/test_torch.py
tools/autograd/derivatives.yaml
torch/functional.py
torch/onnx/symbolic.py
torch/tensor.py

index cd59f18..8cc867f 100644 (file)
@@ -14,11 +14,10 @@ namespace native{
 namespace {
 
 template <typename scalar_t>
-std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
+std::tuple<Tensor, Tensor> _unique_cpu_template(
     const Tensor& self,
     const bool sorted,
-    const bool return_inverse,
-    const bool return_counts) {
+    const bool return_inverse) {
   const Tensor& input = self.contiguous();
   const scalar_t* input_data = input.data<scalar_t>();
   std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
@@ -34,8 +33,7 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
   }
 
   Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
-  Tensor counts = at::empty({0}, self.options().dtype(kLong));
-  if (return_inverse || return_counts) {
+  if (return_inverse) {
     inverse_indices.resize_(input.sizes());
     int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
     std::unordered_map<scalar_t, int64_t> inverse_map;
@@ -46,29 +44,21 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
     for (int i = 0; i < input.numel(); ++i) {
       inverse_indices_data[i] = inverse_map[input_data[i]];
     }
-    if (return_counts) {
-      counts.resize_(output.sizes());
-      counts.fill_(0);
-      for (int i = 0; i < input.numel(); ++i) {
-        counts[inverse_map[input_data[i]]] += 1;
-      }
-    }
   }
-  return std::make_tuple(output, inverse_indices, counts);
+  return std::make_tuple(output, inverse_indices);
 }
 
 template<class ForwardIt>
 ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
-  std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
+  std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
     if (first == last) {
       return last;
     }
     // save to calculate distance to iterators
     ForwardIt begin = first;
 
-    // set first inverse index and count
+    // set first inverse index
     inverse_indices_vec[indices[0]] = 0;
-    counts[0] += 1;
 
     ForwardIt result = first;
     while (++first != last) {
@@ -78,18 +68,16 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
       int64_t idx_result = std::distance(begin, result);
       int64_t idx_first = std::distance(begin, first);
       inverse_indices_vec[indices[idx_first]] = idx_result;
-      counts[idx_result] += 1;
     }
 
     return ++result;
   }
 
 template <typename scalar_t>
-std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
+std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
     const Tensor& self,
     const int64_t dim,
-    const bool return_inverse,
-    const bool return_counts) {
+    const bool return_inverse) {
   // reshape tensor as [dim, -1]
   Tensor input_flat = self.transpose(dim, 0);
   auto orig_sizes = input_flat.sizes().vec();
@@ -121,12 +109,10 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
   }
 
   Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
-  Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong));
   std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
   auto last = _unique_dim_cpu_impl(
-    input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts);
+    input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
   input_unbind.erase(last, input_unbind.end());
-  counts = at::narrow(counts, 0, 0, input_unbind.size());
 
   // reshape back
   auto output = at::stack(input_unbind, 0);
@@ -135,23 +121,22 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
   output = output.view(new_sizes);
   output = output.transpose(0, dim);
 
-  return std::make_tuple(output, inverse_indices, counts);
+  return std::make_tuple(output, inverse_indices);
 }
 } // namespace
 
-
-std::tuple<Tensor, Tensor, Tensor>
-_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
-    return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
+std::tuple<Tensor, Tensor>
+_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
+    return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
   });
 }
 
-std::tuple<Tensor, Tensor, Tensor>
-_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
+std::tuple<Tensor, Tensor>
+_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     // The current implementation using `dim` always sorts due to unhashable tensors
-    return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
+    return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
   });
 }
 
index 204c9fc..0ba6812 100644 (file)
@@ -16,10 +16,9 @@ namespace native{
 namespace {
 
 template <typename scalar_t>
-  std::tuple<Tensor, Tensor, Tensor> _unique_cuda_template(
+  std::tuple<Tensor, Tensor> _unique_cuda_template(
     const Tensor& self,
-    const bool return_inverse,
-    const bool return_counts) {
+    const bool return_inverse) {
 
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
     auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
@@ -29,7 +28,7 @@ template <typename scalar_t>
     int64_t num_inp = input.numel();
     const scalar_t* input_data = input.data<scalar_t>();
 
-    //sort
+    //sort & unique
     Tensor output = input.clone();
     output = output.view(-1);
     scalar_t* output_data = output.data<scalar_t>();
@@ -48,36 +47,21 @@ template <typename scalar_t>
         thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
         inv_loc[0] = 0;
         thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
-        thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
+        thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
         inverse_indices.resize_(input.sizes());
     }
-
-    // unique
-    Tensor counts = at::empty({0}, self.options().dtype(kLong));
-    if (!return_counts) {
-      int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
-      output.resize_(num_out);
-    } else {
-      Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong));
-      int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
-      int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data;
-      sorted_indices[num_out] = num_inp;
-      output.resize_(num_out);
-      counts.resize_(num_out);
-      int64_t* counts_ptr = counts.data<int64_t>();
-      thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr);
-    }
+    int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
+    output.resize_(num_out);
 
     THCudaCheck(cudaGetLastError());
-    return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
+    return std::tuple<Tensor, Tensor>(output, inverse_indices);
   }
 
 template <typename scalar_t>
-  std::tuple<Tensor, Tensor, Tensor> _unique_dim_cuda_template(
+  std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
     const Tensor& self,
     const int64_t dim,
-    const bool return_inverse,
-    const bool return_counts) {
+    const bool return_inverse) {
 
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
     auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
@@ -89,7 +73,7 @@ template <typename scalar_t>
 
     scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
 
-    Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong));
+    Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
     int64_t* indices_ptr = indices.data<int64_t>();
     int64_t numel = input_flat.size(1);
 
@@ -112,7 +96,7 @@ template <typename scalar_t>
 
     // get unique tensors
     scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
-    Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong));
+    Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
     int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
     auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
       [=] __device__ (int64_t a, int64_t b) -> bool {
@@ -134,13 +118,12 @@ template <typename scalar_t>
     output = output.view(new_sizes);
     output = output.transpose(0, dim);
 
-    // calculate inverse indices and counts
-    Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
-    Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong));
-    if (return_inverse || return_counts) {
+    // calculate inverse indices
+    Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
+    if (return_inverse) {
       int64_t size = self.size(dim);
       inverse_indices.resize_(size);
-      Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong));
+      Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
       mask[0] = 1;
       for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
         if (!at::equal(input_sorted[i], input_sorted[i+1])) {
@@ -153,29 +136,27 @@ template <typename scalar_t>
       Tensor imask = at::cumsum(mask, 0) - 1;
       for (int i = 0; i < indices.size(0); ++i) {
         inverse_indices[indices[i]] = imask[i];
-        counts[inverse_indices[indices[i]]] += 1;
       }
     }
 
     THCudaCheck(cudaGetLastError());
-    return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
+    return std::tuple<Tensor, Tensor>(output, inverse_indices);
   }
 } // namespace
 
-
-std::tuple<Tensor, Tensor, Tensor>
-_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
-  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
+std::tuple<Tensor, Tensor>
+_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
-    return _unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
+    return _unique_cuda_template<scalar_t>(self, return_inverse);
   });
 }
 
-std::tuple<Tensor, Tensor, Tensor>
-_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
+std::tuple<Tensor, Tensor>
+_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
-    return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
+    return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
   });
 }
 
index 0d4da52..152a2ae 100644 (file)
   matches_jit_signature: True
   variants: method
 
-- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function
   dispatch:
     CPU: _unique_cpu
     CUDA: _unique_cuda
 
-- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function
   dispatch:
index 25389a5..d3278c6 100644 (file)
@@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
     indices = nz.clone();
   } else {
     Tensor i = nz.narrow(0, 0, sparse_dim);
-    std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1);
+    std::tie(indices, std::ignore) = _unique_dim(i, 1);
     indices = indices.contiguous();  // many sparse CUDA kernels require contiguity, see issue #12633
   }
 
index f895516..1718ec4 100644 (file)
@@ -10362,7 +10362,6 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
         expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
         expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
-        expected_counts = torch.LongTensor([1, 3, 2, 1, 1])
 
         x_unique = torch.unique(x)
         self.assertEqual(
@@ -10376,62 +10375,38 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         x_unique = x.unique(sorted=True)
         self.assertEqual(expected_unique, x_unique)
 
-        x_unique, x_counts = x.unique(sorted=True, return_counts=True)
-        self.assertEqual(expected_counts, x_counts)
-
         x_unique, x_inverse = torch.unique(
             x, sorted=True, return_inverse=True)
         self.assertEqual(expected_unique, x_unique)
         self.assertEqual(expected_inverse, x_inverse)
 
-        x_unique, x_inverse, x_counts = torch.unique(
-            x, sorted=True, return_inverse=True, return_counts=True)
-        self.assertEqual(expected_unique, x_unique)
-        self.assertEqual(expected_inverse, x_inverse)
-        self.assertEqual(expected_counts, x_counts)
-
         # Tests per-element unique on a higher rank tensor.
         y = x.view(2, 2, 2)
         y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
         self.assertEqual(expected_unique, y_unique)
         self.assertEqual(expected_inverse.view(y.size()), y_inverse)
 
-        y_unique, y_inverse, y_counts = y.unique(
-            sorted=True, return_inverse=True, return_counts=True)
-        self.assertEqual(expected_unique, y_unique)
-        self.assertEqual(expected_inverse.view(y.size()), y_inverse)
-        self.assertEqual(expected_counts, y_counts)
-
         # Tests unique on other types.
-        int_unique, int_inverse, int_counts = torch.unique(
-            torch.IntTensor([2, 1, 2]),
-            sorted=True,
-            return_inverse=True,
-            return_counts=True
-        )
+        int_unique, int_inverse = torch.unique(
+            torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True)
         self.assertEqual(torch.IntTensor([1, 2]), int_unique)
         self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse)
-        self.assertEqual(torch.LongTensor([1, 2]), int_counts)
 
-        double_unique, double_inverse, double_counts = torch.unique(
+        double_unique, double_inverse = torch.unique(
             torch.DoubleTensor([2., 1.5, 2.1, 2.]),
             sorted=True,
             return_inverse=True,
-            return_counts=True
         )
         self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique)
         self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse)
-        self.assertEqual(torch.LongTensor([1, 2, 1]), double_counts)
 
-        byte_unique, byte_inverse, byte_counts = torch.unique(
+        byte_unique, byte_inverse = torch.unique(
             torch.ByteTensor([133, 7, 7, 7, 42, 128]),
             sorted=True,
             return_inverse=True,
-            return_counts=True
         )
         self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
         self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
-        self.assertEqual(torch.LongTensor([3, 1, 1, 1]), byte_counts)
 
     def test_unique_dim(self):
         def run_test(dtype=torch.float):
@@ -10448,7 +10423,6 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
                                                   [2., 1.],
                                                   [0., 1.]]], dtype=dtype)
             expected_inverse_dim0 = torch.tensor([0, 0])
-            expected_counts_dim0 = torch.tensor([2])
             expected_unique_dim1 = torch.tensor([[[0., 1.],
                                                   [1., 1.],
                                                   [2., 1.]],
@@ -10456,7 +10430,6 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
                                                   [1., 1.],
                                                   [2., 1.]]], dtype=dtype)
             expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
-            expected_counts_dim1 = torch.tensor([2, 1, 1])
             expected_unique_dim2 = torch.tensor([[[1., 1.],
                                                   [0., 1.],
                                                   [2., 1.],
@@ -10466,94 +10439,30 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
                                                   [2., 1.],
                                                   [0., 1.]]], dtype=dtype)
             expected_inverse_dim2 = torch.tensor([0, 1])
-            expected_counts_dim2 = torch.tensor([1, 1])
 
             # dim0
             x_unique = torch.unique(x, dim=0)
             self.assertEqual(expected_unique_dim0, x_unique)
 
-            x_unique, x_inverse = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=False,
-                dim=0)
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
             self.assertEqual(expected_unique_dim0, x_unique)
             self.assertEqual(expected_inverse_dim0, x_inverse)
 
-            x_unique, x_counts = torch.unique(
-                x,
-                return_inverse=False,
-                return_counts=True,
-                dim=0)
-            self.assertEqual(expected_unique_dim0, x_unique)
-            self.assertEqual(expected_counts_dim0, x_counts)
-
-            x_unique, x_inverse, x_counts = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=True,
-                dim=0)
-            self.assertEqual(expected_unique_dim0, x_unique)
-            self.assertEqual(expected_inverse_dim0, x_inverse)
-            self.assertEqual(expected_counts_dim0, x_counts)
-
             # dim1
             x_unique = torch.unique(x, dim=1)
             self.assertEqual(expected_unique_dim1, x_unique)
 
-            x_unique, x_inverse = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=False,
-                dim=1)
-            self.assertEqual(expected_unique_dim1, x_unique)
-            self.assertEqual(expected_inverse_dim1, x_inverse)
-
-            x_unique, x_counts = torch.unique(
-                x,
-                return_inverse=False,
-                return_counts=True,
-                dim=1)
-            self.assertEqual(expected_unique_dim1, x_unique)
-            self.assertEqual(expected_counts_dim1, x_counts)
-
-            x_unique, x_inverse, x_counts = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=True,
-                dim=1)
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
             self.assertEqual(expected_unique_dim1, x_unique)
             self.assertEqual(expected_inverse_dim1, x_inverse)
-            self.assertEqual(expected_counts_dim1, x_counts)
 
             # dim2
             x_unique = torch.unique(x, dim=2)
             self.assertEqual(expected_unique_dim2, x_unique)
 
-            x_unique, x_inverse = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=False,
-                dim=2)
-            self.assertEqual(expected_unique_dim2, x_unique)
-            self.assertEqual(expected_inverse_dim2, x_inverse)
-
-            x_unique, x_counts = torch.unique(
-                x,
-                return_inverse=False,
-                return_counts=True,
-                dim=2)
-            self.assertEqual(expected_unique_dim2, x_unique)
-            self.assertEqual(expected_counts_dim2, x_counts)
-
-            x_unique, x_inverse, x_counts = torch.unique(
-                x,
-                return_inverse=True,
-                return_counts=True,
-                dim=2)
+            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
             self.assertEqual(expected_unique_dim2, x_unique)
             self.assertEqual(expected_inverse_dim2, x_inverse)
-            self.assertEqual(expected_counts_dim2, x_counts)
 
         run_test(torch.float)
         run_test(torch.double)
index 0a460a2..4a59185 100644 (file)
 - name: uniform_(Tensor self, double from, double to, Generator generator)
   self: zeros_like(grad)
 
-- name: _unique(Tensor self, bool sorted, bool return_inverse, bool return_counts)
+- name: _unique(Tensor self, bool sorted, bool return_inverse)
   self: not_implemented("_unique")
 
 - name: _unsafe_view(Tensor self, IntArrayRef size)
index 9cd9e1c..580227c 100644 (file)
@@ -374,8 +374,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
     return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
 
 
-def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
-    r"""Returns the unique elements of the input tensor.
+def unique(input, sorted=True, return_inverse=False, dim=None):
+    r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
 
     Arguments:
         input (Tensor): the input tensor
@@ -383,26 +383,18 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
             before returning as output.
         return_inverse (bool): Whether to also return the indices for where
             elements in the original input ended up in the returned unique list.
-        return_counts (bool): Whether to also return the counts for each unique
-            element.
         dim (int): the dimension to apply unique. If ``None``, the unique of the
             flattened input is returned. default: ``None``
 
     Returns:
-        (Tensor, Tensor (optional) Tensor (optional)):
-        A tensor or a tuple of tensors containing
+        (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing
 
             - **output** (*Tensor*): the output list of unique scalar elements.
             - **inverse_indices** (*Tensor*): (optional) if
-              :attr:`return_inverse` is True, there will be an additional
-              returned tensor (same shape as input) representing the indices
+              :attr:`return_inverse` is True, there will be a
+              2nd returned tensor (same shape as input) representing the indices
               for where elements in the original input map to in the output;
               otherwise, this function will only return a single tensor.
-              - **counts** (*Tensor*): (optional) if
-              :attr:`return_counts` is True, there will be an additional
-              returned tensor (same shape as output or output.size(dim),
-              if dim was specified) representing the number of occurences
-              for each unique value or tensor.
 
     Example::
 
@@ -427,26 +419,20 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
 
     """
     if dim is not None:
-        output, inverse_indices, counts = torch._unique_dim(
+        output, inverse_indices = torch._unique_dim(
             input,
             dim,
             sorted=sorted,
-            return_inverse=return_inverse,
-            return_counts=return_counts
+            return_inverse=return_inverse
         )
     else:
-        output, inverse_indices, counts = torch._unique(
+        output, inverse_indices = torch._unique(
             input,
             sorted=sorted,
             return_inverse=return_inverse,
-            return_counts=return_counts
         )
-    if return_inverse and return_counts:
-        return output, inverse_indices, counts
-    elif return_inverse:
+    if return_inverse:
         return output, inverse_indices
-    elif return_counts:
-        return output, counts
     else:
         return output
 
index 288f362..157b895 100644 (file)
@@ -1205,11 +1205,10 @@ def conv_tbc(g, input, weight, bias, pad):
     return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
 
 
-@parse_args('v', 'i', 'i', 'i')
-def _unique(g, input, sorted, return_inverse, return_counts):
+@parse_args('v', 'i', 'i')
+def _unique(g, input, sorted, return_inverse):
     return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
-                return_inverse_i=return_inverse, return_counts_i=return_counts,
-                outputs=3)
+                return_inverse_i=return_inverse, outputs=2)
 
 
 # Metaprogram symbolics for each ATen native specialized cast operator.
index f1ec022..bf239b3 100644 (file)
@@ -315,32 +315,26 @@ class Tensor(torch._C._TensorBase):
         else:
             return super(Tensor, self).split_with_sizes(split_size, dim)
 
-    def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    def unique(self, sorted=True, return_inverse=False, dim=None):
         r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
 
         See :func:`torch.unique`
         """
         if dim is not None:
-            output, inverse_indices, counts = torch._unique_dim(
+            output, inverse_indices = torch._unique_dim(
                 self,
                 sorted=sorted,
                 return_inverse=return_inverse,
-                return_counts=return_counts,
                 dim=dim
             )
         else:
-            output, inverse_indices, counts = torch._unique(
+            output, inverse_indices = torch._unique(
                 self,
                 sorted=sorted,
-                return_inverse=return_inverse,
-                return_counts=return_counts
+                return_inverse=return_inverse
             )
-        if return_inverse and return_counts:
-            return output, inverse_indices, counts
-        elif return_inverse:
+        if return_inverse:
             return output, inverse_indices
-        elif return_counts:
-            return output, counts
         else:
             return output