From 773ce4fbd05c4ebb16eb361cadde0dcf33f38085 Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Wed, 3 Apr 2019 15:26:34 -0700 Subject: [PATCH] Step 1: Secretly add return_counts to unique, and refactor unique_dim for performance (#18648) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18648 ghimport-source-id: 1cf4a8fe91492621e02217f38cae5d7e0699fb05 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18661 Step 7: remove _unique * #18655 Step 6: Rename _unique2 to unique and add int? dim * #18654 Step 5: remove _unque_dim in favor of unique_dim * #18651 Step 4: add support for unique with dim=None * #18650 Step 3: Add support for return_counts to torch.unique for dim not None * #18649 Step 2: Rename _unique_dim2_temporary_will_remove_soon to unique_dim * **#18648 Step 1: Secretly add return_counts to unique, and refactor unique_dim for performance** `unique` is fragile, previously I tried to change it in #18391 and #17097, they all pass OSS tests but finally get reverted due to internal failure. My previous work of refactoring unique #18459 is based on #18391, and after #18391 get reverted, I could not work on #18459. To continue working on #18459, #18391, and #17097 without worrying about internal failures, I am suggesting the following steps for the improvements of `unique` and `unique_dim`. soumith Please take this and there is no need to put #18391 back. The motivation is basically to move forward as much as possible without causing any internal failures. So I will try to divide it into steps and sort from low probability of internal failure to high probability. (I don't know what the internal failure is, so I have to guess). Let's merge these PR stack one by one until we enounter internal failure. Step 1: Create two new ATen operators, `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon` and keep `_unique` and `_unique_dim` unchanged. The backend of these two functions and `_unique` and `_unique_dim` are all the same, the only difference is the temporary ones support `return_counts` but not the `_unique` and `_unique_dim`. Step one is mostly #18391 + #18459. The cuda8 errors has been fixed. At this point, there is no user visible API change, so no docs are updated. `torch.unique` does not support `return_counts` yet, and `return_counts` is tested through the newly added temporary operators. This step just added two new ATen operators, so there shouldn't be any internal failure. Step 2: Rename `_unique_dim2_temporary_will_remove_soon` to `unique_dim`. This should cause no internal failure either, because no change to existing operators. The only thing to worry about is to delete `unique_dim` from python side because we don't want users to use it. At this point, C++ users now have `return_counts` support for `unique_dim`. Step 3: Update the docs of `torch.unique` and use `unique_dim` inside `torch.unique` to support `return_counts` In the docs, we should say `torch.unique` with None dim support does not support `return_counts` yet. This might cause internal failure. Step 4: Rename `_unique2_temporary_will_remove_soon` to `_unique2` and use `_unique2` inside `torch.unique` to support `return_counts`. Update the docs saying that `torch.unique` with None dim now support `return_counts`. This might cause internal failure. Step 5: Remove `_unique_dim`. This might cause internal failure. Step 6: Rename `_unique2` to `unique`, add optional `dim` argument to make it looks like the signature of Python's `torch.unique`. Inside `torch.unique`, use `unique` and get rid of `unique_dim`. Unbind `unique_dim` totally from Python at codegen. This is likely to cause internal fail. Step 7: Remove `_unique`. This is very likely to cause internal failure. This PR ====== This PR is for step 1. This create two new ATen operators, `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon` and implement `return_counts` inside them and do refactor for performance improvements. Please review ngimel VitalyFedyunin. They are mostly copied from #18391 and #18459, so the review should be easy. Below is a benchmark on a tensor of shape `torch.Size([15320, 2])`: Before --------- ```python print(torch.__version__) %timeit a.unique(dim=0, sorted=True, return_inverse=False); torch.cuda.synchronize() %timeit a.unique(dim=0, sorted=True, return_inverse=True); torch.cuda.synchronize() ``` ``` 1.0.1 192 µs ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 548 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` ```python print(torch.__version__) %timeit a.unique(sorted=True, return_inverse=False); torch.cuda.synchronize() %timeit a.unique(sorted=True, return_inverse=True); torch.cuda.synchronize() ``` ``` 1.0.1 226 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 302 µs ± 7.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` After ------- ```python print(torch.__version__) %timeit a.unique(dim=0, sorted=True, return_inverse=False); torch.cuda.synchronize() %timeit a.unique(dim=0, sorted=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted=True, return_inverse=False, return_counts=True); torch.cuda.synchronize() %timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` 1.1.0a0+83ab8ac 190 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 237 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 219 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 263 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` ```python print(torch.__version__) %timeit a.unique(sorted=True, return_inverse=False); torch.cuda.synchronize() %timeit a.unique(sorted=True, return_inverse=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted=True, return_inverse=False, return_counts=True); torch.cuda.synchronize() %timeit torch._unique2_temporary_will_remove_soon(a, sorted=True, return_inverse=True, return_counts=True); torch.cuda.synchronize() ``` ``` 1.1.0a0+83ab8ac 232 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 301 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 264 µs ± 7.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 339 µs ± 9.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` Differential Revision: D14730905 fbshipit-source-id: 10026b4b98628a8565cc28a13317d29adf1225cc --- aten/src/ATen/native/Unique.cpp | 60 ++++-- aten/src/ATen/native/cuda/Unique.cu | 304 ++++++++++++++++++----------- aten/src/ATen/native/native_functions.yaml | 18 ++ test/test_torch.py | 212 +++++++++++++++----- 4 files changed, 409 insertions(+), 185 deletions(-) diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 8cc867f..1dcf85d 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -14,10 +14,11 @@ namespace native{ namespace { template -std::tuple _unique_cpu_template( +std::tuple _unique_cpu_template( const Tensor& self, const bool sorted, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { const Tensor& input = self.contiguous(); const scalar_t* input_data = input.data(); std::unordered_set set(input_data, input_data + input.numel()); @@ -33,7 +34,8 @@ std::tuple _unique_cpu_template( } Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); - if (return_inverse) { + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + if (return_inverse || return_counts) { inverse_indices.resize_(input.sizes()); int64_t* inverse_indices_data = inverse_indices.data(); std::unordered_map inverse_map; @@ -44,21 +46,29 @@ std::tuple _unique_cpu_template( for (int i = 0; i < input.numel(); ++i) { inverse_indices_data[i] = inverse_map[input_data[i]]; } + if (return_counts) { + counts.resize_(output.sizes()); + counts.fill_(0); + for (int i = 0; i < input.numel(); ++i) { + counts[inverse_map[input_data[i]]] += 1; + } + } } - return std::make_tuple(output, inverse_indices); + return std::make_tuple(output, inverse_indices, counts); } template ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, - std::vector& indices, Tensor inverse_indices_vec) { + std::vector& indices, Tensor inverse_indices_vec, Tensor counts) { if (first == last) { return last; } // save to calculate distance to iterators ForwardIt begin = first; - // set first inverse index + // set first inverse index and count inverse_indices_vec[indices[0]] = 0; + counts[0] += 1; ForwardIt result = first; while (++first != last) { @@ -68,16 +78,18 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last, int64_t idx_result = std::distance(begin, result); int64_t idx_first = std::distance(begin, first); inverse_indices_vec[indices[idx_first]] = idx_result; + counts[idx_result] += 1; } return ++result; } template -std::tuple _unique_dim_cpu_template( +std::tuple _unique_dim_cpu_template( const Tensor& self, const int64_t dim, - const bool return_inverse) { + const bool return_inverse, + const bool return_counts) { // reshape tensor as [dim, -1] Tensor input_flat = self.transpose(dim, 0); auto orig_sizes = input_flat.sizes().vec(); @@ -109,10 +121,12 @@ std::tuple _unique_dim_cpu_template( } Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong)); + Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong)); std::vector input_unbind = at::unbind(input_sorted, 0); auto last = _unique_dim_cpu_impl( - input_unbind.begin(), input_unbind.end(), indices, inverse_indices); + input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts); input_unbind.erase(last, input_unbind.end()); + counts = at::narrow(counts, 0, 0, input_unbind.size()); // reshape back auto output = at::stack(input_unbind, 0); @@ -121,14 +135,24 @@ std::tuple _unique_dim_cpu_template( output = output.view(new_sizes); output = output.transpose(0, dim); - return std::make_tuple(output, inverse_indices); + return std::make_tuple(output, inverse_indices, counts); } } // namespace + std::tuple _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] { - return _unique_cpu_template(self, sorted, return_inverse); + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + Tensor output, inverse; + std::tie(output, inverse, std::ignore) = _unique_cpu_template(self, sorted, return_inverse, false); + return std::make_tuple(output, inverse); + }); +} + +std::tuple +_unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + return _unique_cpu_template(self, sorted, return_inverse, return_counts); }); } @@ -136,7 +160,17 @@ std::tuple _unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors - return _unique_dim_cpu_template(self, dim, return_inverse); + Tensor output, inverse; + std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template(self, dim, return_inverse, false); + return std::make_tuple(output, inverse); + }); +} + +std::tuple +_unique_dim2_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { + // The current implementation using `dim` always sorts due to unhashable tensors + return _unique_dim_cpu_template(self, dim, return_inverse, return_counts); }); } diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 0ba6812..e4945bb 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -15,148 +16,213 @@ namespace native{ namespace { + +template < + typename policy_t, typename scalar_t, + typename equal_t, typename not_equal_t +> +std::tuple compute_unique( + const policy_t &policy, + scalar_t *data, + int64_t num_inp, + const Tensor &sorted_indices, + const bool return_inverse, + const bool return_counts, + TensorOptions options, + equal_t equal, + not_equal_t not_equal +) { + + // inverse indices + Tensor inverse_indices; + if (!return_inverse) { + inverse_indices = at::empty({0}, options); + } else { + AT_CHECK(sorted_indices.defined(), + "return_inverse is set to true, but sorted_indices is undefined. Send a bug report!"); + const int64_t *sorted_indices_ptr = sorted_indices.data(); + Tensor inv_loc = at::empty({num_inp}, options); + inverse_indices = at::empty({num_inp}, options); + int64_t* inv_loc_ptr = inv_loc.data(); + int64_t* inverse_indices_ptr = inverse_indices.data(); + thrust::adjacent_difference(policy, data, data + num_inp, inv_loc_ptr, not_equal); + inv_loc[0] = 0; + thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr); + thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); + } + + // unique and count + Tensor counts = at::empty({0}, options); + int64_t num_out; + if (!return_counts) { + num_out = thrust::unique(policy, data, data + num_inp, equal) - data; + } else { + Tensor range = at::arange(0, num_inp + 1, options); + int64_t *range_ptr = range.data(); + num_out = thrust::unique_by_key(policy, data, data + num_inp, range_ptr, equal).first - data; + range[num_out] = num_inp; + counts.resize_(num_out); + int64_t* counts_ptr = counts.data(); + thrust::adjacent_difference(policy, range_ptr + 1, range_ptr + num_out + 1, counts_ptr); + } + + THCudaCheck(cudaGetLastError()); + return std::tuple(inverse_indices, counts, num_out); +} + template - std::tuple _unique_cuda_template( - const Tensor& self, - const bool return_inverse) { - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - const Tensor& input = self.contiguous(); - int64_t num_inp = input.numel(); - const scalar_t* input_data = input.data(); - - //sort & unique - Tensor output = input.clone(); - output = output.view(-1); - scalar_t* output_data = output.data(); - Tensor inverse_indices; - if (!return_inverse) { - inverse_indices = at::empty({0}, self.type().toScalarType(kLong)); - thrust::sort(policy, output_data, output_data + num_inp); - } else { - Tensor sorted_indices = at::arange(0, num_inp, self.type().toScalarType(kLong)); - int64_t* sorted_indices_ptr = sorted_indices.data(); - thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr); - Tensor inv_loc = at::empty({num_inp}, self.type().toScalarType(kLong)); - inverse_indices = at::empty({num_inp}, self.type().toScalarType(kLong)); - int64_t* inv_loc_ptr = inv_loc.data(); - int64_t* inverse_indices_ptr = inverse_indices.data(); - thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }}); - inv_loc[0] = 0; - thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr); - thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr); - inverse_indices.resize_(input.sizes()); - } - int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data; - output.resize_(num_out); +std::tuple unique_cuda_template( + const Tensor& self, + const bool return_inverse, + const bool return_counts +) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + auto options = self.options().dtype(kLong); + Tensor output = self.clone().reshape(-1); + int64_t num_inp = output.numel(); + scalar_t* output_data = output.data(); + + Tensor sorted_indices; + if (!return_inverse) { + thrust::sort(policy, output_data, output_data + num_inp); + } else { + sorted_indices = at::arange(0, num_inp, options); + int64_t *sorted_indices_ptr = sorted_indices.data(); + thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr); + } - THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices); + Tensor inverse_indices, counts; + int64_t num_out; + std::tie(inverse_indices, counts, num_out) = compute_unique( + policy, output_data, num_inp, sorted_indices, + return_inverse, return_counts, options, + thrust::equal_to(), + thrust::not_equal_to() + ); + output.resize_(num_out); + + if (return_inverse) { + inverse_indices.resize_(self.sizes()); } + return std::tuple(output, inverse_indices, counts); +} + template - std::tuple _unique_dim_cuda_template( - const Tensor& self, - const int64_t dim, - const bool return_inverse) { - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - Tensor input_flat = self.transpose(dim, 0); - auto orig_sizes = input_flat.sizes().vec(); - input_flat = input_flat.contiguous().view({input_flat.size(0), -1}); - - scalar_t* input_flat_ptr = input_flat.data(); - - Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong)); - int64_t* indices_ptr = indices.data(); - int64_t numel = input_flat.size(1); - - // sort indices using data - thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(), - [=] __device__ (int64_t a, int64_t b) -> bool { - for (int64_t i = 0; i < numel; ++i) { - scalar_t lhs = input_flat_ptr[i + a * numel]; - scalar_t rhs = input_flat_ptr[i + b * numel]; - if (lhs < rhs) { - return true; - } else if (lhs > rhs) { - return false; - } - } - return false; - }); - - Tensor input_sorted = input_flat.index_select(0, indices); - - // get unique tensors - scalar_t* input_sorted_ptr = input_sorted.data(); - Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong)); - int64_t* input_sorted_indices_ptr = input_sorted_indices.data(); - auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(), - [=] __device__ (int64_t a, int64_t b) -> bool { - for (int64_t i = 0; i < numel; ++i) { - scalar_t lhs = input_sorted_ptr[i + a * numel]; - scalar_t rhs = input_sorted_ptr[i + b * numel]; - if (lhs != rhs) { - return false; - } +std::tuple unique_dim_cuda_template( + const Tensor& self, + const int64_t dim, + const bool return_inverse, + const bool return_counts +) { + + /** + * The idea for implementing this is basically the same as unique. + * For unique_dim, we are taking the unique with respect to a index + * tensor, but during the processes, we override the compare and equal + * operator by checking the data underlying it instead. After the + * algorithm, we would use index_select to map the resulting indicies + * to the result on the actual data. + */ + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + int64_t num_inp = self.size(dim); + auto options = self.options().dtype(kLong); + Tensor input_flat = self.transpose(dim, 0).contiguous().view({num_inp, -1}); + int64_t n = input_flat.size(1); + scalar_t *input_flat_ptr = input_flat.data(); + + Tensor indices = at::arange(0, num_inp, options); + int64_t *indices_data = indices.data(); + thrust::sort(policy, indices_data, indices_data + num_inp, + [=] __device__ (int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < n; ++i) { + scalar_t lhs = input_flat_ptr[i + a * n]; + scalar_t rhs = input_flat_ptr[i + b * n]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; } - return true; - }); - input_sorted_indices.resize_(last - input_sorted_indices_ptr); - Tensor output = input_sorted.index_select(0, input_sorted_indices); - - // reshape back - auto new_sizes = std::vector(orig_sizes); - new_sizes[0] = -1; - output = output.view(new_sizes); - output = output.transpose(0, dim); - - // calculate inverse indices - Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong)); - if (return_inverse) { - int64_t size = self.size(dim); - inverse_indices.resize_(size); - Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong)); - mask[0] = 1; - for (int i = 0; i < input_sorted.size(0) - 1; ++i) { - if (!at::equal(input_sorted[i], input_sorted[i+1])) { - mask[i+1] = 1; - } else { - mask[i+1] = 0; + } + return false; + } + ); + + Tensor inverse_indices, counts; + int64_t num_out; + std::tie(inverse_indices, counts, num_out) = compute_unique( + policy, indices_data, num_inp, indices, + return_inverse, return_counts, options, + [=] __device__ (int64_t a, int64_t b) -> bool { + for (int64_t i = 0; i < n; ++i) { + scalar_t lhs = input_flat_ptr[i + a * n]; + scalar_t rhs = input_flat_ptr[i + b * n]; + if (lhs != rhs) { + return false; } } - - Tensor imask = at::cumsum(mask, 0) - 1; - for (int i = 0; i < indices.size(0); ++i) { - inverse_indices[indices[i]] = imask[i]; + return true; + }, + [=] __device__ (int64_t a, int64_t b) -> int64_t { + for (int64_t i = 0; i < n; ++i) { + scalar_t lhs = input_flat_ptr[i + a * n]; + scalar_t rhs = input_flat_ptr[i + b * n]; + if (lhs != rhs) { + return 1; + } } + return 0; } + ); + indices.resize_(num_out); + + return std::tuple(self.index_select(dim, indices), inverse_indices, counts); +} - THCudaCheck(cudaGetLastError()); - return std::tuple(output, inverse_indices); - } } // namespace + std::tuple _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { + // The current CUDA implementation of unique always sort due to the + // lack of hashtable implementation in thrust + Tensor output, inverse; + std::tie(output, inverse, std::ignore) = unique_cuda_template(self, return_inverse, false); + return std::make_tuple(output, inverse); + }); +} + +std::tuple +_unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust - return _unique_cuda_template(self, return_inverse); + return unique_cuda_template(self, return_inverse, return_counts); }); } std::tuple _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { - return _unique_dim_cuda_template(self, dim, return_inverse); + Tensor output, inverse; + std::tie(output, inverse, std::ignore) = unique_dim_cuda_template(self, dim, return_inverse, false); + return std::make_tuple(output, inverse); + }); +} + +std::tuple +_unique_dim2_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { + return unique_dim_cuda_template(self, dim, return_inverse, return_counts); }); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bcdff2e..439745d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2357,6 +2357,24 @@ CPU: _unique_dim_cpu CUDA: _unique_dim_cuda +# _unique and _unique_dim are fragile and modifying them easily cause internal break +# below two operators are a temporary hack for adding return_counts support +# Please don't rely on these two operators, they will be removed soon + +- func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: _unique2_cpu + CUDA: _unique2_cuda + +- func: _unique_dim2_temporary_will_remove_soon(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: _unique_dim2_cpu + CUDA: _unique_dim2_cuda + - func: _unsafe_view(Tensor self, int[] size) -> Tensor matches_jit_signature: True diff --git a/test/test_torch.py b/test/test_torch.py index ef21ea6..026ce40 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10506,57 +10506,87 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], torch.set_flush_denormal(False) def test_unique(self): - x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3]) - expected_unique = torch.LongTensor([1, 2, 3, 5, 8]) - expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2]) + def run_test(device): + x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], device=device) + expected_unique = torch.tensor([1, 2, 3, 5, 8], device=device) + expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) + expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) - x_unique = torch.unique(x) - self.assertEqual( - expected_unique.tolist(), sorted(x_unique.tolist())) + x_unique = torch.unique(x) + self.assertEqual( + expected_unique.tolist(), sorted(x_unique.tolist())) - x_unique, x_inverse = x.unique(return_inverse=True) - self.assertEqual( - expected_unique.tolist(), sorted(x_unique.tolist())) - self.assertEqual(expected_inverse.numel(), x_inverse.numel()) - - x_unique = x.unique(sorted=True) - self.assertEqual(expected_unique, x_unique) - - x_unique, x_inverse = torch.unique( - x, sorted=True, return_inverse=True) - self.assertEqual(expected_unique, x_unique) - self.assertEqual(expected_inverse, x_inverse) - - # Tests per-element unique on a higher rank tensor. - y = x.view(2, 2, 2) - y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) - self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse.view(y.size()), y_inverse) - - # Tests unique on other types. - int_unique, int_inverse = torch.unique( - torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True) - self.assertEqual(torch.IntTensor([1, 2]), int_unique) - self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse) - - double_unique, double_inverse = torch.unique( - torch.DoubleTensor([2., 1.5, 2.1, 2.]), - sorted=True, - return_inverse=True, - ) - self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique) - self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse) + x_unique, x_inverse = x.unique(return_inverse=True) + self.assertEqual( + expected_unique.tolist(), sorted(x_unique.tolist())) + self.assertEqual(expected_inverse.numel(), x_inverse.numel()) + + x_unique = x.unique(sorted=True) + self.assertEqual(expected_unique, x_unique) + + x_unique, _, x_counts = torch._unique2_temporary_will_remove_soon(x, sorted=True, return_counts=True) + self.assertEqual(expected_counts, x_counts) + + x_unique, x_inverse = torch.unique( + x, sorted=True, return_inverse=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + + x_unique, x_inverse, x_counts = torch._unique2_temporary_will_remove_soon( + x, sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + self.assertEqual(expected_counts, x_counts) + + # Tests per-element unique on a higher rank tensor. + y = x.view(2, 2, 2) + y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) + + y_unique, y_inverse, y_counts = torch._unique2_temporary_will_remove_soon( + y, sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) + self.assertEqual(expected_counts, y_counts) + + # Tests unique on other types. + int_unique, int_inverse, int_counts = torch._unique2_temporary_will_remove_soon( + torch.tensor([2, 1, 2], dtype=torch.int, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([1, 2], dtype=torch.int, device=device), int_unique) + self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse) + self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts) + + double_unique, double_inverse, double_counts = torch._unique2_temporary_will_remove_soon( + torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([1.5, 2., 2.1], dtype=torch.double, device=device), double_unique) + self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse) + self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts) + + byte_unique, byte_inverse, byte_counts = torch._unique2_temporary_will_remove_soon( + torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([7, 42, 128, 133], dtype=torch.uint8, device=device), byte_unique) + self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse) + self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts) - byte_unique, byte_inverse = torch.unique( - torch.ByteTensor([133, 7, 7, 7, 42, 128]), - sorted=True, - return_inverse=True, - ) - self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique) - self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse) + run_test(torch.device('cpu')) + if torch.cuda.is_available(): + run_test(torch.device('cuda')) def test_unique_dim(self): - def run_test(dtype=torch.float): + def run_test(dtype=torch.float, device=torch.device('cpu')): x = torch.tensor([[[1., 1.], [0., 1.], [2., 1.], @@ -10564,19 +10594,27 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [[1., 1.], [0., 1.], [2., 1.], - [0., 1.]]], dtype=dtype) + [0., 1.]]], + dtype=dtype, + device=device) expected_unique_dim0 = torch.tensor([[[1., 1.], [0., 1.], [2., 1.], - [0., 1.]]], dtype=dtype) + [0., 1.]]], + dtype=dtype, + device=device) expected_inverse_dim0 = torch.tensor([0, 0]) + expected_counts_dim0 = torch.tensor([2]) expected_unique_dim1 = torch.tensor([[[0., 1.], [1., 1.], [2., 1.]], [[0., 1.], [1., 1.], - [2., 1.]]], dtype=dtype) + [2., 1.]]], + dtype=dtype, + device=device) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) + expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_unique_dim2 = torch.tensor([[[1., 1.], [0., 1.], [2., 1.], @@ -10584,37 +10622,105 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], [[1., 1.], [0., 1.], [2., 1.], - [0., 1.]]], dtype=dtype) + [0., 1.]]], + dtype=dtype, + device=device) expected_inverse_dim2 = torch.tensor([0, 1]) + expected_counts_dim2 = torch.tensor([1, 1]) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + + x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=False, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_counts_dim0, x_counts) + + x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=True, + return_counts=True, + dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) + self.assertEqual(expected_counts_dim0, x_counts) # dim1 x_unique = torch.unique(x, dim=1) self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + + x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=False, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_counts_dim1, x_counts) + + x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=True, + return_counts=True, + dim=1) self.assertEqual(expected_unique_dim1, x_unique) self.assertEqual(expected_inverse_dim1, x_inverse) + self.assertEqual(expected_counts_dim1, x_counts) # dim2 x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) + x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=False, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_counts_dim2, x_counts) + + x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon( + x, + return_inverse=True, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + self.assertEqual(expected_counts_dim2, x_counts) + run_test(torch.float) run_test(torch.double) run_test(torch.long) run_test(torch.uint8) + if torch.cuda.is_available(): + run_test(torch.float, torch.device('cuda')) + run_test(torch.double, torch.device('cuda')) + run_test(torch.long, torch.device('cuda')) + run_test(torch.uint8, torch.device('cuda')) @staticmethod def _test_bincount(self, device): -- 2.7.4