Summary:
Fixes: https://github.com/pytorch/pytorch/issues/19045
Please review: VitalyFedyunin ngimel
This is independent on the #18649 series. This will cause merge conflicts in #18649 series, but please merge this first, and I will resolve the merge conflicts there.
The new feature is exposed in `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon`. But not at `torch.unique` yet. I will take care of the API after #18649 series get merged completely.
Benchmark on a tensor of shape `torch.Size([15320, 2])`:
```python
print(torch.__version__)
print()
a = tensor.sort().values.to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```
```
1.1.0a0+2addccc
cpu, sorted_input=False:
340 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
717 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
52.3 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
52.3 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
cpu, sorted_input=True:
32.8 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
49.9 µs ± 557 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
51.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
78 µs ± 782 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cuda, sorted_input=False:
213 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
291 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
250 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
321 µs ± 1.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cuda, sorted_input=True:
45.6 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
110 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
82 µs ± 857 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
143 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
```python
print(torch.__version__)
print()
a1, a2 = tensor.unbind(1)
indices = (a1 * tensor.max() + a2).sort().indices
a = tensor.index_select(0, indices).to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```
```
cpu, sorted_input=False:
55.4 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.8 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 402 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.1 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cpu, sorted_input=True:
54.7 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.5 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.9 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cuda, sorted_input=False:
171 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
220 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
203 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
251 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cuda, sorted_input=True:
59.6 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
113 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
93.2 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
147 µs ± 2.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
The CPU implementation of `unique_dim` is super slow, see https://github.com/pytorch/pytorch/issues/18987, but this PR will not worry about this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19060
Differential Revision:
D14866909
Pulled By: ezyang
fbshipit-source-id:
d20012cec68c37b05cf770a6f4d6524f910b950f
namespace {
template <typename scalar_t>
-std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
+std::tuple<Tensor, Tensor, Tensor> unique_cpu_template(
const Tensor& self,
const bool sorted,
const bool return_inverse,
const bool return_counts) {
const Tensor& input = self.contiguous();
const scalar_t* input_data = input.data<scalar_t>();
- std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
- Tensor output = at::empty({static_cast<int64_t>(set.size())}, input.options());
- scalar_t* output_data = output.data<scalar_t>();
+ int64_t numel = input.numel();
+ Tensor output;
+ Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
+ Tensor counts = at::empty({0}, self.options().dtype(kLong));
+
+ std::unordered_set<scalar_t> set(input_data, input_data + numel);
+ output = at::empty({static_cast<int64_t>(set.size())}, input.options());
+ scalar_t *output_data = output.data<scalar_t>();
if (sorted) {
std::vector<scalar_t> vec(set.begin(), set.end());
std::copy(set.begin(), set.end(), output_data);
}
- Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
- Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (return_inverse || return_counts) {
inverse_indices.resize_(input.sizes());
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
for (int i = 0; i < output.numel(); ++i) {
inverse_map[output_data[i]] = i;
}
- for (int i = 0; i < input.numel(); ++i) {
+ for (int i = 0; i < numel; ++i) {
inverse_indices_data[i] = inverse_map[input_data[i]];
}
if (return_counts) {
counts.resize_(output.sizes());
counts.fill_(0);
- for (int i = 0; i < input.numel(); ++i) {
+ for (int i = 0; i < numel; ++i) {
counts[inverse_map[input_data[i]]] += 1;
}
}
return std::make_tuple(output, inverse_indices, counts);
}
+template <typename scalar_t>
+std::tuple<Tensor, Tensor, Tensor> unique_consecutive_cpu_template(
+ const Tensor& self,
+ const bool return_inverse,
+ const bool return_counts) {
+ const Tensor& input = self.contiguous();
+ const scalar_t* input_data = input.data<scalar_t>();
+ int64_t numel = input.numel();
+ Tensor output = at::empty({numel}, input.options());
+ Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
+ Tensor counts = at::empty({0}, self.options().dtype(kLong));
+
+ scalar_t *output_data = output.data<scalar_t>();
+ int64_t *inverse_data = nullptr;
+ int64_t *counts_data = nullptr;
+ if (numel > 0) {
+ *output_data = *input_data;
+ }
+ if (return_inverse) {
+ inverse_indices.resize_(input.sizes());
+ inverse_data = inverse_indices.data<int64_t>();
+ }
+ if (return_counts) {
+ counts.resize_(input.sizes());
+ counts_data = counts.data<int64_t>();
+ }
+ scalar_t *p = output_data;
+ int64_t *q = counts_data;
+ int64_t last = 0;
+ for (int64_t i = 0; i < numel; i++) {
+ if (input_data[i] != *p) {
+ *(++p) = input_data[i];
+ if (return_counts) {
+ *(q++) = i - last;
+ last = i;
+ }
+ }
+ if (return_inverse) {
+ inverse_data[i] = p - output_data;
+ }
+ }
+ int64_t output_size = p - output_data + 1;
+ if (return_counts && numel > 0) {
+ *q = numel - last;
+ counts.resize_({output_size});
+ }
+ output.resize_({output_size});
+
+ return std::make_tuple(output, inverse_indices, counts);
+}
+
template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
+ const bool consecutive,
const bool return_inverse,
const bool return_counts) {
// reshape tensor as [dim, -1]
scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());
// sort indices using data
- std::sort(indices.begin(), indices.end(),
- [&](int64_t a, int64_t b) -> bool {
- for (int64_t i = 0; i < numel; ++i) {
- scalar_t lhs = input_flat_ptr[i + a * numel];
- scalar_t rhs = input_flat_ptr[i + b * numel];
- if (lhs < rhs) {
- return true;
- } else if (lhs > rhs) {
- return false;
+ if (!consecutive) {
+ std::sort(indices.begin(), indices.end(),
+ [&](int64_t a, int64_t b) -> bool {
+ for (int64_t i = 0; i < numel; ++i) {
+ scalar_t lhs = input_flat_ptr[i + a * numel];
+ scalar_t rhs = input_flat_ptr[i + b * numel];
+ if (lhs < rhs) {
+ return true;
+ } else if (lhs > rhs) {
+ return false;
+ }
}
- }
- return false;
- });
+ return false;
+ });
+ }
- Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.options());
- for (int i = 0; i < indices.size(); ++i) {
- input_sorted[i] = input_flat[indices[i]];
+ Tensor input_sorted;
+ if (!consecutive) {
+ input_sorted = at::empty(input_flat.sizes(), input_flat.options());
+ for (int i = 0; i < indices.size(); ++i) {
+ input_sorted[i] = input_flat[indices[i]];
+ }
+ } else {
+ input_sorted = input_flat;
}
Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
return std::make_tuple(output, inverse_indices, counts);
}
+
} // namespace
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
Tensor output, inverse;
- std::tie(output, inverse, std::ignore) = _unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
+ std::tie(output, inverse, std::ignore) = unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
- return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
+ return unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
});
}
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
Tensor output, inverse;
- std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, false);
+ std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
_unique_dim2_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
- return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
+ return _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, return_counts);
});
}
+std::tuple<Tensor, Tensor, Tensor>
+unique_dim_consecutive_cpu(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
+ return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
+ return _unique_dim_cpu_template<scalar_t>(self, dim, true, return_inverse, return_counts);
+ });
+}
+
+std::tuple<Tensor, Tensor, Tensor>
+unique_consecutive_cpu(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
+ if (!dim.has_value()) {
+ return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
+ return unique_consecutive_cpu_template<scalar_t>(self, return_inverse, return_counts);
+ });
+ }
+ return unique_dim_consecutive_cpu(self, dim.value(), return_inverse, return_counts);
+}
+
} // namespace native
} // namespace at
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
const Tensor& self,
+ const bool consecutive,
const bool return_inverse,
const bool return_counts
) {
Tensor sorted_indices;
if (!return_inverse) {
- thrust::sort(policy, output_data, output_data + num_inp);
+ if (!consecutive) {
+ thrust::sort(policy, output_data, output_data + num_inp);
+ }
} else {
sorted_indices = at::arange(0, num_inp, options);
- int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
- thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+ if (!consecutive) {
+ int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
+ thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+ }
}
Tensor inverse_indices, counts;
std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
+ const bool consecutive,
const bool return_inverse,
const bool return_counts
) {
Tensor indices = at::arange(0, num_inp, options);
int64_t *indices_data = indices.data<int64_t>();
- thrust::sort(policy, indices_data, indices_data + num_inp,
- [=] __device__ (int64_t a, int64_t b) -> bool {
- for (int64_t i = 0; i < n; ++i) {
- scalar_t lhs = input_flat_ptr[i + a * n];
- scalar_t rhs = input_flat_ptr[i + b * n];
- if (lhs < rhs) {
- return true;
- } else if (lhs > rhs) {
- return false;
+ if (!consecutive) {
+ thrust::sort(policy, indices_data, indices_data + num_inp,
+ [=] __device__ (int64_t a, int64_t b) -> bool {
+ for (int64_t i = 0; i < n; ++i) {
+ scalar_t lhs = input_flat_ptr[i + a * n];
+ scalar_t rhs = input_flat_ptr[i + b * n];
+ if (lhs < rhs) {
+ return true;
+ } else if (lhs > rhs) {
+ return false;
+ }
}
+ return false;
}
- return false;
- }
- );
+ );
+ }
Tensor inverse_indices, counts;
int64_t num_out;
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
Tensor output, inverse;
- std::tie(output, inverse, std::ignore) = unique_cuda_template<scalar_t>(self, return_inverse, false);
+ std::tie(output, inverse, std::ignore) = unique_cuda_template<scalar_t>(self, false, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
- return unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
+ return unique_cuda_template<scalar_t>(self, false, return_inverse, return_counts);
});
}
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
Tensor output, inverse;
- std::tie(output, inverse, std::ignore) = unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, false);
+ std::tie(output, inverse, std::ignore) = unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique_dim2_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
- return unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
+ return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
});
}
+std::tuple<Tensor, Tensor, Tensor>
+unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
+ return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
+ return unique_dim_cuda_template<scalar_t>(self, dim, true, return_inverse, return_counts);
+ });
+}
+
+std::tuple<Tensor, Tensor, Tensor>
+unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
+ if (!dim.has_value()) {
+ return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
+ // The current CUDA implementation of unique always sort due to the
+ // lack of hashtable implementation in thrust
+ return unique_cuda_template<scalar_t>(self, true, return_inverse, return_counts);
+ });
+ }
+ return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts);
+}
+
} // namespace native
} // namespace at
CPU: _unique_dim_cpu
CUDA: _unique_dim_cuda
+- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+ matches_jit_signature: True
+ variants: function
+ dispatch:
+ CPU: unique_consecutive_cpu
+ CUDA: unique_consecutive_cuda
+
+- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+ matches_jit_signature: True
+ variants: function
+ dispatch:
+ CPU: unique_dim_consecutive_cpu
+ CUDA: unique_dim_consecutive_cuda
+
# _unique and _unique_dim are fragile and modifying them easily cause internal break
# below two operators are a temporary hack for adding return_counts support
# Please don't rely on these two operators, they will be removed soon
.. automethod:: unfold
.. automethod:: uniform_
.. automethod:: unique
+ .. automethod:: unique_consecutive
.. automethod:: unsqueeze
.. automethod:: unsqueeze_
.. automethod:: values
.. autofunction:: std
.. autofunction:: sum
.. autofunction:: unique
+.. autofunction:: unique_consecutive
.. autofunction:: var
self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse)
self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts)
+ # test consecutive version
+ z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device)
+ expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device)
+ expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
+ expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
+
+ z_unique = torch.unique_consecutive(z)
+ self.assertEqual(z_unique, expected_z_unique)
+
+ z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True)
+ self.assertEqual(z_unique, expected_z_unique)
+ self.assertEqual(z_inverse, expected_z_inverse)
+
+ z_unique, z_counts = torch.unique_consecutive(z, return_counts=True)
+ self.assertEqual(z_unique, expected_z_unique)
+ self.assertEqual(z_counts, expected_z_counts)
+
+ z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True)
+ self.assertEqual(z_unique, expected_z_unique)
+ self.assertEqual(z_inverse, expected_z_inverse)
+ self.assertEqual(z_counts, expected_z_counts)
+
run_test(torch.device('cpu'))
if torch.cuda.is_available():
run_test(torch.device('cuda'))
self.assertEqual(expected_inverse_dim2, x_inverse)
self.assertEqual(expected_counts_dim2, x_counts)
+ # test consecutive version
+ y = torch.tensor(
+ [[0, 1],
+ [0, 1],
+ [0, 1],
+ [1, 2],
+ [1, 2],
+ [3, 4],
+ [0, 1],
+ [0, 1],
+ [3, 4],
+ [1, 2]],
+ dtype=dtype,
+ device=device
+ )
+ expected_y_unique = torch.tensor(
+ [[0, 1],
+ [1, 2],
+ [3, 4],
+ [0, 1],
+ [3, 4],
+ [1, 2]],
+ dtype=dtype,
+ device=device
+ )
+ expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=dtype, device=device)
+ expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=dtype, device=device)
+ y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0)
+ self.assertEqual(expected_y_inverse, y_inverse)
+ self.assertEqual(expected_y_counts, y_counts)
+
run_test(torch.float)
run_test(torch.double)
run_test(torch.long)
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
- 'index',
+ 'index', 'unique_dim_consecutive',
'_indexCopy_', 'max_values', 'min_values',
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
'_th_.*', '_thnn_.*',
'tensordot',
'norm',
'split',
+ 'unique_consecutive',
# These are handled specially by python_arg_parser.cpp
'add',
'add_',
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
+ def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
def lu(self, pivot=True, get_infos=False): ...
${function_hints}
'tensordot',
'trtrs',
'unique',
+ 'unique_consecutive',
]
return output
+def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None):
+ r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+ .. note:: This function is different from :func:`torch.unique` in the sense that this function
+ only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
+ in C++.
+
+ Arguments:
+ input (Tensor): the input tensor
+ return_inverse (bool): Whether to also return the indices for where
+ elements in the original input ended up in the returned unique list.
+ return_counts (bool): Whether to also return the counts for each unique
+ element.
+ dim (int): the dimension to apply unique. If ``None``, the unique of the
+ flattened input is returned. default: ``None``
+
+ Returns:
+ (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
+
+ - **output** (*Tensor*): the output list of unique scalar elements.
+ - **inverse_indices** (*Tensor*): (optional) if
+ :attr:`return_inverse` is True, there will be an additional
+ returned tensor (same shape as input) representing the indices
+ for where elements in the original input map to in the output;
+ otherwise, this function will only return a single tensor.
+ - **counts** (*Tensor*): (optional) if
+ :attr:`return_counts` is True, there will be an additional
+ returned tensor (same shape as output or output.size(dim),
+ if dim was specified) representing the number of occurrences
+ for each unique value or tensor.
+
+ Example::
+
+ >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
+ >>> output = torch.unique_consecutive(x)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+
+ >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+ >>> inverse_indices
+ tensor([0, 0, 1, 1, 2, 3, 3, 4])
+
+ >>> output, counts = torch.unique_consecutive(x, return_counts=True)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+ >>> counts
+ tensor([2, 2, 1, 2, 1])
+ """
+ output, inverse_indices, counts = torch._C._VariableFunctions.unique_consecutive(
+ input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+ if return_inverse and return_counts:
+ return output, inverse_indices, counts
+ if return_inverse:
+ return output, inverse_indices
+ if return_counts:
+ return output, counts
+ return output
+
+
def tensordot(a, b, dims=2):
r"""Returns a contraction of a and b over multiple dimensions.
else:
return output
+ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
+ r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+ See :func:`torch.unique_consecutive`
+ """
+ return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+
def __rsub__(self, other):
return _C._VariableFunctions.rsub(self, other)