#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
-#include <thrust/execution_policy.h>
-#include <thrust/unique.h>
-#include <thrust/iterator/constant_iterator.h>
-#include <thrust/device_vector.h>
-
+#include <ATen/cuda/cub.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
namespace at {
namespace native {
+template<typename index_t>
+void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
+
namespace {
constexpr int MODE_SUM = 0;
}
}
-
-
Tensor embedding_bag_backward_cuda_sum_avg(
const Tensor &grad,
- const Tensor &indices,
+ const Tensor &indices_,
const Tensor &offset2bag,
const Tensor &bag_size,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
const Tensor& per_sample_weights,
int64_t padding_idx) {
+ auto indices = indices_.contiguous();
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- ptrdiff_t numel = indices.numel();
+ ptrdiff_t num_indices = indices.numel();
- if (numel == 0) {
+ if (num_indices == 0) {
// all empty bags
return at::zeros({num_weights, grad.size(1)}, grad.options());
}
Tensor count;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
- using device_ptr = thrust::device_ptr<index_t>;
-
- // Sort the inputs into sorted with the corresponding indices; we
- // don't need a stable or multidimensional sort, so just use Thrust
- // directly
- {
- sorted_indices.copy_(indices);
-
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
- auto policy = thrust::cuda::par(allocator).on(stream);
-
- // Fill sortedOrigIndices with sequential indices
- auto count_iter = thrust::counting_iterator<index_t>(0);
- auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
- thrust::copy(policy, count_iter, count_iter + numel, orig_data);
-
- // Sort; a stable sort is not required
- auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
- thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
- LTOp<index_t>());
- }
+ auto range = at::arange(num_indices, indices.options());
+ int64_t nbits = cuda::cub::get_num_bits(num_weights);
+ cuda::cub::sort_pairs(
+ indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
+ range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
+ num_indices, false/*, 0, nbits*/);
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
- auto policy = thrust::cuda::par(allocator).on(stream);
-
- // Compute an increasing sequence per unique item in sortedIndices:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 1 2 3 1 2 1 1 2
- auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
- auto count_data = device_ptr(count.data_ptr<index_t>());
- thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
- thrust::make_constant_iterator(1),
- count_data);
-
- // Take the maximum of each count per unique key in reverse:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 3 3 3 2 2 1 2 2
- thrust::inclusive_scan_by_key(
- policy, thrust::make_reverse_iterator(sorted_data + numel),
- thrust::make_reverse_iterator(sorted_data),
- thrust::make_reverse_iterator(count_data + numel),
- thrust::make_reverse_iterator(count_data + numel),
- thrust::equal_to<index_t>(), thrust::maximum<index_t>());
+ embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
}
});
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,