improve performance of unique with inverse indices (#16145)
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Fri, 18 Jan 2019 22:53:32 +0000 (14:53 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 22:56:39 +0000 (14:56 -0800)
Summary:
Partial fix for #15804, only w/o dim.
For jcjohnson benchmarking script I'm getting the following results on V100:
Before:
```
unning with N = 10000, M = 10000
cuda (no inverse): 0.98 ms
cpu (no inverse): 0.96 ms
cuda (with inverse): 1.07 ms
cpu (with inverse): 1.76 ms

Running with N = 10000, M = 100000
cuda (no inverse): 0.76 ms
cpu (no inverse): 1.53 ms
cuda (with inverse): 1.23 ms
cpu (with inverse): 3.02 ms

Running with N = 100000, M = 100000
cuda (no inverse): 1.28 ms
cpu (no inverse): 11.22 ms
cuda (with inverse): 69.76 ms
cpu (with inverse): 20.28 ms

Running with N = 100000, M = 1000000
cuda (no inverse): 0.78 ms
cpu (no inverse): 18.78 ms
cuda (with inverse): 133.45 ms
cpu (with inverse): 34.09 ms

Running with N = 500000, M = 500000
cuda (no inverse): 1.43 ms
cpu (no inverse): 61.13 ms
cuda (with inverse): 3315.18 ms
cpu (with inverse): 104.57 ms

Running with N = 500000, M = 5000000
cuda (no inverse): 0.86 ms
cpu (no inverse): 96.44 ms
cuda (with inverse): 5209.93 ms
cpu (with inverse): 176.10 ms
```
After
```
Running with N = 10000, M = 10000
cuda (no inverse): 1.04 ms
cpu (no inverse): 0.94 ms
cuda (with inverse): 0.64 ms
cpu (with inverse): 1.76 ms

Running with N = 10000, M = 100000
cuda (no inverse): 0.77 ms
cpu (no inverse): 1.55 ms
cuda (with inverse): 0.58 ms
cpu (with inverse): 2.79 ms

Running with N = 100000, M = 100000
cuda (no inverse): 1.30 ms
cpu (no inverse): 14.15 ms
cuda (with inverse): 1.63 ms
cpu (with inverse): 20.90 ms

Running with N = 100000, M = 1000000
cuda (no inverse): 0.82 ms
cpu (no inverse): 18.63 ms
cuda (with inverse): 0.61 ms
cpu (with inverse): 33.52 ms

Running with N = 500000, M = 500000
cuda (no inverse): 1.51 ms
cpu (no inverse): 59.81 ms
cuda (with inverse): 1.23 ms
cpu (with inverse): 110.69 ms

Running with N = 500000, M = 5000000
cuda (no inverse): 0.92 ms
cpu (no inverse): 104.26 ms
cuda (with inverse): 0.84 ms
cpu (with inverse): 187.12 ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16145

Differential Revision: D13738821

Pulled By: soumith

fbshipit-source-id: 0811fb4ade47e3b466cebbc124e3f3333a986749

aten/src/ATen/native/cuda/Unique.cu

index 1687b6b..828fb48 100644 (file)
@@ -7,28 +7,13 @@
 #include <tuple>
 #include <thrust/unique.h>
 #include <thrust/sort.h>
+#include <thrust/scan.h>
+#include <thrust/scatter.h>
 
 namespace at {
 namespace native{
 
 namespace {
-template <typename scalar_t>
-__global__ void inverse_indices_kernel(
-    const scalar_t* input_data,
-    const scalar_t* output_data,
-    int64_t* inverse_indices_data,
-    int64_t num_inp,
-    int64_t num_out) {
-    int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
-    int64_t stride = blockDim.x * gridDim.x;
-
-    for (int64_t i = idx; i < num_inp * num_out; i += stride) {
-      if (input_data[i / num_out] == output_data[i % num_out]){
-        inverse_indices_data[i / num_out] = i % num_out;
-      }
-    }
-  }
-
 
 template <typename scalar_t>
   std::tuple<Tensor, Tensor> _unique_cuda_template(
@@ -47,25 +32,29 @@ template <typename scalar_t>
     Tensor output = input.clone();
     output = output.view(-1);
     scalar_t* output_data = output.data<scalar_t>();
-    thrust::sort(policy, output_data, output_data + num_inp);
-    scalar_t* output_end = thrust::unique(policy, output_data, output_data + num_inp);
-    int64_t num_out = output_end - output_data;
-    output.resize_(num_out);
-
-    Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
-
-    if (return_inverse) {
-      inverse_indices.resize_(input.sizes());
-      int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
-      int block = 512;
-      int grid = std::min<int64_t>((num_inp * num_out + block - 1) / block, 2048L);
-      inverse_indices_kernel<<<grid, block, 0, stream>>>(
-        input_data, output_data, inverse_indices_data, num_inp, num_out);
+    Tensor inverse_indices;
+    if (!return_inverse) {
+        inverse_indices = at::empty({0},  self.type().toScalarType(kLong));
+        thrust::sort(policy, output_data, output_data + num_inp);
+    } else {
+        Tensor sorted_indices = at::arange(0, num_inp, self.type().toScalarType(kLong));
+        int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
+        thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
+        Tensor inv_loc = at::empty({num_inp}, self.type().toScalarType(kLong));
+        inverse_indices = at::empty({num_inp}, self.type().toScalarType(kLong));
+        int64_t* inv_loc_ptr = inv_loc.data<int64_t>();
+        int64_t* inverse_indices_ptr = inverse_indices.data<int64_t>();
+        thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
+        inv_loc[0] = 0;
+        thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
+        thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
+        inverse_indices.resize_(input.sizes());
     }
+    int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
+    output.resize_(num_out);
 
     THCudaCheck(cudaGetLastError());
     return std::tuple<Tensor, Tensor>(output, inverse_indices);
-
   }
 
 template <typename scalar_t>