EmbeddingBag sort thrust->cub (#64498)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 14 Sep 2021 02:49:33 +0000 (19:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 02:51:12 +0000 (19:51 -0700)
Summary:
Partially fixes https://github.com/pytorch/pytorch/issues/57505

Also fixes a warning I found when compiling:
```
/home/gaoxiang/pytorch-cub/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu(7): warning: inline qualifier ignored for "__global__" function
```
I also updated the bfloat16 guard to CUDA 11.5

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64498

Reviewed By: mruberry

Differential Revision: D30917077

Pulled By: ngimel

fbshipit-source-id: fb9df08fd469038478a563014b5af7452b4b28c0

aten/src/ATen/cuda/cub.cuh
aten/src/ATen/native/cuda/EmbeddingBag.cu
torch/csrc/distributed/c10d/quantization/quantization_gpu.cu
torch/testing/_internal/common_nn.py

index 38e5852..26f8047 100644 (file)
@@ -55,7 +55,11 @@ struct cuda_type<c10::Half> {
   using type = __half;
 };
 
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 99999
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11050
+// cub sort support for __nv_bfloat16 is added to cub 1.13 in
+// https://github.com/NVIDIA/cub/pull/306 and according to
+// https://github.com/NVIDIA/cub#releases, 1.13 is included in
+// CUDA Toolkit 11.5
 
 // waiting for https://github.com/NVIDIA/cub/pull/306 to land on CUDA
 template<>
index 3509468..8d1ef8b 100644 (file)
@@ -7,14 +7,9 @@
 
 #include <THC/THCDeviceUtils.cuh>
 #include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCThrustAllocator.cuh>
 #include <THC/THCAtomics.cuh>
 
-#include <thrust/execution_policy.h>
-#include <thrust/unique.h>
-#include <thrust/iterator/constant_iterator.h>
-#include <thrust/device_vector.h>
-
+#include <ATen/cuda/cub.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
 #include <ATen/native/cuda/KernelUtils.cuh>
@@ -24,6 +19,9 @@
 namespace at {
 namespace native {
 
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
+
 namespace {
 
 constexpr int MODE_SUM = 0;
@@ -149,25 +147,24 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
   }
 }
 
-
-
 Tensor embedding_bag_backward_cuda_sum_avg(
                                    const Tensor &grad,
-                                   const Tensor &indices,
+                                   const Tensor &indices_,
                                    const Tensor &offset2bag,
                                    const Tensor &bag_size,
                                    int64_t num_weights,
                                    bool scale_grad_by_freq, int64_t mode,
                                    const Tensor& per_sample_weights,
                                    int64_t padding_idx) {
+  auto indices = indices_.contiguous();
 
   auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
-  ptrdiff_t numel = indices.numel();
+  ptrdiff_t num_indices = indices.numel();
 
-  if (numel == 0) {
+  if (num_indices == 0) {
     // all empty bags
     return at::zeros({num_weights, grad.size(1)}, grad.options());
   }
@@ -179,52 +176,16 @@ Tensor embedding_bag_backward_cuda_sum_avg(
   Tensor count;
 
   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
-    using device_ptr = thrust::device_ptr<index_t>;
-
-    // Sort the inputs into sorted with the corresponding indices; we
-    // don't need a stable or multidimensional sort, so just use Thrust
-    // directly
-    {
-      sorted_indices.copy_(indices);
-
-      auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
-      auto policy = thrust::cuda::par(allocator).on(stream);
-
-      // Fill sortedOrigIndices with sequential indices
-      auto count_iter = thrust::counting_iterator<index_t>(0);
-      auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
-      thrust::copy(policy, count_iter, count_iter + numel, orig_data);
-
-      // Sort; a stable sort is not required
-      auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
-      thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
-                          LTOp<index_t>());
-    }
+    auto range = at::arange(num_indices, indices.options());
+    int64_t nbits = cuda::cub::get_num_bits(num_weights);
+    cuda::cub::sort_pairs(
+      indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
+      range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
+      num_indices, false/*, 0, nbits*/);
 
     if (scale_grad_by_freq) {
       count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-
-      auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
-      auto policy = thrust::cuda::par(allocator).on(stream);
-
-      // Compute an increasing sequence per unique item in sortedIndices:
-      // sorted: 2 5 5 5 7 7 8 9 9
-      //  count: 1 1 2 3 1 2 1 1 2
-      auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
-      auto count_data = device_ptr(count.data_ptr<index_t>());
-      thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
-                                    thrust::make_constant_iterator(1),
-                                    count_data);
-
-      // Take the maximum of each count per unique key in reverse:
-      // sorted: 2 5 5 5 7 7 8 9 9
-      //  count: 1 3 3 3 2 2 1 2 2
-      thrust::inclusive_scan_by_key(
-          policy, thrust::make_reverse_iterator(sorted_data + numel),
-          thrust::make_reverse_iterator(sorted_data),
-          thrust::make_reverse_iterator(count_data + numel),
-          thrust::make_reverse_iterator(count_data + numel),
-          thrust::equal_to<index_t>(), thrust::maximum<index_t>());
+      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
     }
   });
   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
index 5590e03..7b78f5f 100644 (file)
@@ -4,7 +4,7 @@
 #include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
 
 // FP32 -> BF16 kernel
-__global__ inline void _float_to_bfloat16_cuda_kernel(
+__global__ void _float_to_bfloat16_cuda_kernel(
     const float* __restrict__ input,
     const int nrows,
     const int ncols,
@@ -26,7 +26,7 @@ __global__ inline void _float_to_bfloat16_cuda_kernel(
 }
 
 // BF16 -> FP32 kernel
-__global__ inline void _bfloat16_to_float_cuda_kernel(
+__global__ void _bfloat16_to_float_cuda_kernel(
     const uint16_t* __restrict__ input,
     const int nrows,
     const int ncols,
index f3cdc47..58f30d8 100644 (file)
@@ -2792,6 +2792,14 @@ new_module_tests = [
     ),
     dict(
         module_name='EmbeddingBag',
+        constructor_args=(4, 3),
+        cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+        input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+        check_gradgrad=False,
+        desc='discontiguous',
+    ),
+    dict(
+        module_name='EmbeddingBag',
         constructor_args=(4, 3, None, 2., False, 'sum'),
         cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                 .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',