From 816048e7e6953389d8fc1dd706c39613f437e2c7 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 19:49:33 -0700 Subject: [PATCH] EmbeddingBag sort thrust->cub (#64498) Summary: Partially fixes https://github.com/pytorch/pytorch/issues/57505 Also fixes a warning I found when compiling: ``` /home/gaoxiang/pytorch-cub/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu(7): warning: inline qualifier ignored for "__global__" function ``` I also updated the bfloat16 guard to CUDA 11.5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64498 Reviewed By: mruberry Differential Revision: D30917077 Pulled By: ngimel fbshipit-source-id: fb9df08fd469038478a563014b5af7452b4b28c0 --- aten/src/ATen/cuda/cub.cuh | 6 +- aten/src/ATen/native/cuda/EmbeddingBag.cu | 69 +++++----------------- .../c10d/quantization/quantization_gpu.cu | 4 +- torch/testing/_internal/common_nn.py | 8 +++ 4 files changed, 30 insertions(+), 57 deletions(-) diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index 38e5852..26f8047 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -55,7 +55,11 @@ struct cuda_type { using type = __half; }; -#if defined(CUDA_VERSION) && CUDA_VERSION >= 99999 +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11050 +// cub sort support for __nv_bfloat16 is added to cub 1.13 in +// https://github.com/NVIDIA/cub/pull/306 and according to +// https://github.com/NVIDIA/cub#releases, 1.13 is included in +// CUDA Toolkit 11.5 // waiting for https://github.com/NVIDIA/cub/pull/306 to land on CUDA template<> diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 3509468..8d1ef8b 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -7,14 +7,9 @@ #include #include -#include #include -#include -#include -#include -#include - +#include #include #include #include @@ -24,6 +19,9 @@ namespace at { namespace native { +template +void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); + namespace { constexpr int MODE_SUM = 0; @@ -149,25 +147,24 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean( } } - - Tensor embedding_bag_backward_cuda_sum_avg( const Tensor &grad, - const Tensor &indices, + const Tensor &indices_, const Tensor &offset2bag, const Tensor &bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const Tensor& per_sample_weights, int64_t padding_idx) { + auto indices = indices_.contiguous(); auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - ptrdiff_t numel = indices.numel(); + ptrdiff_t num_indices = indices.numel(); - if (numel == 0) { + if (num_indices == 0) { // all empty bags return at::zeros({num_weights, grad.size(1)}, grad.options()); } @@ -179,52 +176,16 @@ Tensor embedding_bag_backward_cuda_sum_avg( Tensor count; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { - using device_ptr = thrust::device_ptr; - - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + numel, orig_data); - - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data, - LTOp()); - } + auto range = at::arange(num_indices, indices.options()); + int64_t nbits = cuda::cub::get_num_bits(num_weights); + cuda::cub::sort_pairs( + indices.data_ptr(), sorted_indices.data_ptr(), + range.data_ptr(), orig_indices.data_ptr(), + num_indices, false/*, 0, nbits*/); if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel, - thrust::make_constant_iterator(1), - count_data); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, thrust::make_reverse_iterator(sorted_data + numel), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + numel), - thrust::make_reverse_iterator(count_data + numel), - thrust::equal_to(), thrust::maximum()); + embedding_dense_backward_cuda_scan(sorted_indices, count); } }); return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu index 5590e03..7b78f5f 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu @@ -4,7 +4,7 @@ #include // FP32 -> BF16 kernel -__global__ inline void _float_to_bfloat16_cuda_kernel( +__global__ void _float_to_bfloat16_cuda_kernel( const float* __restrict__ input, const int nrows, const int ncols, @@ -26,7 +26,7 @@ __global__ inline void _float_to_bfloat16_cuda_kernel( } // BF16 -> FP32 kernel -__global__ inline void _bfloat16_to_float_cuda_kernel( +__global__ void _bfloat16_to_float_cuda_kernel( const uint16_t* __restrict__ input, const int nrows, const int ncols, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index f3cdc47..58f30d8 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -2792,6 +2792,14 @@ new_module_tests = [ ), dict( module_name='EmbeddingBag', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous', + ), + dict( + module_name='EmbeddingBag', constructor_args=(4, 3, None, 2., False, 'sum'), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', -- 2.7.4