[Reland] Embedding thrust->cub migration (#63806)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 24 Aug 2021 16:24:50 +0000 (09:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 16:30:32 +0000 (09:30 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63806

Reviewed By: bdhirsh

Differential Revision: D30498255

Pulled By: ngimel

fbshipit-source-id: 78b7085a92a168cf0163f53dcb712bac922f5235

aten/src/ATen/cuda/cub.cuh
aten/src/ATen/native/cuda/Embedding.cu
aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
aten/src/ATen/native/cuda/Indexing.cu
aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
aten/src/ATen/native/cuda/Randperm.cu
aten/src/ATen/native/cuda/UniqueCub.cu
torch/testing/_internal/common_nn.py

index 62da28d..38e5852 100644 (file)
@@ -3,6 +3,7 @@
 #include <cstddef>
 #include <type_traits>
 #include <iterator>
+#include <limits>
 
 // include cub in a safe manner, see:
 // https://github.com/pytorch/pytorch/pull/55292
@@ -102,6 +103,8 @@ static inline void sort_keys(
     const key_t *keys_in, key_t *keys_out,
     int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
 ) {
+  TORCH_CHECK(n <= std::numeric_limits<int>::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
   using key_t_ = typename detail::cuda_type<key_t>::type;
 
   const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
@@ -124,6 +127,8 @@ static inline void sort_pairs(
     const value_t *values_in, value_t *values_out,
     int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
 ) {
+  TORCH_CHECK(n <= std::numeric_limits<int>::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
   using key_t_ = typename detail::cuda_type<key_t>::type;
 
   auto allocator = c10::cuda::CUDACachingAllocator::get();
@@ -156,6 +161,10 @@ static inline void segmented_sort_pairs(
     OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
     bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
 ) {
+  TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
+  TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
   using key_t_ = typename detail::cuda_type<key_t>::type;
 
   auto allocator = c10::cuda::CUDACachingAllocator::get();
@@ -305,4 +314,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
   }
 }
 
-}}}
+template<typename InputIteratorT , typename OutputIteratorT , typename NumSelectedIteratorT >
+inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) {
+  TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
+    "cub unique does not support more than INT_MAX elements");
+  CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique,
+    input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
+}
+
+}}}  // namespace at::cuda::cub
index 10a42b8..ba79fa1 100644 (file)
@@ -7,12 +7,9 @@
 
 #include <THC/THCDeviceUtils.cuh>
 #include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCThrustAllocator.cuh>
 #include <THC/THCReduceApplyUtils.cuh>
 
-#include <thrust/execution_policy.h>
-#include <thrust/iterator/constant_iterator.h>
-#include <thrust/unique.h>
+#include <ATen/cuda/cub.cuh>
 
 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
@@ -224,14 +221,19 @@ __global__ void renorm_kernel(
 
 } // anonymous namespace
 
-Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices,
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
+
+Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
                                int64_t num_weights, int64_t padding_idx,
                                bool scale_grad_by_freq) {
   auto grad_arg = TensorArg(grad_, "grad", 1);
-  auto indices_arg = TensorArg(indices, "indices", 1);
+  auto indices_arg = TensorArg(indices_, "indices", 1);
   checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
   checkSameGPU("embedding_backward", grad_arg, indices_arg);
 
+  auto indices = indices_.contiguous();
+
   auto num_indices = indices.numel();
   auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -272,59 +274,16 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
   auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
   Tensor count;
   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
-    using device_ptr = thrust::device_ptr<index_t>;
-
-    // Sort the inputs into sorted with the corresponding indices; we
-    // don't need a stable or multidimensional sort, so just use Thrust
-    // directly
-    {
-        sorted_indices.copy_(indices);
-
-        auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
-        auto policy = thrust::cuda::par(allocator).on(stream);
-
-        // Fill sortedOrigIndices with sequential indices
-        auto count_iter = thrust::counting_iterator<index_t>(0);
-        auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
-        thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
-
-        // Sort; a stable sort is not required
-        auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
-        thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
-                            LTOp<index_t>());
-    }
+    auto range = at::arange(num_indices, indices.options());
+    int64_t nbits = cuda::cub::get_num_bits(num_weights);
+    cuda::cub::sort_pairs(
+      indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
+      range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
+      num_indices, false/*, 0, nbits*/);
 
     if (scale_grad_by_freq) {
       count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-
-      auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
-      auto policy = thrust::cuda::par(allocator).on(stream);
-
-      // Compute an increasing sequence per unique item in sortedIndices:
-      // sorted: 2 5 5 5 7 7 8 9 9
-      //  count: 1 1 2 3 1 2 1 1 2
-      auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
-      auto count_data = device_ptr(count.data_ptr<index_t>());
-      thrust::inclusive_scan_by_key(
-        policy,
-        sorted_data,
-        sorted_data + num_indices,
-        thrust::make_constant_iterator(1),
-        count_data
-      );
-
-      // Take the maximum of each count per unique key in reverse:
-      // sorted: 2 5 5 5 7 7 8 9 9
-      //  count: 1 3 3 3 2 2 1 2 2
-      thrust::inclusive_scan_by_key(
-        policy,
-        thrust::make_reverse_iterator(sorted_data + num_indices),
-        thrust::make_reverse_iterator(sorted_data),
-        thrust::make_reverse_iterator(count_data + num_indices),
-        thrust::make_reverse_iterator(count_data + num_indices),
-        thrust::equal_to<index_t>(),
-        thrust::maximum<index_t>()
-      );
+      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
     }
   });
 
