From ce6fe50158c631f7f54b6df3ac91632cd41d48ea Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 19 Aug 2021 13:00:08 -0700 Subject: [PATCH] Revert embedding thrust->cub migration (#63451) Summary: Fixes https://github.com/pytorch/pytorch/issues/63427 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63451 Reviewed By: mruberry Differential Revision: D30398482 Pulled By: ngimel fbshipit-source-id: e153786d204215555a6571688eabae712facad7e --- aten/src/ATen/cuda/cub.cuh | 19 +---- aten/src/ATen/native/cuda/Embedding.cu | 85 ++++++++++++++++------ .../ATen/native/cuda/EmbeddingBackwardKernel.cuh | 4 + aten/src/ATen/native/cuda/Indexing.cu | 3 + aten/src/ATen/native/cuda/LegacyThrustHelpers.cu | 43 ----------- aten/src/ATen/native/cuda/Randperm.cu | 2 + aten/src/ATen/native/cuda/UniqueCub.cu | 13 +++- 7 files changed, 86 insertions(+), 83 deletions(-) diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index 38e5852..62da28d 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -3,7 +3,6 @@ #include #include #include -#include // include cub in a safe manner, see: // https://github.com/pytorch/pytorch/pull/55292 @@ -103,8 +102,6 @@ static inline void sort_keys( const key_t *keys_in, key_t *keys_out, int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; const key_t_ *keys_in_ = reinterpret_cast(keys_in); @@ -127,8 +124,6 @@ static inline void sort_pairs( const value_t *values_in, value_t *values_out, int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; auto allocator = c10::cuda::CUDACachingAllocator::get(); @@ -161,10 +156,6 @@ static inline void segmented_sort_pairs( OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 ) { - TORCH_CHECK(num_elements <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); - TORCH_CHECK(num_segments <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); using key_t_ = typename detail::cuda_type::type; auto allocator = c10::cuda::CUDACachingAllocator::get(); @@ -314,12 +305,4 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT } } -template -inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) { - TORCH_CHECK(num_items <= std::numeric_limits::max(), - "cub unique does not support more than INT_MAX elements"); - CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique, - input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); -} - -}}} // namespace at::cuda::cub +}}} diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 100ffbd..10a42b8 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -7,9 +7,12 @@ #include #include +#include #include -#include +#include +#include +#include #include #include @@ -221,9 +224,6 @@ __global__ void renorm_kernel( } // anonymous namespace -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); - Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { @@ -272,16 +272,59 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); Tensor count; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { - auto range = at::arange(num_indices, indices.options()); - int64_t nbits = cuda::cub::get_num_bits(num_weights); - cuda::cub::sort_pairs( - indices.data_ptr(), sorted_indices.data_ptr(), - range.data_ptr(), orig_indices.data_ptr(), - num_indices, false, 0, nbits); + using device_ptr = thrust::device_ptr; + + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + sorted_indices.copy_(indices); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Fill sortedOrigIndices with sequential indices + auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.data_ptr()); + thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); + + // Sort; a stable sort is not required + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, + LTOp()); + } if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - embedding_dense_backward_cuda_scan(sorted_indices, count); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key( + policy, + sorted_data, + sorted_data + num_indices, + thrust::make_constant_iterator(1), + count_data + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::equal_to(), + thrust::maximum() + ); } }); @@ -297,23 +340,23 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, checkSameGPU("embedding_renorm", self_arg, indices_arg); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () { + using device_ptr = thrust::device_ptr; auto num_indices = indices.numel(); auto indices_contig = std::get<0>(indices.sort()).contiguous(); - auto unique_indices = at::empty(indices.numel(), indices.options()); - auto num_unique_indices = at::empty({}, indices.options().dtype(kLong)); + auto indices_data = device_ptr(indices_contig.data_ptr()); - cuda::cub::unique( - indices_contig.data_ptr(), - unique_indices.data_ptr(), - num_unique_indices.data_ptr(), - num_indices - ); + auto unique_indices = at::empty(indices.numel(), indices.options()); + auto unique_data = device_ptr(unique_indices.data_ptr()); + auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); + auto num_unique_indices = static_cast(end - unique_data); - dim3 grid = num_unique_indices.item(); - dim3 block = 128; + dim3 grid(num_unique_indices); + dim3 block(128); int dim = self.stride(0); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh index c79bf83..f06b850 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh @@ -10,6 +10,10 @@ #include #include +#include +#include +#include + #pragma once namespace at { diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 57654f2..95ab33e 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -218,6 +218,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List::max(), + "index_put of tensors larger than INT_MAX is not supported yet in pytorch"); + if (num_indices > 0 && sliceSize > 0) { const bool permuted = !src.is_contiguous(); auto src_ = permuted ? src.contiguous() : src; diff --git a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu index 446aa08..582dc9e 100644 --- a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu +++ b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu @@ -5,8 +5,6 @@ #include #include #include -#include -#include namespace at { namespace native { @@ -32,45 +30,4 @@ void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_ thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp()); } -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { - using device_ptr = thrust::device_ptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - auto num_indices = count.numel(); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key( - policy, - sorted_data, - sorted_data + num_indices, - thrust::make_constant_iterator(1), - count_data - ); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, - thrust::make_reverse_iterator(sorted_data + num_indices), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::equal_to(), - thrust::maximum() - ); -} - -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); -template -void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); - }} diff --git a/aten/src/ATen/native/cuda/Randperm.cu b/aten/src/ATen/native/cuda/Randperm.cu index 56b8eb2..4c5e16a 100644 --- a/aten/src/ATen/native/cuda/Randperm.cu +++ b/aten/src/ATen/native/cuda/Randperm.cu @@ -47,6 +47,8 @@ template struct alignas(N) OpaqueType { char data[N]; }; Tensor& randperm_out_cuda(int64_t n, c10::optional generator, Tensor& result) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); + TORCH_CHECK(n <= std::numeric_limits::max(), + "randperm of tensors larger than INT_MAX is not supported yet in pytorch"); check_supported_max_int_with_precision(n, result); diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index eb31fd2..1b9619b 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -94,7 +94,13 @@ std::tuple compute_unique( Tensor length = at::empty({1}, options); int64_t num_out; if (!return_counts) { - cuda::cub::unique(data, data_out.data_ptr(), length.data_ptr(), num_inp); + CUB_WRAPPER( + cub::DeviceSelect::Unique, + data, + data_out.data_ptr(), + length.data_ptr(), + num_inp, + stream); num_out = length.item(); } else { counts.resize_(num_inp); @@ -129,6 +135,11 @@ std::tuple unique_cuda_template( auto options = self.options().dtype(kLong); int64_t num_inp = self.numel(); + TORCH_CHECK( + num_inp <= INT_MAX, + "num_inp ", + num_inp, + " is too big to be handled by cub"); Tensor sorted; Tensor self_c = self.contiguous(); if (consecutive) { -- 2.7.4