From: Richard Zou Date: Wed, 10 Apr 2019 01:08:59 +0000 (-0700) Subject: EmbeddingBag w/ per_sample_weights CUDA fwd + bwd (#18800) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~299 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c889ff6cf8c18017490c7ed0e99a524fc5ff7422;p=platform%2Fupstream%2Fpytorch.git EmbeddingBag w/ per_sample_weights CUDA fwd + bwd (#18800) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18800 ghimport-source-id: 17f638dea0e1ac9a86ec06b223c60362ed78449c Reviewed By: cpuhrsch Differential Revision: D14851422 Pulled By: zou3519 fbshipit-source-id: 27b114e51e66112e4bc9cfc63d1d1ddfa650d347 --- diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index ba062b4..526d6b3 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -289,11 +289,6 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, checkScalarType("embedding_bag", offset2bag_arg, kLong); checkContiguous("embedding_bag", offset2bag_arg); - if (per_sample_weights.defined() && - per_sample_weights.device().type() != DeviceType::CPU) { - AT_ERROR("NYI: _embedding_bag_backward: per_sample_weights only supported for CPU"); - } - if (sparse) { return at::_embedding_bag_sparse_backward( grad, indices, offsets, offset2bag, bag_size_, num_weights, diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 043a981..fa3bf2b 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -24,13 +24,15 @@ namespace native { namespace { -// This kernel assumes that all input tensors except `weight` are contiguous. +// This kernel assumes that all input tensors except `weight` and +// per_sample_weights are contiguous. template __global__ void EmbeddingBag_updateOutputKernel( int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output, int64_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t featureSize, int64_t weight_stide0, int64_t weight_stride1, - int mode, int64_t *bag_size, int64_t *max_indices) { + int mode, int64_t *bag_size, int64_t *max_indices, + scalar_t* per_sample_weights, int64_t per_sample_weights_stride) { // the strategy here is that each bag x feature is handled by a single thread @@ -64,7 +66,13 @@ __global__ void EmbeddingBag_updateOutputKernel( maxWord = input[emb]; } } else { - weightFeatSum += static_cast(weightValue); + if (per_sample_weights) { + accscalar_t scaleWeightBy = static_cast( + per_sample_weights[emb * per_sample_weights_stride]); + weightFeatSum += scaleWeightBy * static_cast(weightValue); + } else { + weightFeatSum += static_cast(weightValue); + } } bag_size_++; @@ -106,7 +114,8 @@ template __global__ void EmbeddingBag_accGradParametersKernel_sum_avg( int64_t *input, int64_t *indices, scalar_t *gradOutput, scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel, - int64_t stride, int mode, const int64_t *bag_size) { + int64_t stride, int mode, const int64_t *bag_size, + scalar_t* per_sample_weights, int64_t per_sample_weights_stride) { using accscalar_t = acc_type; int idx = blockIdx.x * 4 + threadIdx.y; @@ -134,7 +143,10 @@ __global__ void EmbeddingBag_accGradParametersKernel_sum_avg( const int seq_number = offset2bag[origRow]; const int gradOutputRow = ((int)seq_number) * stride; - const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; + accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; + if (per_sample_weights) { + scale *= per_sample_weights[origRow * per_sample_weights_stride]; + } accscalar_t gradient[SZ]; accscalar_t weight[SZ]; @@ -179,7 +191,8 @@ Tensor embedding_bag_backward_cuda_sum_avg( const Tensor &offset2bag, const Tensor &bag_size, int64_t num_weights, - bool scale_grad_by_freq, int64_t mode) { + bool scale_grad_by_freq, int64_t mode, + const Tensor& per_sample_weights) { auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options()); @@ -255,7 +268,9 @@ Tensor embedding_bag_backward_cuda_sum_avg( grad.data(), grad_weight.data(), offset2bag.data(), count.defined() ? count.data() : nullptr, numel, stride, - mode, bag_size.data()); + mode, bag_size.data(), + per_sample_weights.defined() ? per_sample_weights.data() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); }); THCudaCheck(cudaGetLastError()); @@ -331,9 +346,6 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); - AT_CHECK(!per_sample_weights.defined(), - "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)"); - int64_t numIndices = indices.size(0); int64_t numBags = offsets.size(0); int64_t featureSize = weight.size(1); @@ -363,7 +375,9 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, weight.data(), output.data(), offset2bag.data(), numIndices, numBags, featureSize, weight.stride(0), weight.stride(1), mode, bag_size.data(), - mode == MODE_MAX ? max_indices.data() : NULL); + mode == MODE_MAX ? max_indices.data() : NULL, + per_sample_weights.defined() ? per_sample_weights.data() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); }); THCudaCheck(cudaGetLastError()); @@ -391,14 +405,16 @@ Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &ind checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg); checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg); - AT_ASSERT(!per_sample_weights.defined()); switch (mode) { case MODE_SUM: case MODE_MEAN: - return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode); + if (mode == MODE_MEAN) + AT_ASSERT(!per_sample_weights.defined()); + return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights); case MODE_MAX: + AT_ASSERT(!per_sample_weights.defined()); return embedding_bag_backward_cuda_max(grad, max_indices, num_weights); default: diff --git a/test/test_nn.py b/test/test_nn.py index 478f6e5..79391ab 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2589,8 +2589,12 @@ class TestNN(NNTestCase): input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) - with self.assertRaisesRegex(RuntimeError, 'have the same type as'): - es(input, offsets, per_sample_weights) + if device == 'cpu': + with self.assertRaisesRegex(RuntimeError, 'have the same type as'): + es(input, offsets, per_sample_weights) + else: + with self.assertRaisesRegex(RuntimeError, 'expected scalar type'): + es(input, offsets, per_sample_weights) # Failure 2.1: input/per_sample_weights have different sizes (1d input) input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) @@ -2620,6 +2624,10 @@ class TestNN(NNTestCase): def test_EmbeddingBag_per_sample_weights_failures(self): self._test_EmbeddingBag_per_sample_weights_failures(self) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_EmbeddingBag_per_sample_weights_failures_cuda(self): + self._test_EmbeddingBag_per_sample_weights_failures(self, device='cuda') + @staticmethod def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'): def test_per_sample_weights(mode, dtype): @@ -2649,6 +2657,10 @@ class TestNN(NNTestCase): def test_EmbeddingBag_per_sample_weights_and_offsets(self): self._test_EmbeddingBag_per_sample_weights_and_offsets(self) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_EmbeddingBag_per_sample_weights_and_offsets_cuda(self): + self._test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cuda') + @staticmethod def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'): dtypes = (torch.float, torch.double) @@ -2674,6 +2686,10 @@ class TestNN(NNTestCase): self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_EmbeddingBag_per_sample_weights_and_no_offsets_cuda(self): + self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cuda') + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @repeat_test_for_types(ALL_TENSORTYPES) def test_embedding_bag_cuda(self, dtype=torch.float): self._test_EmbeddingBag(True, 'sum', False, dtype)