@@ -340,23 +299,23 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
   checkSameGPU("embedding_renorm", self_arg, indices_arg);
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
-  auto policy = thrust::cuda::par(allocator).on(stream);
 
   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
-    using device_ptr = thrust::device_ptr<index_t>;
 
     auto num_indices = indices.numel();
     auto indices_contig = std::get<0>(indices.sort()).contiguous();
-    auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
-
     auto unique_indices = at::empty(indices.numel(), indices.options());
-    auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
-    auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
-    auto num_unique_indices = static_cast<int>(end - unique_data);
+    auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
+
+    cuda::cub::unique(
+      indices_contig.data_ptr<index_t>(),
+      unique_indices.data_ptr<index_t>(),
+      num_unique_indices.data_ptr<int64_t>(),
+      num_indices
+    );
 
-    dim3 grid(num_unique_indices);
-    dim3 block(128);
+    dim3 grid = num_unique_indices.item<int64_t>();
+    dim3 block = 128;
     int dim = self.stride(0);
 
     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
index f06b850..c79bf83 100644 (file)
 #include <THC/THCThrustAllocator.cuh>
 #include <THC/THCAtomics.cuh>
 
-#include <thrust/execution_policy.h>
-#include <thrust/unique.h>
-#include <thrust/device_vector.h>
-
 #pragma once
 
 namespace at {
index 95ab33e..57654f2 100644 (file)
@@ -218,9 +218,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
   std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self, indices, !unsafe);
   int64_t num_indices = linearIndex.numel();
 
-  TORCH_CHECK(num_indices <= std::numeric_limits<int>::max(),
-    "index_put of tensors larger than INT_MAX is not supported yet in pytorch");
-
   if (num_indices > 0 && sliceSize > 0) {
       const bool permuted = !src.is_contiguous();
       auto src_ = permuted ? src.contiguous() : src;
index 582dc9e..446aa08 100644 (file)
@@ -5,6 +5,8 @@
 #include <thrust/device_ptr.h>
 #include <thrust/execution_policy.h>
 #include <thrust/sort.h>
+#include <thrust/unique.h>
+#include <thrust/device_ptr.h>
 
 namespace at { namespace native {
 
@@ -30,4 +32,45 @@ void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_
   thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
 }
 
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
+  using device_ptr = thrust::device_ptr<index_t>;
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+  auto policy = thrust::cuda::par(allocator).on(stream);
+
+  auto num_indices = count.numel();
+
+  // Compute an increasing sequence per unique item in sortedIndices:
+  // sorted: 2 5 5 5 7 7 8 9 9
+  //  count: 1 1 2 3 1 2 1 1 2
+  auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
+  auto count_data = device_ptr(count.data_ptr<index_t>());
+  thrust::inclusive_scan_by_key(
+    policy,
+    sorted_data,
+    sorted_data + num_indices,
+    thrust::make_constant_iterator(1),
+    count_data
+  );
+
+  // Take the maximum of each count per unique key in reverse:
+  // sorted: 2 5 5 5 7 7 8 9 9
+  //  count: 1 3 3 3 2 2 1 2 2
+  thrust::inclusive_scan_by_key(
+    policy,
+    thrust::make_reverse_iterator(sorted_data + num_indices),
+    thrust::make_reverse_iterator(sorted_data),
+    thrust::make_reverse_iterator(count_data + num_indices),
+    thrust::make_reverse_iterator(count_data + num_indices),
+    thrust::equal_to<index_t>(),
+    thrust::maximum<index_t>()
+  );
+}
+
+template
+void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
+template
+void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
+
 }}
index 4c5e16a..56b8eb2 100644 (file)
@@ -47,8 +47,6 @@ template <int N> struct alignas(N) OpaqueType { char data[N]; };
 
 Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor& result) {
   TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
-  TORCH_CHECK(n <= std::numeric_limits<int>::max(),
-    "randperm of tensors larger than INT_MAX is not supported yet in pytorch");
 
   check_supported_max_int_with_precision(n, result);
 
index 1b9619b..eb31fd2 100644 (file)
@@ -94,13 +94,7 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> compute_unique(
   Tensor length = at::empty({1}, options);
   int64_t num_out;
   if (!return_counts) {
-    CUB_WRAPPER(
-        cub::DeviceSelect::Unique,
-        data,
-        data_out.data_ptr<scalar_t>(),
-        length.data_ptr<int64_t>(),
-        num_inp,
-        stream);
+    cuda::cub::unique(data, data_out.data_ptr<scalar_t>(), length.data_ptr<int64_t>(), num_inp);
     num_out = length.item<int64_t>();
   } else {
     counts.resize_(num_inp);
@@ -135,11 +129,6 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
 
   auto options = self.options().dtype(kLong);
   int64_t num_inp = self.numel();
-  TORCH_CHECK(
-      num_inp <= INT_MAX,
-      "num_inp ",
-      num_inp,
-      " is too big to be handled by cub");
   Tensor sorted;
   Tensor self_c = self.contiguous();
   if (consecutive) {
index 90024de..e0d09b7 100644 (file)
@@ -2775,6 +2775,14 @@ new_module_tests = [
         check_gradgrad=False,
     ),
     dict(
+        module_name='Embedding',
+        constructor_args=(4, 3),
+        cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+        input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+        check_gradgrad=False,
+        desc='discontiguous'
+    ),
+    dict(
         module_name='EmbeddingBag',
         constructor_args=(4, 3),
         cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',