#include <tuple>
#include <thrust/unique.h>
#include <thrust/sort.h>
+#include <thrust/scan.h>
+#include <thrust/scatter.h>
namespace at {
namespace native{
namespace {
-template <typename scalar_t>
-__global__ void inverse_indices_kernel(
- const scalar_t* input_data,
- const scalar_t* output_data,
- int64_t* inverse_indices_data,
- int64_t num_inp,
- int64_t num_out) {
- int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
- int64_t stride = blockDim.x * gridDim.x;
-
- for (int64_t i = idx; i < num_inp * num_out; i += stride) {
- if (input_data[i / num_out] == output_data[i % num_out]){
- inverse_indices_data[i / num_out] = i % num_out;
- }
- }
- }
-
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_cuda_template(
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
- thrust::sort(policy, output_data, output_data + num_inp);
- scalar_t* output_end = thrust::unique(policy, output_data, output_data + num_inp);
- int64_t num_out = output_end - output_data;
- output.resize_(num_out);
-
- Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
-
- if (return_inverse) {
- inverse_indices.resize_(input.sizes());
- int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
- int block = 512;
- int grid = std::min<int64_t>((num_inp * num_out + block - 1) / block, 2048L);
- inverse_indices_kernel<<<grid, block, 0, stream>>>(
- input_data, output_data, inverse_indices_data, num_inp, num_out);
+ Tensor inverse_indices;
+ if (!return_inverse) {
+ inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
+ thrust::sort(policy, output_data, output_data + num_inp);
+ } else {
+ Tensor sorted_indices = at::arange(0, num_inp, self.type().toScalarType(kLong));
+ int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
+ thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+ Tensor inv_loc = at::empty({num_inp}, self.type().toScalarType(kLong));
+ inverse_indices = at::empty({num_inp}, self.type().toScalarType(kLong));
+ int64_t* inv_loc_ptr = inv_loc.data<int64_t>();
+ int64_t* inverse_indices_ptr = inverse_indices.data<int64_t>();
+ thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
+ inv_loc[0] = 0;
+ thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
+ thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
+ inverse_indices.resize_(input.sizes());
}
+ int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
+ output.resize_(num_out);
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
-
}
template <typename scalar_t>