From: Xiang Gao Date: Tue, 24 Aug 2021 16:24:50 +0000 (-0700) Subject: [Reland] Embedding thrust->cub migration (#63806) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~774 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=227cb268bccd22feb8aa8651773a202ec1e09c7f;p=platform%2Fupstream%2Fpytorch.git [Reland] Embedding thrust->cub migration (#63806) Summary: Fixes https://github.com/pytorch/pytorch/issues/63427 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63806 Reviewed By: bdhirsh Differential Revision: D30498255 Pulled By: ngimel fbshipit-source-id: 78b7085a92a168cf0163f53dcb712bac922f5235 --- diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index 62da28d..38e5852 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -3,6 +3,7 @@ #include #include #include +#include // include cub in a safe manner, see: // https://github.com/pytorch/pytorch/pull/55292 @@ -102,6 +103,8 @@ static inline void sort_keys( const key_t *keys_in, key_t *keys_out, int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { + TORCH_CHECK(n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; const key_t_ *keys_in_ = reinterpret_cast(keys_in); @@ -124,6 +127,8 @@ static inline void sort_pairs( const value_t *values_in, value_t *values_out, int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { + TORCH_CHECK(n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; auto allocator = c10::cuda::CUDACachingAllocator::get(); @@ -156,6 +161,10 @@ static inline void segmented_sort_pairs( OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { + TORCH_CHECK(num_elements <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + TORCH_CHECK(num_segments <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; auto allocator = c10::cuda::CUDACachingAllocator::get(); @@ -305,4 +314,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT } } -}}} +template +inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique, + input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); +} + +}}} // namespace at::cuda::cub diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 10a42b8..ba79fa1 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -7,12 +7,9 @@ #include #include -#include #include -#include -#include -#include +#include #include #include @@ -224,14 +221,19 @@ __global__ void renorm_kernel( } // anonymous namespace -Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices, +template +void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); + +Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { auto grad_arg = TensorArg(grad_, "grad", 1); - auto indices_arg = TensorArg(indices, "indices", 1); + auto indices_arg = TensorArg(indices_, "indices", 1); checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); checkSameGPU("embedding_backward", grad_arg, indices_arg); + auto indices = indices_.contiguous(); + auto num_indices = indices.numel(); auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)}); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -272,59 +274,16 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); Tensor count; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { - using device_ptr = thrust::device_ptr; - - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); - - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, - LTOp()); - } + auto range = at::arange(num_indices, indices.options()); + int64_t nbits = cuda::cub::get_num_bits(num_weights); + cuda::cub::sort_pairs( + indices.data_ptr(), sorted_indices.data_ptr(), + range.data_ptr(), orig_indices.data_ptr(), + num_indices, false/*, 0, nbits*/); if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key( - policy, - sorted_data, - sorted_data + num_indices, - thrust::make_constant_iterator(1), - count_data - ); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, - thrust::make_reverse_iterator(sorted_data + num_indices), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::equal_to(), - thrust::maximum() - ); + embedding_dense_backward_cuda_scan(sorted_indices, count); } }); @@ -340,23 +299,23 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, checkSameGPU("embedding_renorm", self_arg, indices_arg); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () { - using device_ptr = thrust::device_ptr; auto num_indices = indices.numel(); auto indices_contig = std::get<0>(indices.sort()).contiguous(); - auto indices_data = device_ptr(indices_contig.data_ptr()); - auto unique_indices = at::empty(indices.numel(), indices.options()); - auto unique_data = device_ptr(unique_indices.data_ptr()); - auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); - auto num_unique_indices = static_cast(end - unique_data); + auto num_unique_indices = at::empty({}, indices.options().dtype(kLong)); + + cuda::cub::unique( + indices_contig.data_ptr(), + unique_indices.data_ptr(), + num_unique_indices.data_ptr(), + num_indices + ); - dim3 grid(num_unique_indices); - dim3 block(128); + dim3 grid = num_unique_indices.item(); + dim3 block = 128; int dim = self.stride(0); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh index f06b850..c79bf83 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh @@ -10,10 +10,6 @@ #include #include -#include -#include -#include - #pragma once namespace at { diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 95ab33e..57654f2 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -218,9 +218,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List::max(), - "index_put of tensors larger than INT_MAX is not supported yet in pytorch"); - if (num_indices > 0 && sliceSize > 0) { const bool permuted = !src.is_contiguous(); auto src_ = permuted ? src.contiguous() : src; diff --git a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu index 582dc9e..446aa08 100644 --- a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu +++ b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace at { namespace native { @@ -30,4 +32,45 @@ void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_ thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp()); } +template +void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { + using device_ptr = thrust::device_ptr; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + auto num_indices = count.numel(); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key( + policy, + sorted_data, + sorted_data + num_indices, + thrust::make_constant_iterator(1), + count_data + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::equal_to(), + thrust::maximum() + ); +} + +template +void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); +template +void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); + }} diff --git a/aten/src/ATen/native/cuda/Randperm.cu b/aten/src/ATen/native/cuda/Randperm.cu index 4c5e16a..56b8eb2 100644 --- a/aten/src/ATen/native/cuda/Randperm.cu +++ b/aten/src/ATen/native/cuda/Randperm.cu @@ -47,8 +47,6 @@ template struct alignas(N) OpaqueType { char data[N]; }; Tensor& randperm_out_cuda(int64_t n, c10::optional generator, Tensor& result) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); - TORCH_CHECK(n <= std::numeric_limits::max(), - "randperm of tensors larger than INT_MAX is not supported yet in pytorch"); check_supported_max_int_with_precision(n, result); diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index 1b9619b..eb31fd2 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -94,13 +94,7 @@ std::tuple compute_unique( Tensor length = at::empty({1}, options); int64_t num_out; if (!return_counts) { - CUB_WRAPPER( - cub::DeviceSelect::Unique, - data, - data_out.data_ptr(), - length.data_ptr(), - num_inp, - stream); + cuda::cub::unique(data, data_out.data_ptr(), length.data_ptr(), num_inp); num_out = length.item(); } else { counts.resize_(num_inp); @@ -135,11 +129,6 @@ std::tuple unique_cuda_template( auto options = self.options().dtype(kLong); int64_t num_inp = self.numel(); - TORCH_CHECK( - num_inp <= INT_MAX, - "num_inp ", - num_inp, - " is too big to be handled by cub"); Tensor sorted; Tensor self_c = self.contiguous(); if (consecutive) { diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 90024de..e0d09b7 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -2775,6 +2775,14 @@ new_module_tests = [ check_gradgrad=False, ), dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous' + ), + dict( module_name='EmbeddingBag', constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',