#include <cstddef>
#include <type_traits>
#include <iterator>
+#include <limits>
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
const key_t *keys_in, key_t *keys_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
+ TORCH_CHECK(n <= std::numeric_limits<int>::max(),
+ "cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
const value_t *values_in, value_t *values_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
+ TORCH_CHECK(n <= std::numeric_limits<int>::max(),
+ "cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
+ TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
+ "cub sort does not support sorting more than INT_MAX elements");
+ TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
+ "cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
}
}
-}}}
+template<typename InputIteratorT , typename OutputIteratorT , typename NumSelectedIteratorT >
+inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) {
+ TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
+ "cub unique does not support more than INT_MAX elements");
+ CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique,
+ input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
+}
+
+}}} // namespace at::cuda::cub
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCThrustAllocator.cuh>
#include <THC/THCReduceApplyUtils.cuh>
-#include <thrust/execution_policy.h>
-#include <thrust/iterator/constant_iterator.h>
-#include <thrust/unique.h>
+#include <ATen/cuda/cub.cuh>
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
} // anonymous namespace
-Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices,
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
+
+Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
int64_t num_weights, int64_t padding_idx,
bool scale_grad_by_freq) {
auto grad_arg = TensorArg(grad_, "grad", 1);
- auto indices_arg = TensorArg(indices, "indices", 1);
+ auto indices_arg = TensorArg(indices_, "indices", 1);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
checkSameGPU("embedding_backward", grad_arg, indices_arg);
+ auto indices = indices_.contiguous();
+
auto num_indices = indices.numel();
auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor count;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
- using device_ptr = thrust::device_ptr<index_t>;
-
- // Sort the inputs into sorted with the corresponding indices; we
- // don't need a stable or multidimensional sort, so just use Thrust
- // directly
- {
- sorted_indices.copy_(indices);
-
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
- auto policy = thrust::cuda::par(allocator).on(stream);
-
- // Fill sortedOrigIndices with sequential indices
- auto count_iter = thrust::counting_iterator<index_t>(0);
- auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
- thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
-
- // Sort; a stable sort is not required
- auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
- thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
- LTOp<index_t>());
- }
+ auto range = at::arange(num_indices, indices.options());
+ int64_t nbits = cuda::cub::get_num_bits(num_weights);
+ cuda::cub::sort_pairs(
+ indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
+ range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
+ num_indices, false/*, 0, nbits*/);
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
- auto policy = thrust::cuda::par(allocator).on(stream);
-
- // Compute an increasing sequence per unique item in sortedIndices:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 1 2 3 1 2 1 1 2
- auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
- auto count_data = device_ptr(count.data_ptr<index_t>());
- thrust::inclusive_scan_by_key(
- policy,
- sorted_data,
- sorted_data + num_indices,
- thrust::make_constant_iterator(1),
- count_data
- );
-
- // Take the maximum of each count per unique key in reverse:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 3 3 3 2 2 1 2 2
- thrust::inclusive_scan_by_key(
- policy,
- thrust::make_reverse_iterator(sorted_data + num_indices),
- thrust::make_reverse_iterator(sorted_data),
- thrust::make_reverse_iterator(count_data + num_indices),
- thrust::make_reverse_iterator(count_data + num_indices),
- thrust::equal_to<index_t>(),
- thrust::maximum<index_t>()
- );
+ embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
}
});
checkSameGPU("embedding_renorm", self_arg, indices_arg);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
- auto policy = thrust::cuda::par(allocator).on(stream);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
- using device_ptr = thrust::device_ptr<index_t>;
auto num_indices = indices.numel();
auto indices_contig = std::get<0>(indices.sort()).contiguous();
- auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
-
auto unique_indices = at::empty(indices.numel(), indices.options());
- auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
- auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
- auto num_unique_indices = static_cast<int>(end - unique_data);
+ auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
+
+ cuda::cub::unique(
+ indices_contig.data_ptr<index_t>(),
+ unique_indices.data_ptr<index_t>(),
+ num_unique_indices.data_ptr<int64_t>(),
+ num_indices
+ );
- dim3 grid(num_unique_indices);
- dim3 block(128);
+ dim3 grid = num_unique_indices.item<int64_t>();
+ dim3 block = 128;
int dim = self.stride(0);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
-#include <thrust/execution_policy.h>
-#include <thrust/unique.h>
-#include <thrust/device_vector.h>
-
#pragma once
namespace at {
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self, indices, !unsafe);
int64_t num_indices = linearIndex.numel();
- TORCH_CHECK(num_indices <= std::numeric_limits<int>::max(),
- "index_put of tensors larger than INT_MAX is not supported yet in pytorch");
-
if (num_indices > 0 && sliceSize > 0) {
const bool permuted = !src.is_contiguous();
auto src_ = permuted ? src.contiguous() : src;
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
+#include <thrust/unique.h>
+#include <thrust/device_ptr.h>
namespace at { namespace native {
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
}
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
+ using device_ptr = thrust::device_ptr<index_t>;
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+ auto policy = thrust::cuda::par(allocator).on(stream);
+
+ auto num_indices = count.numel();
+
+ // Compute an increasing sequence per unique item in sortedIndices:
+ // sorted: 2 5 5 5 7 7 8 9 9
+ // count: 1 1 2 3 1 2 1 1 2
+ auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
+ auto count_data = device_ptr(count.data_ptr<index_t>());
+ thrust::inclusive_scan_by_key(
+ policy,
+ sorted_data,
+ sorted_data + num_indices,
+ thrust::make_constant_iterator(1),
+ count_data
+ );
+
+ // Take the maximum of each count per unique key in reverse:
+ // sorted: 2 5 5 5 7 7 8 9 9
+ // count: 1 3 3 3 2 2 1 2 2
+ thrust::inclusive_scan_by_key(
+ policy,
+ thrust::make_reverse_iterator(sorted_data + num_indices),
+ thrust::make_reverse_iterator(sorted_data),
+ thrust::make_reverse_iterator(count_data + num_indices),
+ thrust::make_reverse_iterator(count_data + num_indices),
+ thrust::equal_to<index_t>(),
+ thrust::maximum<index_t>()
+ );
+}
+
+template
+void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
+template
+void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
+
}}
Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor& result) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
- TORCH_CHECK(n <= std::numeric_limits<int>::max(),
- "randperm of tensors larger than INT_MAX is not supported yet in pytorch");
check_supported_max_int_with_precision(n, result);
Tensor length = at::empty({1}, options);
int64_t num_out;
if (!return_counts) {
- CUB_WRAPPER(
- cub::DeviceSelect::Unique,
- data,
- data_out.data_ptr<scalar_t>(),
- length.data_ptr<int64_t>(),
- num_inp,
- stream);
+ cuda::cub::unique(data, data_out.data_ptr<scalar_t>(), length.data_ptr<int64_t>(), num_inp);
num_out = length.item<int64_t>();
} else {
counts.resize_(num_inp);
auto options = self.options().dtype(kLong);
int64_t num_inp = self.numel();
- TORCH_CHECK(
- num_inp <= INT_MAX,
- "num_inp ",
- num_inp,
- " is too big to be handled by cub");
Tensor sorted;
Tensor self_c = self.contiguous();
if (consecutive) {
check_gradgrad=False,
),
dict(
+ module_name='Embedding',
+ constructor_args=(4, 3),
+ cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+ input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+ check_gradgrad=False,
+ desc='discontiguous'
+ ),
+ dict(
module_name='EmbeddingBag',
constructor_args=(4, 3),
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',