const Tensor &add_indices,
const Tensor &src,
Tensor &output) {
+ AT_ASSERT(select_indices.numel() == add_indices.numel());
auto add_indices_data = add_indices.data<int64_t>();
auto select_indices_data = select_indices.data<int64_t>();
auto src_data = src.data<T>();
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
+
for (int64_t i = 0; i < numel; i++) {
THBlas_axpy<T>(ddim, 1,
src_data + src_stride0 * select_indices_data[i], src_stride1,
}
}
+// This function fuses the following three fns:
+// index_select (using select_indices as the index)
+// mul (scaling by per_sample_weights)
+// index_add (using add_indices as the index)
+template<typename T>
+static void index_select_scale_add(const Tensor &select_indices,
+ const Tensor &add_indices,
+ const Tensor &scale,
+ const Tensor &src,
+ Tensor &output) {
+ AT_ASSERT(select_indices.numel() == add_indices.numel());
+ auto add_indices_data = add_indices.data<int64_t>();
+ auto select_indices_data = select_indices.data<int64_t>();
+ auto src_data = src.data<T>();
+ auto output_data = output.data<T>();
+ auto numel = add_indices.numel();
+ int64_t ddim = src.size(1);
+ auto src_stride0 = src.stride(0);
+ auto src_stride1 = src.stride(1);
+ auto output_stride0 = output.stride(0);
+ auto output_stride1 = output.stride(1);
+
+ auto* scale_data = scale.data<T>();
+ auto scale_stride = scale.stride(0);
+
+ // XXX: We could make this faster via vectorization
+ for (int64_t i = 0; i < numel; i++) {
+ auto* src_base = src_data + src_stride0 * select_indices_data[i];
+ auto* output_base = output_data + output_stride0 * add_indices_data[i];
+ auto scale = scale_data[i * scale_stride];
+ for (int64_t j = 0; j < ddim; j++) {
+ output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
+ }
+ }
+}
+
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == MODE_MEAN || mode == MODE_MAX) {
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
- const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
+ const Tensor& weight,
+ const Tensor& indices,
+ const Tensor& offset2bag,
+ const Tensor& output,
+ const Tensor& bag_size,
+ const Tensor& offsets) {
auto max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.options());
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];
-
for (int dim = 0; dim < dims; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item = weight_data[weight_stride0 * word_idx + dim * weight_stride1];
-
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
if (is_first_for_bag || weight_item > current_item) {
std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
- const int64_t mode, bool sparse) {
+ const int64_t mode, bool sparse,
+ const Tensor &per_sample_weights) {
return at::_embedding_bag(weight, indices.contiguous(), offsets.contiguous(),
- scale_grad_by_freq, mode, sparse);
+ scale_grad_by_freq, mode, sparse, per_sample_weights);
};
// Assumes all input tensors except for `weight` are contiguous.
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
- const int64_t mode, bool sparse) {
+ const int64_t mode, bool sparse,
+ const Tensor &per_sample_weights) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
auto weight_arg = TensorArg(weight, "weight", 1);
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
+ if (per_sample_weights.defined()) {
+ AT_CHECK(mode == MODE_SUM,
+ "embedding_bag: per_sample_weights only supported with mode='sum'");
+ auto per_input_weights_arg = TensorArg(
+ per_sample_weights,"per_sample_weights", 1);
+ checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
+ AT_ASSERT(per_sample_weights.dim() == 1);
+ AT_ASSERT(per_sample_weights.numel() == indices.numel());
+ }
+
auto bag_size = at::zeros(offsets.sizes(), indices.options());
make_bag_size(offsets, indices, mode, bag_size);
if (mode == MODE_MEAN || mode == MODE_SUM) {
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
- index_select_add<scalar_t>(indices, offset2bag, weight, output);
+ if (per_sample_weights.defined()) {
+ AT_ASSERT(mode == MODE_SUM);
+ index_select_scale_add<scalar_t>(
+ indices, offset2bag, per_sample_weights, weight, output);
+ } else {
+ index_select_add<scalar_t>(indices, offset2bag, weight, output);
+ }
});
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
} else { // MODE_MAX
+ at::optional<Tensor> maybe_per_sample_weights;
+ if (per_sample_weights.defined()) {
+ maybe_per_sample_weights = per_sample_weights;
+ }
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.scalar_type(), "embedding_bag_cpu_max", [&]() {
- return embedding_bag_cpu_max<scalar_t>(weight, indices, offset2bag, output, bag_size, offsets);
+ return embedding_bag_cpu_max<scalar_t>(
+ weight, indices, offset2bag, output, bag_size, offsets);
}
);
}
const Tensor &max_indices_,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
- bool sparse) {
+ bool sparse,
+ const Tensor& per_sample_weights) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkContiguous("embedding_bag", indices_arg);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkContiguous("embedding_bag", offset2bag_arg);
+ AT_CHECK(!per_sample_weights.defined(),
+ "NYI: _embedding_bag_backward: per_sample_weights");
+
if (sparse) {
return at::_embedding_bag_sparse_backward(
grad, indices, offsets, offset2bag, bag_size_, num_weights,
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
- const int64_t mode, bool sparse) {
+ const int64_t mode, bool sparse,
+ const Tensor& per_sample_weights) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
+ AT_CHECK(!per_sample_weights.defined(),
+ "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)");
+
int64_t numIndices = indices.size(0);
int64_t numBags = offsets.size(0);
int64_t featureSize = weight.size(1);
# applying indices = indices.contiguous().
# The backward functions apply a check that these input tensors are contiguous.
-- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor)
+- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor)
matches_jit_signature: True
-- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor)
+- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor)
matches_jit_signature: True
dispatch:
CPU: _embedding_bag_cpu
CUDA: _embedding_bag_cuda
-- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor
+- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor per_sample_weights) -> Tensor
matches_jit_signature: True
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor
self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype)
self._test_gumbel_softmax_grad(cuda=True, dtype=dtype)
+ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
+ mode='mean',
+ device='cpu',
+ dtype=torch.float,
+ test_per_sample_weights=False,
+ sparse=True,
+ test_backward=True):
+ es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
+ e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
+ e.weight.data.copy_(es.weight)
+ input = torch.randint(N, (B, L), device=device, dtype=torch.long)
+ offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
+ grad_output = torch.rand(B, D, device=device, dtype=dtype)
+
+ if test_per_sample_weights:
+ per_sample_weights = torch.randn(B, L, device=device, dtype=dtype)
+ output = es(input.view(-1), offsets, per_sample_weights.view(-1))
+ else:
+ output = es(input.view(-1), offsets)
+ per_sample_weights = None
+
+ if mode == 'sum':
+ if test_per_sample_weights:
+ ref_output = (e(input) * per_sample_weights.unsqueeze(-1)).sum(1)
+ else:
+ ref_output = e(input).sum(1)
+ elif mode == 'mean':
+ assert not test_per_sample_weights
+ ref_output = e(input).mean(1)
+ elif mode == 'max':
+ assert not test_per_sample_weights
+ ref_output = e(input).max(1)[0]
+
+ self.assertEqual(output, ref_output, dtype2prec[dtype])
+
+ if not test_backward:
+ return
+
+ output.backward(grad_output)
+ ref_output.backward(grad_output)
+ es_weight_grad = es.weight.grad.data
+ if sparse:
+ es_weight_grad = es.weight.grad.data.to_dense()
+
+ # We have more floating point error here because we are dealing with larger numbers
+ needed_prec = dtype2prec[dtype] * 2
+ self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
+
def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
# check a known test example
device = torch.device("cuda") if cuda else torch.device("cpu")
self.assertEqual(dense_grad, torch.zeros_like(es.weight))
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
- def _test_vs_Embedding(N, D, B, L, max_norm=None):
- es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
- e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
- e.weight.data.copy_(es.weight)
- input = torch.randint(N, (B, L), device=device, dtype=torch.long)
- offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
- grad_output = torch.rand(B, D, device=device, dtype=dtype)
-
- output = es(input.view(-1), offsets)
- if mode == 'sum':
- ref_output = e(input).sum(1)
- elif mode == 'mean':
- ref_output = e(input).mean(1)
- elif mode == 'max':
- ref_output = e(input).max(1)[0]
-
- self.assertEqual(output, ref_output, dtype2prec[dtype])
-
- output.backward(grad_output)
- ref_output.backward(grad_output)
- es_weight_grad = es.weight.grad.data
- if sparse:
- es_weight_grad = es.weight.grad.data.to_dense()
-
- # We have more floating point error here because we are dealing with larger numbers
- needed_prec = dtype2prec[dtype] * 2
- self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
-
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
- _test_vs_Embedding(N, D, B, L)
+ kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype)
+ self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
for max_norm in (None, 3):
for p in itertools.product([1, 2], repeat=4):
- _test_vs_Embedding(*p, max_norm=max_norm)
+ self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs)
# check that giving illegal input combos raises error
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
self._test_EmbeddingBag(False, 'sum', True)
self._test_EmbeddingBag(False, 'mean', True)
+ @staticmethod
+ def _embedding_bag_reference_impl(input, weight, offsets=None, mode='sum',
+ per_sample_weights=None):
+ assert mode == 'sum'
+ assert offsets is not None
+ if per_sample_weights is None:
+ per_sample_weights = torch.ones(input.size())
+ assert input.numel() == per_sample_weights.numel()
+
+ bags = []
+ embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1)
+ for index, offset in enumerate(offsets):
+ if index + 1 < len(offsets):
+ next_offset = offsets[index + 1]
+ else:
+ next_offset = len(input)
+ length = next_offset - offset
+ bags.append(embeddings.narrow(0, offset, length).sum(0))
+ return torch.stack(bags)
+
+ @staticmethod
+ def _test_EmbeddingBag_per_sample_weights_failures(self, device='cpu'):
+ # Failure 1: mismatched embeddings / per_sample_weights dtype
+ es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
+ input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
+ offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
+ per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
+ with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
+ es(input, offsets, per_sample_weights)
+
+ # Failure 2.1: input/per_sample_weights have different sizes (1d input)
+ input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
+ offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
+ per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
+ with self.assertRaisesRegex(ValueError, 'same shape as the input'):
+ es(input, offsets, per_sample_weights)
+
+ # Failure 2.2: input/per_sample_weights have different sizes (2d input)
+ input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
+ offsets = None
+ per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
+ with self.assertRaisesRegex(ValueError, 'same shape as the input'):
+ es(input, offsets, per_sample_weights)
+
+ # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean')
+ for unsupported_mode in ('max', 'mean'):
+ es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
+ dtype=torch.float, device=device)
+ input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
+ offsets = None
+ per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
+ with self.assertRaisesRegex(NotImplementedError,
+ "only supported for mode='sum'"):
+ es(input, offsets, per_sample_weights)
+
+ def test_EmbeddingBag_per_sample_weights_failures(self):
+ self._test_EmbeddingBag_per_sample_weights_failures(self)
+
+ @staticmethod
+ def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'):
+ def test_per_sample_weights(mode, dtype):
+ es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
+ es.weight.data.copy_(
+ torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
+ input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
+ offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
+ per_sample_weights = torch.randn_like(input, dtype=dtype)
+
+ expected = self._embedding_bag_reference_impl(
+ input, es.weight, offsets, mode, per_sample_weights)
+ result = es(input, offsets, per_sample_weights)
+ self.assertEqual(result, expected)
+
+ dtypes = (torch.float, torch.double)
+ modes = ('sum',)
+ for dtype, mode in itertools.product(dtypes, modes):
+ test_per_sample_weights(mode, dtype)
+
+ def test_EmbeddingBag_per_sample_weights_and_offsets(self):
+ self._test_EmbeddingBag_per_sample_weights_and_offsets(self)
+
+ @staticmethod
+ def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
+ dtypes = (torch.float, torch.double)
+ modes = ('sum',)
+ for dtype, mode in itertools.product(dtypes, modes):
+ kwargs = dict(test_per_sample_weights=True, test_backward=False,
+ mode=mode, dtype=dtype, device=device)
+
+ # Simple case
+ self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs)
+
+ # B * L > 1000
+ self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs)
+
+ # Large num_embedding
+ self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs)
+
+ # Large embedding_dim
+ self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
+
+ def test_EmbeddingBag_per_sample_weights_and_no_offsets(self):
+ self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)
+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_embedding_bag_cuda(self, dtype=torch.float):
- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq)
grad_output: embedding_dense_double_backward(grad, indices)
-- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
+- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, Tensor per_sample_weights)
indices: not_differentiable
offsets: not_differentiable
- weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse)
+ weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights)
+ per_sample_weights: not_differentiable # TODO(rzou): See issue #4068
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
indices: not_differentiable
@weak_script
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
- scale_grad_by_freq=False, mode='mean', sparse=False):
- # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor
+ scale_grad_by_freq=False, mode='mean', sparse=False,
+ per_sample_weights=None):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor]) -> Tensor
r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the
intermediate embeddings.
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
:class:`torch.nn.Embedding` for more details regarding sparse gradients.
Note: this option is not supported when ``mode="max"``.
+ per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+ to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
+ must have exactly the same shape as input and is treated as having the same
+ :attr:`offsets`, if those are not None.
+
Shape:
- :attr:`weight` (Tensor): the learnable weights of the module of
shape `(num_embeddings, embedding_dim)`
+ - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as
+ :attr:`input`.
+
- :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
Examples::
"and should now be `embedding_bag(input, weight, ...)`.")
weight, input = input, weight
+ if per_sample_weights is not None and input.size() != per_sample_weights.size():
+ raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, "
+ "then it must have the same shape as the input ({})"
+ .format(per_sample_weights.shape, input.shape))
+
if input.dim() == 2:
if offsets is not None:
raise ValueError("if input is 2D, then offsets has to be None"
", as input is treated is a mini-batch of"
" fixed length sequences. However, found "
"offsets of type {}".format(type(offsets)))
- else:
- offsets = torch.arange(0, input.numel(), input.size(1),
- dtype=torch.long, device=input.device)
+ offsets = torch.arange(0, input.numel(), input.size(1),
+ dtype=torch.long, device=input.device)
- input = input.reshape(-1)
+ input = input.reshape(-1)
+ if per_sample_weights is not None:
+ per_sample_weights = per_sample_weights.reshape(-1)
elif input.dim() == 1:
if offsets is None:
raise ValueError("offsets has to be a 1D Tensor but got None")
# remove once script supports set_grad_enabled
_no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
+ if per_sample_weights is not None and mode != 'sum':
+ raise NotImplementedError("embedding_bag: per_sample_weights was not None. "
+ "per_sample_weights is only supported for mode='sum' "
+ "(got mode='{}'). Please open a feature request on GitHub."
+ .format(mode))
+
ret, _, _, _ = torch.embedding_bag(
weight,
input,
offsets,
scale_grad_by_freq,
mode_enum,
- sparse)
+ sparse,
+ per_sample_weights)
return ret
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
intermediate embeddings.
- For bags of constant length, this class
+ For bags of constant length and no :attr:`per_sample_weights`, this class
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
operations.
+ EmbeddingBag also supports per-sample weights as an argument to the forward
+ pass. This scales the output of the Embedding before performing a weighted
+ reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
+ only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
+ :attr:`per_sample_weights`.
+
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
the words in the mini-batch. Default ``False``.
Note: this option is not supported when ``mode="max"``.
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
+ ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
+ into consideration. ``"mean"`` computes the average of the values
+ in the bag, ``"max"`` computes the max value over each bag.
Default: ``"mean"``
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
Notes for more details regarding sparse gradients. Note: this option is not
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
initialized from :math:`\mathcal{N}(0, 1)`.
- Inputs: :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
+ Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
+ :attr:`per_index_weights` (Tensor, optional)
- If :attr:`input` is 2D of shape `(B, N)`,
having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
+ per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+ to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
+ must have exactly the same shape as input and is treated as having the same
+ :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
+
+
Output shape: `(B, embedding_dim)`
Examples::
init.normal_(self.weight)
@weak_script_method
- def forward(self, input, offsets=None):
- # type: (Tensor, Optional[Tensor]) -> Tensor
+ def forward(self, input, offsets=None, per_sample_weights=None):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
return F.embedding_bag(input, self.weight, offsets,
self.max_norm, self.norm_type,
- self.scale_grad_by_freq, self.mode, self.sparse)
+ self.scale_grad_by_freq, self.mode, self.sparse,
+ per_sample_weights)
def extra_repr(self):
s = '{num_embeddings}, {embedding_dim}'
return g.op("Gather", weight, indices)
-@parse_args('v', 'v', 'v', 'i', 'i', 'i')
+@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
- sparse):
+ sparse,
+ per_sample_weights):
+ if not per_sample_weights.node().mustBeNone():
+ raise RuntimeError('Unsupported: ONNX export of embedding_bag '
+ 'with per_sample_weights')
return g.op("ATen",
embedding_matrix,
indices,