From 447d74a0747be7686dc44bb9db57ca45954fc865 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 9 Apr 2019 18:09:01 -0700 Subject: [PATCH] EmbeddingBag w/ differentiable per_sample_weights (#18957) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18957 ghimport-source-id: 7396ca08b137ea40f04285764a9d9a6d4f19227e Reviewed By: cpuhrsch Differential Revision: D14856526 Pulled By: zou3519 fbshipit-source-id: 949faea219c7c02ad981b1db610a477194d3f5c9 --- aten/src/ATen/native/EmbeddingBag.cpp | 77 ++++++++++++++++++++++++++- aten/src/ATen/native/cuda/EmbeddingBag.cu | 84 ++++++++++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 5 ++ aten/src/TH/THBlasUtils.h | 11 ++++ test/test_nn.py | 33 +++++++++--- tools/autograd/derivatives.yaml | 2 +- 6 files changed, 201 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 526d6b3..9fa2bdb 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -1,6 +1,7 @@ #include -#include #include +#include +#include #include @@ -461,6 +462,78 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi return index_grad_weight; } +template +Tensor _embedding_bag_per_sample_weights_backward_cpu_template( + const Tensor& grad, + const Tensor& weight, // NB: embedding table, not per_sample_weights + const Tensor& indices, + const Tensor& offset2bag, + int64_t mode) { + AT_CHECK( + mode == MODE_SUM, + "embedding_bag_backward: per_sample_weights only supported for mode='sum'"); + + AT_ASSERT(grad.dim() == 2) + auto embedding_features = grad.size(1); + + AT_ASSERT(indices.dim() == 1); + auto num_samples = indices.size(0); + + AT_ASSERT(weight.dim() == 2); + AT_ASSERT(weight.size(1) == embedding_features); + + auto output = at::zeros({num_samples}, grad.options()); + + auto indices_arg = TensorArg(indices, "indices", 1); + checkScalarType("embedding_bag", indices_arg, kLong); + checkContiguous("embedding_bag", indices_arg); + auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); + checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkContiguous("embedding_bag", offset2bag_arg); + + auto grad_data = grad.data(); + auto grad_stride0 = grad.stride(0); + auto grad_stride1 = grad.stride(1); + + auto weight_data = weight.data(); + auto weight_stride0 = weight.stride(0); + auto weight_stride1 = weight.stride(1); + + auto indices_data = indices.data(); + + // The following are contiguous + auto output_data = output.data(); + auto offset2bag_data = offset2bag.data(); + + // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. + parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) { + for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) { + auto bag_idx = offset2bag_data[sample_idx]; + auto embedding_idx = indices_data[sample_idx]; + + output_data[sample_idx] = THBlas_dot( + embedding_features, + grad_data + grad_stride0 * bag_idx, grad_stride1, + weight_data + weight_stride0 * embedding_idx, weight_stride1); + } + }); + return output; +} + +Tensor _embedding_bag_per_sample_weights_backward_cpu( + const Tensor& grad, + const Tensor& weight, // NB: embedding table, not per_sample_weights + const Tensor& indices, + const Tensor& offset2bag, + int64_t mode) { + return AT_DISPATCH_FLOATING_TYPES( + grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu", [&]() { + return _embedding_bag_per_sample_weights_backward_cpu_template( + grad, weight, indices, offset2bag, mode); + } + ); +} + Tensor _embedding_bag_sparse_backward( const Tensor &grad_, const Tensor &indices, const Tensor &offsets, const Tensor &offset2bag, const Tensor &bag_size_, int64_t num_weights, @@ -476,7 +549,7 @@ Tensor _embedding_bag_sparse_backward( offset2bag, bag_size_); if (per_sample_weights.defined()) { AT_ASSERT(mode == MODE_SUM); - index_grad *= per_sample_weights.unsqueeze(1); + index_grad.mul_(per_sample_weights.unsqueeze(1)); } return native::embedding_backward(index_grad, indices, num_weights, -1, scale_grad_by_freq, true); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index fa3bf2b..af19f8a 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -423,5 +423,89 @@ Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &ind } } +template +__inline__ __device__ +static scalar_t warpReduceSum(scalar_t val) { + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) + val += WARP_SHFL_DOWN(val, offset); + return val; +} + +template +__global__ static void _embedding_bag_per_sample_weights_backward_kernel( + const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1, + const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1, + const int64_t* indices, // contiguous + const int64_t* offset2bag, // contiguous + int64_t num_samples, + int64_t embedding_features, + scalar_t* output) { + using accscalar_t = acc_type; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + const int warp = idx / WARP_SIZE; + const int thread_in_warp = idx % WARP_SIZE; + const int num_warps = blockDim.x * gridDim.x / WARP_SIZE; + + // Each warp is responsible for the accumulation of one sample. + // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx]. + for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) { + accscalar_t result = 0.; + const int bag_idx = (int)offset2bag[sample_idx]; + const int embedding_idx = (int)indices[sample_idx]; + for (int feature_idx = thread_in_warp; feature_idx < embedding_features; + feature_idx += WARP_SIZE) { + result += + grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] * + weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx]; + } + result = warpReduceSum(result); + if (thread_in_warp == 0) { + output[sample_idx] = result; + } + } +} + +Tensor _embedding_bag_per_sample_weights_backward_cuda( + const Tensor& grad, + const Tensor& weight, // NB: embedding table, not per_sample_weights + const Tensor& indices, + const Tensor& offset2bag, + int64_t mode) { + AT_CHECK( + mode == MODE_SUM, + "embedding_bag_backward: per_sample_weights only supported for mode='sum'"); + + AT_ASSERT(grad.dim() == 2) + auto embedding_features = grad.size(1); + + AT_ASSERT(indices.dim() == 1); + auto num_samples = indices.size(0); + + AT_ASSERT(weight.dim() == 2); + AT_ASSERT(weight.size(1) == embedding_features); + + const int threads_per_block = 1024; + const int warps_per_block = threads_per_block / WARP_SIZE; + + dim3 block(threads_per_block); + dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); + + auto output = at::empty({num_samples}, grad.options()); + AT_DISPATCH_FLOATING_TYPES( + grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { + _embedding_bag_per_sample_weights_backward_kernel + <<>>( + grad.data(), grad.stride(0), grad.stride(1), + weight.data(), weight.stride(0), weight.stride(1), + indices.data(), + offset2bag.data(), + num_samples, + embedding_features, + output.data()); + } + ); + return output; +} + } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6cd5c20..43ece56 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -843,6 +843,11 @@ CPU: _embedding_bag_dense_backward_cpu CUDA: _embedding_bag_dense_backward_cuda +- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offset2bag, int mode) -> Tensor + dispatch: + CPU: _embedding_bag_per_sample_weights_backward_cpu + CUDA: _embedding_bag_per_sample_weights_backward_cuda + - func: empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor matches_jit_signature: True cpu_half: True diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index 86efc40..ffc3982 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -30,3 +30,14 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); } AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(COPY_SPECIALIZATION) + +template +inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); + +#define DOT_SPECIALIZATION(ctype,name,_1) \ + template<> \ + inline ctype THBlas_dot(int64_t n, ctype *x, int64_t incx, ctype *y, int64_t incy) { \ + return TH ## name ## Blas_dot(n, x, incx, y, incy); \ + } + +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(DOT_SPECIALIZATION) diff --git a/test/test_nn.py b/test/test_nn.py index 79391ab..2ac6188 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2326,6 +2326,7 @@ class TestNN(NNTestCase): device='cpu', dtype=torch.float, test_per_sample_weights=False, + trainable_per_sample_weights=False, sparse=False, test_backward=True, backward_prec=None): @@ -2340,14 +2341,18 @@ class TestNN(NNTestCase): # To prevent large gradients, weights should sum to 1 for each bag per_sample_weights = \ torch.randn(B, L, device=device, dtype=dtype).softmax(dim=-1) + per_sample_weights_reference = \ + per_sample_weights.clone().requires_grad_(trainable_per_sample_weights) + per_sample_weights.requires_grad_(trainable_per_sample_weights) output = es(input.view(-1), offsets, per_sample_weights.view(-1)) else: output = es(input.view(-1), offsets) per_sample_weights = None + per_sample_weights_reference = None if mode == 'sum': if test_per_sample_weights: - ref_output = (e(input) * per_sample_weights.unsqueeze(-1)).sum(1) + ref_output = (e(input) * per_sample_weights_reference.unsqueeze(-1)).sum(1) else: ref_output = e(input).sum(1) elif mode == 'mean': @@ -2375,6 +2380,9 @@ class TestNN(NNTestCase): needed_prec = backward_prec self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) + if test_per_sample_weights and trainable_per_sample_weights: + self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad) + def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double): # check a known test example device = torch.device("cuda") if cuda else torch.device("cpu") @@ -2630,17 +2638,20 @@ class TestNN(NNTestCase): @staticmethod def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'): - def test_per_sample_weights(mode, dtype): + def test_per_sample_weights(mode, dtype, trainable_scale): es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device) es.weight.data.copy_( torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) - per_sample_weights = torch.randn_like(input, dtype=dtype) + per_sample_weights = torch.randn_like(input, dtype=dtype) \ + .requires_grad_(trainable_scale) + ref_per_sample_weights = \ + per_sample_weights.detach().requires_grad_(trainable_scale) reference_weights = es.weight.detach().requires_grad_() expected = self._embedding_bag_reference_impl( - input, reference_weights, offsets, mode, per_sample_weights) + input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) self.assertEqual(result, expected) @@ -2648,11 +2659,14 @@ class TestNN(NNTestCase): result.backward(grad) expected.backward(grad) self.assertEqual(es.weight.grad, reference_weights.grad) + if trainable_scale: + self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad) dtypes = (torch.float, torch.double) modes = ('sum',) - for dtype, mode in itertools.product(dtypes, modes): - test_per_sample_weights(mode, dtype) + trainable_scale = (True, False) + for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): + test_per_sample_weights(mode, dtype, trainable) def test_EmbeddingBag_per_sample_weights_and_offsets(self): self._test_EmbeddingBag_per_sample_weights_and_offsets(self) @@ -2666,9 +2680,12 @@ class TestNN(NNTestCase): dtypes = (torch.float, torch.double) modes = ('sum',) sparsity = (True, False) - for dtype, mode, sparse in itertools.product(dtypes, modes, sparsity): + trainable_scale = (True, False) + for dtype, mode, sparse, trainable_per_sample_weights in \ + itertools.product(dtypes, modes, sparsity, trainable_scale): kwargs = dict(test_per_sample_weights=True, device=device, - mode=mode, dtype=dtype, sparse=sparse) + mode=mode, dtype=dtype, sparse=sparse, + trainable_per_sample_weights=trainable_per_sample_weights) # Simple case self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 00c7a3b..4d7b568 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -955,7 +955,7 @@ indices: not_differentiable offsets: not_differentiable weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights) - per_sample_weights: not_differentiable # TODO(rzou): See issue #4068 + per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, result1, mode) - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) indices: not_differentiable -- 2.7.4