Add torch.unique_consecutive (#19060)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Wed, 10 Apr 2019 14:33:15 +0000 (07:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 14:36:08 +0000 (07:36 -0700)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/19045

Please review: VitalyFedyunin ngimel

This is independent on the #18649 series. This will cause merge conflicts in #18649 series, but please merge this first, and I will resolve the merge conflicts there.

The new feature is exposed in `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon`. But not at `torch.unique` yet. I will take care of the API after #18649 series get merged completely.

Benchmark on a tensor of shape `torch.Size([15320, 2])`:

```python
print(torch.__version__)
print()
a = tensor.sort().values.to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```

```
1.1.0a0+2addccc

cpu, sorted_input=False:
340 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
717 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
52.3 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
52.3 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

cpu, sorted_input=True:
32.8 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
49.9 µs ± 557 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
51.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
78 µs ± 782 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

cuda, sorted_input=False:
213 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
291 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
250 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
321 µs ± 1.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

cuda, sorted_input=True:
45.6 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
110 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
82 µs ± 857 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
143 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```

```python
print(torch.__version__)
print()
a1, a2 = tensor.unbind(1)
indices = (a1 * tensor.max() + a2).sort().indices
a = tensor.index_select(0, indices).to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```

```
cpu, sorted_input=False:
55.4 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.8 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 402 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.1 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

cpu, sorted_input=True:
54.7 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.5 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.9 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

cuda, sorted_input=False:
171 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
220 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
203 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
251 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

cuda, sorted_input=True:
59.6 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
113 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
93.2 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
147 µs ± 2.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
The CPU implementation of `unique_dim` is super slow, see https://github.com/pytorch/pytorch/issues/18987, but this PR will not worry about this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19060

Differential Revision: D14866909

Pulled By: ezyang

fbshipit-source-id: d20012cec68c37b05cf770a6f4d6524f910b950f

aten/src/ATen/native/Unique.cpp
aten/src/ATen/native/cuda/Unique.cu
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_torch.py
tools/autograd/gen_python_functions.py
tools/pyi/gen_pyi.py
torch/__init__.pyi.in
torch/functional.py
torch/tensor.py

index 1dcf85d..2888097 100644 (file)
@@ -14,16 +14,21 @@ namespace native{
 namespace {
 
 template <typename scalar_t>
-std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
+std::tuple<Tensor, Tensor, Tensor> unique_cpu_template(
     const Tensor& self,
     const bool sorted,
     const bool return_inverse,
     const bool return_counts) {
   const Tensor& input = self.contiguous();
   const scalar_t* input_data = input.data<scalar_t>();
-  std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
-  Tensor output = at::empty({static_cast<int64_t>(set.size())}, input.options());
-  scalar_t* output_data = output.data<scalar_t>();
+  int64_t numel = input.numel();
+  Tensor output;
+  Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
+  Tensor counts = at::empty({0}, self.options().dtype(kLong));
+
+  std::unordered_set<scalar_t> set(input_data, input_data + numel);
+  output = at::empty({static_cast<int64_t>(set.size())}, input.options());
+  scalar_t *output_data = output.data<scalar_t>();
 
   if (sorted) {
     std::vector<scalar_t> vec(set.begin(), set.end());
@@ -33,8 +38,6 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
     std::copy(set.begin(), set.end(), output_data);
   }
 
-  Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
-  Tensor counts = at::empty({0}, self.options().dtype(kLong));
   if (return_inverse || return_counts) {
     inverse_indices.resize_(input.sizes());
     int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
@@ -43,13 +46,13 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
     for (int i = 0; i < output.numel(); ++i) {
       inverse_map[output_data[i]] = i;
     }
-    for (int i = 0; i < input.numel(); ++i) {
+    for (int i = 0; i < numel; ++i) {
       inverse_indices_data[i] = inverse_map[input_data[i]];
     }
     if (return_counts) {
       counts.resize_(output.sizes());
       counts.fill_(0);
-      for (int i = 0; i < input.numel(); ++i) {
+      for (int i = 0; i < numel; ++i) {
         counts[inverse_map[input_data[i]]] += 1;
       }
     }
@@ -57,6 +60,57 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
   return std::make_tuple(output, inverse_indices, counts);
 }
 
+template <typename scalar_t>
+std::tuple<Tensor, Tensor, Tensor> unique_consecutive_cpu_template(
+    const Tensor& self,
+    const bool return_inverse,
+    const bool return_counts) {
+  const Tensor& input = self.contiguous();
+  const scalar_t* input_data = input.data<scalar_t>();
+  int64_t numel = input.numel();
+  Tensor output = at::empty({numel}, input.options());
+  Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
+  Tensor counts = at::empty({0}, self.options().dtype(kLong));
+
+  scalar_t *output_data = output.data<scalar_t>();
+  int64_t *inverse_data = nullptr;
+  int64_t *counts_data = nullptr;
+  if (numel > 0) {
+    *output_data = *input_data;
+  }
+  if (return_inverse) {
+    inverse_indices.resize_(input.sizes());
+    inverse_data = inverse_indices.data<int64_t>();
+  }
+  if (return_counts) {
+    counts.resize_(input.sizes());
+    counts_data = counts.data<int64_t>();
+  }
+  scalar_t *p = output_data;
+  int64_t *q = counts_data;
+  int64_t last = 0;
+  for (int64_t i = 0; i < numel; i++) {
+    if (input_data[i] != *p) {
+      *(++p) = input_data[i];
+      if (return_counts) {
+        *(q++) = i - last;
+        last = i;
+      }
+    }
+    if (return_inverse) {
+      inverse_data[i] = p - output_data;
+    }
+  }
+  int64_t output_size = p - output_data + 1;
+  if (return_counts && numel > 0) {
+    *q = numel - last;
+    counts.resize_({output_size});
+  }
+  output.resize_({output_size});
+
+  return std::make_tuple(output, inverse_indices, counts);
+}
+
 template<class ForwardIt>
 ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
   std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
@@ -88,6 +142,7 @@ template <typename scalar_t>
 std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
     const Tensor& self,
     const int64_t dim,
+    const bool consecutive,
     const bool return_inverse,
     const bool return_counts) {
   // reshape tensor as [dim, -1]
@@ -101,23 +156,30 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
   scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());
 
   // sort indices using data
-  std::sort(indices.begin(), indices.end(),
-    [&](int64_t a, int64_t b) -> bool {
-      for (int64_t i = 0; i < numel; ++i) {
-        scalar_t lhs = input_flat_ptr[i + a * numel];
-        scalar_t rhs = input_flat_ptr[i + b * numel];
-        if (lhs < rhs) {
-          return true;
-        } else if (lhs > rhs) {
-          return false;
+  if (!consecutive) {
+    std::sort(indices.begin(), indices.end(),
+      [&](int64_t a, int64_t b) -> bool {
+        for (int64_t i = 0; i < numel; ++i) {
+          scalar_t lhs = input_flat_ptr[i + a * numel];
+          scalar_t rhs = input_flat_ptr[i + b * numel];
+          if (lhs < rhs) {
+            return true;
+          } else if (lhs > rhs) {
+            return false;
+          }
         }
-      }
-      return false;
-    });
+        return false;
+      });
+  }
 
-  Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.options());
-  for (int i = 0; i < indices.size(); ++i) {
-    input_sorted[i] = input_flat[indices[i]];
+  Tensor input_sorted;
+  if (!consecutive) {
+    input_sorted = at::empty(input_flat.sizes(), input_flat.options());
+    for (int i = 0; i < indices.size(); ++i) {
+      input_sorted[i] = input_flat[indices[i]];
+    }
+  } else {
+    input_sorted = input_flat;
   }
 
   Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
@@ -137,6 +199,7 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
 
   return std::make_tuple(output, inverse_indices, counts);
 }
+
 } // namespace
 
 
@@ -144,7 +207,7 @@ std::tuple<Tensor, Tensor>
 _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
     Tensor output, inverse;
-    std::tie(output, inverse, std::ignore) = _unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
+    std::tie(output, inverse, std::ignore) = unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
     return std::make_tuple(output, inverse);
   });
 }
@@ -152,7 +215,7 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
 std::tuple<Tensor, Tensor, Tensor>
 _unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
-    return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
+    return unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
   });
 }
 
@@ -161,7 +224,7 @@ _unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     // The current implementation using `dim` always sorts due to unhashable tensors
     Tensor output, inverse;
-    std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, false);
+    std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, false);
     return std::make_tuple(output, inverse);
   });
 }
@@ -170,9 +233,26 @@ std::tuple<Tensor, Tensor, Tensor>
 _unique_dim2_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     // The current implementation using `dim` always sorts due to unhashable tensors
-    return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
+    return _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, return_counts);
   });
 }
 
+std::tuple<Tensor, Tensor, Tensor>
+unique_dim_consecutive_cpu(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
+    return _unique_dim_cpu_template<scalar_t>(self, dim, true, return_inverse, return_counts);
+  });
+}
+
+std::tuple<Tensor, Tensor, Tensor>
+unique_consecutive_cpu(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
+  if (!dim.has_value()) {
+    return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
+      return unique_consecutive_cpu_template<scalar_t>(self, return_inverse, return_counts);
+    });
+  }
+  return unique_dim_consecutive_cpu(self, dim.value(), return_inverse, return_counts);
+}
+
 }  // namespace native
 }  // namespace at
index e4945bb..734fa66 100644 (file)
@@ -73,6 +73,7 @@ std::tuple<Tensor, Tensor, int64_t> compute_unique(
 template <typename scalar_t>
 std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
   const Tensor& self,
+  const bool consecutive,
   const bool return_inverse,
   const bool return_counts
 ) {
@@ -88,11 +89,15 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
 
   Tensor sorted_indices;
   if (!return_inverse) {
-    thrust::sort(policy, output_data, output_data + num_inp);
+    if (!consecutive) {
+      thrust::sort(policy, output_data, output_data + num_inp);
+    }
   } else {
     sorted_indices = at::arange(0, num_inp, options);
-    int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
-    thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+    if (!consecutive) {
+      int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
+      thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+    }
   }
 
   Tensor inverse_indices, counts;
@@ -116,6 +121,7 @@ template <typename scalar_t>
 std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
   const Tensor& self,
   const int64_t dim,
+  const bool consecutive,
   const bool return_inverse,
   const bool return_counts
 ) {
@@ -141,20 +147,22 @@ std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
 
   Tensor indices = at::arange(0, num_inp, options);
   int64_t *indices_data = indices.data<int64_t>();
-  thrust::sort(policy, indices_data, indices_data + num_inp,
-    [=] __device__ (int64_t a, int64_t b) -> bool {
-      for (int64_t i = 0; i < n; ++i) {
-        scalar_t lhs = input_flat_ptr[i + a * n];
-        scalar_t rhs = input_flat_ptr[i + b * n];
-        if (lhs < rhs) {
-          return true;
-        } else if (lhs > rhs) {
-          return false;
+  if (!consecutive) {
+    thrust::sort(policy, indices_data, indices_data + num_inp,
+      [=] __device__ (int64_t a, int64_t b) -> bool {
+        for (int64_t i = 0; i < n; ++i) {
+          scalar_t lhs = input_flat_ptr[i + a * n];
+          scalar_t rhs = input_flat_ptr[i + b * n];
+          if (lhs < rhs) {
+            return true;
+          } else if (lhs > rhs) {
+            return false;
+          }
         }
+        return false;
       }
-      return false;
-    }
-  );
+    );
+  }
 
   Tensor inverse_indices, counts;
   int64_t num_out;
@@ -196,7 +204,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
     Tensor output, inverse;
-    std::tie(output, inverse, std::ignore) = unique_cuda_template<scalar_t>(self, return_inverse, false);
+    std::tie(output, inverse, std::ignore) = unique_cuda_template<scalar_t>(self, false, return_inverse, false);
     return std::make_tuple(output, inverse);
   });
 }
@@ -206,7 +214,7 @@ _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse,
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
-    return unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
+    return unique_cuda_template<scalar_t>(self, false, return_inverse, return_counts);
   });
 }
 
@@ -214,7 +222,7 @@ std::tuple<Tensor, Tensor>
 _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     Tensor output, inverse;
-    std::tie(output, inverse, std::ignore) = unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, false);
+    std::tie(output, inverse, std::ignore) = unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, false);
     return std::make_tuple(output, inverse);
   });
 }
@@ -222,9 +230,28 @@ _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const
 std::tuple<Tensor, Tensor, Tensor>
 _unique_dim2_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
   return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
-    return unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
+    return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
   });
 }
 
+std::tuple<Tensor, Tensor, Tensor>
+unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
+    return unique_dim_cuda_template<scalar_t>(self, dim, true, return_inverse, return_counts);
+  });
+}
+
+std::tuple<Tensor, Tensor, Tensor>
+unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
+  if (!dim.has_value()) {
+    return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
+      // The current CUDA implementation of unique always sort due to the
+      // lack of hashtable implementation in thrust
+      return unique_cuda_template<scalar_t>(self, true, return_inverse, return_counts);
+    });
+  }
+  return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts);
+}
+
 }  // namespace native
 }  // namespace at
index 1499071..852add5 100644 (file)
     CPU: _unique_dim_cpu
     CUDA: _unique_dim_cuda
 
+- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: unique_consecutive_cpu
+    CUDA: unique_consecutive_cuda
+
+- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: unique_dim_consecutive_cpu
+    CUDA: unique_dim_consecutive_cuda
+
 # _unique and _unique_dim are fragile and modifying them easily cause internal break
 # below two operators are a temporary hack for adding return_counts support
 # Please don't rely on these two operators, they will be removed soon
index af10e8e..0af3c52 100644 (file)
@@ -451,6 +451,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: unfold
    .. automethod:: uniform_
    .. automethod:: unique
+   .. automethod:: unique_consecutive
    .. automethod:: unsqueeze
    .. automethod:: unsqueeze_
    .. automethod:: values
index 039fd0f..67f1955 100644 (file)
@@ -225,6 +225,7 @@ Reduction Ops
 .. autofunction:: std
 .. autofunction:: sum
 .. autofunction:: unique
+.. autofunction:: unique_consecutive
 .. autofunction:: var
 
 
index 2a4c576..5adce86 100644 (file)
@@ -10611,6 +10611,28 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse)
             self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts)
 
+            # test consecutive version
+            z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device)
+            expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device)
+            expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
+            expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
+
+            z_unique = torch.unique_consecutive(z)
+            self.assertEqual(z_unique, expected_z_unique)
+
+            z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True)
+            self.assertEqual(z_unique, expected_z_unique)
+            self.assertEqual(z_inverse, expected_z_inverse)
+
+            z_unique, z_counts = torch.unique_consecutive(z, return_counts=True)
+            self.assertEqual(z_unique, expected_z_unique)
+            self.assertEqual(z_counts, expected_z_counts)
+
+            z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True)
+            self.assertEqual(z_unique, expected_z_unique)
+            self.assertEqual(z_inverse, expected_z_inverse)
+            self.assertEqual(z_counts, expected_z_counts)
+
         run_test(torch.device('cpu'))
         if torch.cuda.is_available():
             run_test(torch.device('cuda'))
@@ -10742,6 +10764,37 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(expected_inverse_dim2, x_inverse)
             self.assertEqual(expected_counts_dim2, x_counts)
 
+            # test consecutive version
+            y = torch.tensor(
+                [[0, 1],
+                 [0, 1],
+                 [0, 1],
+                 [1, 2],
+                 [1, 2],
+                 [3, 4],
+                 [0, 1],
+                 [0, 1],
+                 [3, 4],
+                 [1, 2]],
+                dtype=dtype,
+                device=device
+            )
+            expected_y_unique = torch.tensor(
+                [[0, 1],
+                 [1, 2],
+                 [3, 4],
+                 [0, 1],
+                 [3, 4],
+                 [1, 2]],
+                dtype=dtype,
+                device=device
+            )
+            expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=dtype, device=device)
+            expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=dtype, device=device)
+            y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0)
+            self.assertEqual(expected_y_inverse, y_inverse)
+            self.assertEqual(expected_y_counts, y_counts)
+
         run_test(torch.float)
         run_test(torch.double)
         run_test(torch.long)
index e147010..2e7fe79 100644 (file)
@@ -22,7 +22,7 @@ SKIP_PYTHON_BINDINGS = [
     '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
     '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
     '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
-    'index',
+    'index', 'unique_dim_consecutive',
     '_indexCopy_', 'max_values', 'min_values',
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
index cd95a25..0d5ad32 100644 (file)
@@ -72,6 +72,7 @@ blacklist = [
     'tensordot',
     'norm',
     'split',
+    'unique_consecutive',
     # These are handled specially by python_arg_parser.cpp
     'add',
     'add_',
index a521c45..9a748c9 100644 (file)
@@ -87,6 +87,7 @@ class Tensor:
              center=True, pad_mode='reflect', normalized=False, onesided=True): ...
     def split(self, split_size, dim=0): ...
     def unique(self, sorted=True, return_inverse=False, dim=None): ...
+    def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
     def lu(self, pivot=True, get_infos=False): ...
 
 ${function_hints}
index 4a13658..fd0e47e 100644 (file)
@@ -28,6 +28,7 @@ __all__ = [
     'tensordot',
     'trtrs',
     'unique',
+    'unique_consecutive',
 ]
 
 
@@ -449,6 +450,67 @@ def unique(input, sorted=True, return_inverse=False, dim=None):
         return output
 
 
+def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None):
+    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+    .. note:: This function is different from :func:`torch.unique` in the sense that this function
+        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
+        in C++.
+
+    Arguments:
+        input (Tensor): the input tensor
+        return_inverse (bool): Whether to also return the indices for where
+            elements in the original input ended up in the returned unique list.
+        return_counts (bool): Whether to also return the counts for each unique
+            element.
+        dim (int): the dimension to apply unique. If ``None``, the unique of the
+            flattened input is returned. default: ``None``
+
+    Returns:
+        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
+
+            - **output** (*Tensor*): the output list of unique scalar elements.
+            - **inverse_indices** (*Tensor*): (optional) if
+              :attr:`return_inverse` is True, there will be an additional
+              returned tensor (same shape as input) representing the indices
+              for where elements in the original input map to in the output;
+              otherwise, this function will only return a single tensor.
+            - **counts** (*Tensor*): (optional) if
+              :attr:`return_counts` is True, there will be an additional
+              returned tensor (same shape as output or output.size(dim),
+              if dim was specified) representing the number of occurrences
+              for each unique value or tensor.
+
+    Example::
+
+        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
+        >>> output = torch.unique_consecutive(x)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+
+        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+        >>> inverse_indices
+        tensor([0, 0, 1, 1, 2, 3, 3, 4])
+
+        >>> output, counts = torch.unique_consecutive(x, return_counts=True)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+        >>> counts
+        tensor([2, 2, 1, 2, 1])
+    """
+    output, inverse_indices, counts = torch._C._VariableFunctions.unique_consecutive(
+        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+    if return_inverse and return_counts:
+        return output, inverse_indices, counts
+    if return_inverse:
+        return output, inverse_indices
+    if return_counts:
+        return output, counts
+    return output
+
+
 def tensordot(a, b, dims=2):
     r"""Returns a contraction of a and b over multiple dimensions.
 
index a788c50..1f98a61 100644 (file)
@@ -368,6 +368,13 @@ class Tensor(torch._C._TensorBase):
         else:
             return output
 
+    def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
+        r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+        See :func:`torch.unique_consecutive`
+        """
+        return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+
     def __rsub__(self, other):
         return _C._VariableFunctions.rsub(self, other)