namespace {
-// This kernel assumes that all input tensors except `weight` are contiguous.
+// This kernel assumes that all input tensors except `weight` and
+// per_sample_weights are contiguous.
template <typename scalar_t>
__global__ void EmbeddingBag_updateOutputKernel(
int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
int64_t *offset2bag, int64_t numIndices, int64_t numBags,
int64_t featureSize, int64_t weight_stide0, int64_t weight_stride1,
- int mode, int64_t *bag_size, int64_t *max_indices) {
+ int mode, int64_t *bag_size, int64_t *max_indices,
+ scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
// the strategy here is that each bag x feature is handled by a single thread
maxWord = input[emb];
}
} else {
- weightFeatSum += static_cast<accscalar_t>(weightValue);
+ if (per_sample_weights) {
+ accscalar_t scaleWeightBy = static_cast<accscalar_t>(
+ per_sample_weights[emb * per_sample_weights_stride]);
+ weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
+ } else {
+ weightFeatSum += static_cast<accscalar_t>(weightValue);
+ }
}
bag_size_++;
__global__ void EmbeddingBag_accGradParametersKernel_sum_avg(
int64_t *input, int64_t *indices, scalar_t *gradOutput,
scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
- int64_t stride, int mode, const int64_t *bag_size) {
+ int64_t stride, int mode, const int64_t *bag_size,
+ scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
using accscalar_t = acc_type<scalar_t, true>;
int idx = blockIdx.x * 4 + threadIdx.y;
const int seq_number = offset2bag[origRow];
const int gradOutputRow = ((int)seq_number) * stride;
- const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+ accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+ if (per_sample_weights) {
+ scale *= per_sample_weights[origRow * per_sample_weights_stride];
+ }
accscalar_t gradient[SZ];
accscalar_t weight[SZ];
const Tensor &offset2bag,
const Tensor &bag_size,
int64_t num_weights,
- bool scale_grad_by_freq, int64_t mode) {
+ bool scale_grad_by_freq, int64_t mode,
+ const Tensor& per_sample_weights) {
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
offset2bag.data<int64_t>(),
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
- mode, bag_size.data<int64_t>());
+ mode, bag_size.data<int64_t>(),
+ per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+ per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
});
THCudaCheck(cudaGetLastError());
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
- AT_CHECK(!per_sample_weights.defined(),
- "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)");
-
int64_t numIndices = indices.size(0);
int64_t numBags = offsets.size(0);
int64_t featureSize = weight.size(1);
weight.data<scalar_t>(), output.data<scalar_t>(),
offset2bag.data<int64_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data<int64_t>(),
- mode == MODE_MAX ? max_indices.data<int64_t>() : NULL);
+ mode == MODE_MAX ? max_indices.data<int64_t>() : NULL,
+ per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+ per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
});
THCudaCheck(cudaGetLastError());
checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg);
checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
- AT_ASSERT(!per_sample_weights.defined());
switch (mode) {
case MODE_SUM:
case MODE_MEAN:
- return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode);
+ if (mode == MODE_MEAN)
+ AT_ASSERT(!per_sample_weights.defined());
+ return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
case MODE_MAX:
+ AT_ASSERT(!per_sample_weights.defined());
return embedding_bag_backward_cuda_max(grad, max_indices, num_weights);
default:
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
- with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
- es(input, offsets, per_sample_weights)
+ if device == 'cpu':
+ with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
+ es(input, offsets, per_sample_weights)
+ else:
+ with self.assertRaisesRegex(RuntimeError, 'expected scalar type'):
+ es(input, offsets, per_sample_weights)
# Failure 2.1: input/per_sample_weights have different sizes (1d input)
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
def test_EmbeddingBag_per_sample_weights_failures(self):
self._test_EmbeddingBag_per_sample_weights_failures(self)
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_EmbeddingBag_per_sample_weights_failures_cuda(self):
+ self._test_EmbeddingBag_per_sample_weights_failures(self, device='cuda')
+
@staticmethod
def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'):
def test_per_sample_weights(mode, dtype):
def test_EmbeddingBag_per_sample_weights_and_offsets(self):
self._test_EmbeddingBag_per_sample_weights_and_offsets(self)
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_EmbeddingBag_per_sample_weights_and_offsets_cuda(self):
+ self._test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cuda')
+
@staticmethod
def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
dtypes = (torch.float, torch.double)
self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_EmbeddingBag_per_sample_weights_and_no_offsets_cuda(self):
+ self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cuda')
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_embedding_bag_cuda(self, dtype=torch.float):
self._test_EmbeddingBag(True, 'sum', False, dtype)