From: Richard Zou Date: Wed, 10 Apr 2019 01:08:59 +0000 (-0700) Subject: EmbeddingBag CPU forward with per_sample_weights. (#18735) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~301 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2a2007e5aca575eb495e42dbf4404451a351eeca;p=platform%2Fupstream%2Fpytorch.git EmbeddingBag CPU forward with per_sample_weights. (#18735) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18735 ghimport-source-id: d81bef54dafd7167d2451250d7be478d3c013920 Reviewed By: cpuhrsch Differential Revision: D14851415 Pulled By: zou3519 fbshipit-source-id: cea6039e760ad571b90f0a536e420498f34be325 --- diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 150fb25..677d7fe 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -39,6 +39,7 @@ static void index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output) { + AT_ASSERT(select_indices.numel() == add_indices.numel()); auto add_indices_data = add_indices.data(); auto select_indices_data = select_indices.data(); auto src_data = src.data(); @@ -49,6 +50,7 @@ static void index_select_add(const Tensor &select_indices, auto src_stride1 = src.stride(1); auto output_stride0 = output.stride(0); auto output_stride1 = output.stride(1); + for (int64_t i = 0; i < numel; i++) { THBlas_axpy(ddim, 1, src_data + src_stride0 * select_indices_data[i], src_stride1, @@ -56,6 +58,42 @@ static void index_select_add(const Tensor &select_indices, } } +// This function fuses the following three fns: +// index_select (using select_indices as the index) +// mul (scaling by per_sample_weights) +// index_add (using add_indices as the index) +template +static void index_select_scale_add(const Tensor &select_indices, + const Tensor &add_indices, + const Tensor &scale, + const Tensor &src, + Tensor &output) { + AT_ASSERT(select_indices.numel() == add_indices.numel()); + auto add_indices_data = add_indices.data(); + auto select_indices_data = select_indices.data(); + auto src_data = src.data(); + auto output_data = output.data(); + auto numel = add_indices.numel(); + int64_t ddim = src.size(1); + auto src_stride0 = src.stride(0); + auto src_stride1 = src.stride(1); + auto output_stride0 = output.stride(0); + auto output_stride1 = output.stride(1); + + auto* scale_data = scale.data(); + auto scale_stride = scale.stride(0); + + // XXX: We could make this faster via vectorization + for (int64_t i = 0; i < numel; i++) { + auto* src_base = src_data + src_stride0 * select_indices_data[i]; + auto* output_base = output_data + output_stride0 * add_indices_data[i]; + auto scale = scale_data[i * scale_stride]; + for (int64_t j = 0; j < ddim; j++) { + output_base[j * output_stride1] += src_base[j * src_stride1] * scale; + } + } +} + static void make_bag_size(const Tensor &offsets, const Tensor &indices, const int64_t mode, Tensor &bag_size) { if (mode == MODE_MEAN || mode == MODE_MAX) { @@ -110,7 +148,12 @@ static Tensor apply_bag_size_backward(const Tensor &offsets, template std::tuple embedding_bag_cpu_max( - const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) { + const Tensor& weight, + const Tensor& indices, + const Tensor& offset2bag, + const Tensor& output, + const Tensor& bag_size, + const Tensor& offsets) { auto max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.options()); @@ -132,11 +175,9 @@ std::tuple embedding_bag_cpu_max( auto bag = offset2bag_data[i]; auto word_idx = indices_data[i]; - for (int dim = 0; dim < dims; dim++) { auto& current_item = output_data[output_stride * bag + dim]; auto weight_item = weight_data[weight_stride0 * word_idx + dim * weight_stride1]; - bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; if (is_first_for_bag || weight_item > current_item) { @@ -155,9 +196,10 @@ std::tuple embedding_bag_cpu_max( std::tuple embedding_bag(const Tensor &weight, const Tensor &indices, const Tensor &offsets, const bool scale_grad_by_freq, - const int64_t mode, bool sparse) { + const int64_t mode, bool sparse, + const Tensor &per_sample_weights) { return at::_embedding_bag(weight, indices.contiguous(), offsets.contiguous(), - scale_grad_by_freq, mode, sparse); + scale_grad_by_freq, mode, sparse, per_sample_weights); }; // Assumes all input tensors except for `weight` are contiguous. @@ -165,7 +207,8 @@ embedding_bag(const Tensor &weight, const Tensor &indices, std::tuple _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, const Tensor &offsets, const bool scale_grad_by_freq, - const int64_t mode, bool sparse) { + const int64_t mode, bool sparse, + const Tensor &per_sample_weights) { auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag", indices_arg, kLong); auto offsets_arg = TensorArg(offsets, "offsets", 1); @@ -173,6 +216,16 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, auto weight_arg = TensorArg(weight, "weight", 1); checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble}); + if (per_sample_weights.defined()) { + AT_CHECK(mode == MODE_SUM, + "embedding_bag: per_sample_weights only supported with mode='sum'"); + auto per_input_weights_arg = TensorArg( + per_sample_weights,"per_sample_weights", 1); + checkSameType("embedding_bag", weight_arg, per_input_weights_arg); + AT_ASSERT(per_sample_weights.dim() == 1); + AT_ASSERT(per_sample_weights.numel() == indices.numel()); + } + auto bag_size = at::zeros(offsets.sizes(), indices.options()); make_bag_size(offsets, indices, mode, bag_size); @@ -191,14 +244,25 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, if (mode == MODE_MEAN || mode == MODE_SUM) { AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() { - index_select_add(indices, offset2bag, weight, output); + if (per_sample_weights.defined()) { + AT_ASSERT(mode == MODE_SUM); + index_select_scale_add( + indices, offset2bag, per_sample_weights, weight, output); + } else { + index_select_add(indices, offset2bag, weight, output); + } }); auto ret = apply_bag_size(offsets, indices, mode, output, bag_size); return std::tuple(ret, offset2bag, bag_size, bag_size); } else { // MODE_MAX + at::optional maybe_per_sample_weights; + if (per_sample_weights.defined()) { + maybe_per_sample_weights = per_sample_weights; + } return AT_DISPATCH_FLOATING_TYPES_AND_HALF( weight.scalar_type(), "embedding_bag_cpu_max", [&]() { - return embedding_bag_cpu_max(weight, indices, offset2bag, output, bag_size, offsets); + return embedding_bag_cpu_max( + weight, indices, offset2bag, output, bag_size, offsets); } ); } @@ -213,7 +277,8 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, const Tensor &max_indices_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, - bool sparse) { + bool sparse, + const Tensor& per_sample_weights) { auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag", indices_arg, kLong); checkContiguous("embedding_bag", indices_arg); @@ -224,6 +289,9 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, checkScalarType("embedding_bag", offset2bag_arg, kLong); checkContiguous("embedding_bag", offset2bag_arg); + AT_CHECK(!per_sample_weights.defined(), + "NYI: _embedding_bag_backward: per_sample_weights"); + if (sparse) { return at::_embedding_bag_sparse_backward( grad, indices, offsets, offset2bag, bag_size_, num_weights, diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index dea987e..a1a76a7 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -321,7 +321,8 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, std::tuple _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, const Tensor &offsets, const bool scale_grad_by_freq, - const int64_t mode, bool sparse) { + const int64_t mode, bool sparse, + const Tensor& per_sample_weights) { auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag_cuda", indices_arg, kLong); auto offsets_arg = TensorArg(offsets, "offsets", 1); @@ -330,6 +331,9 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); + AT_CHECK(!per_sample_weights.defined(), + "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)"); + int64_t numIndices = indices.size(0); int64_t numBags = offsets.size(0); int64_t featureSize = weight.size(1); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a3406a6..86d9d2e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -823,16 +823,16 @@ # applying indices = indices.contiguous(). # The backward functions apply a check that these input tensors are contiguous. -- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) +- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor) matches_jit_signature: True -- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) +- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor) matches_jit_signature: True dispatch: CPU: _embedding_bag_cpu CUDA: _embedding_bag_cuda -- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor +- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor per_sample_weights) -> Tensor matches_jit_signature: True - func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor diff --git a/test/test_nn.py b/test/test_nn.py index 33a2ad2..d4f9cb3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2321,6 +2321,54 @@ class TestNN(NNTestCase): self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype) self._test_gumbel_softmax_grad(cuda=True, dtype=dtype) + def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, + mode='mean', + device='cpu', + dtype=torch.float, + test_per_sample_weights=False, + sparse=True, + test_backward=True): + es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype) + e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype) + e.weight.data.copy_(es.weight) + input = torch.randint(N, (B, L), device=device, dtype=torch.long) + offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L) + grad_output = torch.rand(B, D, device=device, dtype=dtype) + + if test_per_sample_weights: + per_sample_weights = torch.randn(B, L, device=device, dtype=dtype) + output = es(input.view(-1), offsets, per_sample_weights.view(-1)) + else: + output = es(input.view(-1), offsets) + per_sample_weights = None + + if mode == 'sum': + if test_per_sample_weights: + ref_output = (e(input) * per_sample_weights.unsqueeze(-1)).sum(1) + else: + ref_output = e(input).sum(1) + elif mode == 'mean': + assert not test_per_sample_weights + ref_output = e(input).mean(1) + elif mode == 'max': + assert not test_per_sample_weights + ref_output = e(input).max(1)[0] + + self.assertEqual(output, ref_output, dtype2prec[dtype]) + + if not test_backward: + return + + output.backward(grad_output) + ref_output.backward(grad_output) + es_weight_grad = es.weight.grad.data + if sparse: + es_weight_grad = es.weight.grad.data.to_dense() + + # We have more floating point error here because we are dealing with larger numbers + needed_prec = dtype2prec[dtype] * 2 + self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) + def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double): # check a known test example device = torch.device("cuda") if cuda else torch.device("cpu") @@ -2409,39 +2457,12 @@ class TestNN(NNTestCase): self.assertEqual(dense_grad, torch.zeros_like(es.weight)) # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length - def _test_vs_Embedding(N, D, B, L, max_norm=None): - es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype) - e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype) - e.weight.data.copy_(es.weight) - input = torch.randint(N, (B, L), device=device, dtype=torch.long) - offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L) - grad_output = torch.rand(B, D, device=device, dtype=dtype) - - output = es(input.view(-1), offsets) - if mode == 'sum': - ref_output = e(input).sum(1) - elif mode == 'mean': - ref_output = e(input).mean(1) - elif mode == 'max': - ref_output = e(input).max(1)[0] - - self.assertEqual(output, ref_output, dtype2prec[dtype]) - - output.backward(grad_output) - ref_output.backward(grad_output) - es_weight_grad = es.weight.grad.data - if sparse: - es_weight_grad = es.weight.grad.data.to_dense() - - # We have more floating point error here because we are dealing with larger numbers - needed_prec = dtype2prec[dtype] * 2 - self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) - N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50) - _test_vs_Embedding(N, D, B, L) + kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype) + self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) for max_norm in (None, 3): for p in itertools.product([1, 2], repeat=4): - _test_vs_Embedding(*p, max_norm=max_norm) + self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs) # check that giving illegal input combos raises error es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) @@ -2535,6 +2556,110 @@ class TestNN(NNTestCase): self._test_EmbeddingBag(False, 'sum', True) self._test_EmbeddingBag(False, 'mean', True) + @staticmethod + def _embedding_bag_reference_impl(input, weight, offsets=None, mode='sum', + per_sample_weights=None): + assert mode == 'sum' + assert offsets is not None + if per_sample_weights is None: + per_sample_weights = torch.ones(input.size()) + assert input.numel() == per_sample_weights.numel() + + bags = [] + embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1) + for index, offset in enumerate(offsets): + if index + 1 < len(offsets): + next_offset = offsets[index + 1] + else: + next_offset = len(input) + length = next_offset - offset + bags.append(embeddings.narrow(0, offset, length).sum(0)) + return torch.stack(bags) + + @staticmethod + def _test_EmbeddingBag_per_sample_weights_failures(self, device='cpu'): + # Failure 1: mismatched embeddings / per_sample_weights dtype + es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) + with self.assertRaisesRegex(RuntimeError, 'have the same type as'): + es(input, offsets, per_sample_weights) + + # Failure 2.1: input/per_sample_weights have different sizes (1d input) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + per_sample_weights = torch.randn(5, dtype=torch.float, device=device) + with self.assertRaisesRegex(ValueError, 'same shape as the input'): + es(input, offsets, per_sample_weights) + + # Failure 2.2: input/per_sample_weights have different sizes (2d input) + input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + offsets = None + per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device) + with self.assertRaisesRegex(ValueError, 'same shape as the input'): + es(input, offsets, per_sample_weights) + + # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean') + for unsupported_mode in ('max', 'mean'): + es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to( + dtype=torch.float, device=device) + input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + offsets = None + per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device) + with self.assertRaisesRegex(NotImplementedError, + "only supported for mode='sum'"): + es(input, offsets, per_sample_weights) + + def test_EmbeddingBag_per_sample_weights_failures(self): + self._test_EmbeddingBag_per_sample_weights_failures(self) + + @staticmethod + def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'): + def test_per_sample_weights(mode, dtype): + es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device) + es.weight.data.copy_( + torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) + per_sample_weights = torch.randn_like(input, dtype=dtype) + + expected = self._embedding_bag_reference_impl( + input, es.weight, offsets, mode, per_sample_weights) + result = es(input, offsets, per_sample_weights) + self.assertEqual(result, expected) + + dtypes = (torch.float, torch.double) + modes = ('sum',) + for dtype, mode in itertools.product(dtypes, modes): + test_per_sample_weights(mode, dtype) + + def test_EmbeddingBag_per_sample_weights_and_offsets(self): + self._test_EmbeddingBag_per_sample_weights_and_offsets(self) + + @staticmethod + def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'): + dtypes = (torch.float, torch.double) + modes = ('sum',) + for dtype, mode in itertools.product(dtypes, modes): + kwargs = dict(test_per_sample_weights=True, test_backward=False, + mode=mode, dtype=dtype, device=device) + + # Simple case + self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs) + + # B * L > 1000 + self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs) + + # Large num_embedding + self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs) + + # Large embedding_dim + self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) + + def test_EmbeddingBag_per_sample_weights_and_no_offsets(self): + self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @repeat_test_for_types(ALL_TENSORTYPES) def test_embedding_bag_cuda(self, dtype=torch.float): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ab6bb84..00c7a3b 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -951,10 +951,11 @@ - name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) grad_output: embedding_dense_double_backward(grad, indices) -- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) +- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, Tensor per_sample_weights) indices: not_differentiable offsets: not_differentiable - weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse) + weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights) + per_sample_weights: not_differentiable # TODO(rzou): See issue #4068 - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) indices: not_differentiable diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 937e624..7c4a8f3 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1502,8 +1502,9 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., @weak_script def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, - scale_grad_by_freq=False, mode='mean', sparse=False): - # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor + scale_grad_by_freq=False, mode='mean', sparse=False, + per_sample_weights=None): + # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor]) -> Tensor r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the intermediate embeddings. @@ -1530,6 +1531,11 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under :class:`torch.nn.Embedding` for more details regarding sparse gradients. Note: this option is not supported when ``mode="max"``. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not None. + Shape: @@ -1553,6 +1559,9 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as + :attr:`input`. + - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` Examples:: @@ -1575,17 +1584,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, "and should now be `embedding_bag(input, weight, ...)`.") weight, input = input, weight + if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, " + "then it must have the same shape as the input ({})" + .format(per_sample_weights.shape, input.shape)) + if input.dim() == 2: if offsets is not None: raise ValueError("if input is 2D, then offsets has to be None" ", as input is treated is a mini-batch of" " fixed length sequences. However, found " "offsets of type {}".format(type(offsets))) - else: - offsets = torch.arange(0, input.numel(), input.size(1), - dtype=torch.long, device=input.device) + offsets = torch.arange(0, input.numel(), input.size(1), + dtype=torch.long, device=input.device) - input = input.reshape(-1) + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) elif input.dim() == 1: if offsets is None: raise ValueError("offsets has to be a 1D Tensor but got None") @@ -1628,13 +1643,20 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # remove once script supports set_grad_enabled _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + if per_sample_weights is not None and mode != 'sum': + raise NotImplementedError("embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + "(got mode='{}'). Please open a feature request on GitHub." + .format(mode)) + ret, _, _, _ = torch.embedding_bag( weight, input, offsets, scale_grad_by_freq, mode_enum, - sparse) + sparse, + per_sample_weights) return ret diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index cdd359e..325302c 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -178,7 +178,7 @@ class EmbeddingBag(Module): r"""Computes sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. - For bags of constant length, this class + For bags of constant length and no :attr:`per_sample_weights`, this class * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``, * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``, @@ -187,6 +187,12 @@ class EmbeddingBag(Module): However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these operations. + EmbeddingBag also supports per-sample weights as an argument to the forward + pass. This scales the output of the Embedding before performing a weighted + reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the + only supported ``mode`` is ``"sum"``, which computes a weighted sum according to + :attr:`per_sample_weights`. + Args: num_embeddings (int): size of the dictionary of embeddings embedding_dim (int): the size of each embedding vector @@ -197,6 +203,9 @@ class EmbeddingBag(Module): the words in the mini-batch. Default ``False``. Note: this option is not supported when ``mode="max"``. mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` + into consideration. ``"mean"`` computes the average of the values + in the bag, ``"max"`` computes the max value over each bag. Default: ``"mean"`` sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not @@ -206,7 +215,8 @@ class EmbeddingBag(Module): weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` initialized from :math:`\mathcal{N}(0, 1)`. - Inputs: :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) + Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and + :attr:`per_index_weights` (Tensor, optional) - If :attr:`input` is 2D of shape `(B, N)`, @@ -223,6 +233,12 @@ class EmbeddingBag(Module): having ``B`` bags. Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. + + Output shape: `(B, embedding_dim)` Examples:: @@ -262,11 +278,12 @@ class EmbeddingBag(Module): init.normal_(self.weight) @weak_script_method - def forward(self, input, offsets=None): - # type: (Tensor, Optional[Tensor]) -> Tensor + def forward(self, input, offsets=None, per_sample_weights=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor return F.embedding_bag(input, self.weight, offsets, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.mode, self.sparse) + self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights) def extra_repr(self): s = '{num_embeddings}, {embedding_dim}' diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 30ec2b9..cecc9ee 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -423,14 +423,18 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): return g.op("Gather", weight, indices) -@parse_args('v', 'v', 'v', 'i', 'i', 'i') +@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v') def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, - sparse): + sparse, + per_sample_weights): + if not per_sample_weights.node().mustBeNone(): + raise RuntimeError('Unsupported: ONNX export of embedding_bag ' + 'with per_sample_weights') return g.op("ATen", embedding_matrix, indices,