EmbeddingBag CPU forward with per_sample_weights. (#18735)
authorRichard Zou <zou3519@gmail.com>
Wed, 10 Apr 2019 01:08:59 +0000 (18:08 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 01:12:55 +0000 (18:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18735
ghimport-source-id: d81bef54dafd7167d2451250d7be478d3c013920

Reviewed By: cpuhrsch

Differential Revision: D14851415

Pulled By: zou3519

fbshipit-source-id: cea6039e760ad571b90f0a536e420498f34be325

aten/src/ATen/native/EmbeddingBag.cpp
aten/src/ATen/native/cuda/EmbeddingBag.cu
aten/src/ATen/native/native_functions.yaml
test/test_nn.py
tools/autograd/derivatives.yaml
torch/nn/functional.py
torch/nn/modules/sparse.py
torch/onnx/symbolic.py

index 150fb25..677d7fe 100644 (file)
@@ -39,6 +39,7 @@ static void index_select_add(const Tensor &select_indices,
                              const Tensor &add_indices,
                              const Tensor &src,
                              Tensor &output) {
+  AT_ASSERT(select_indices.numel() == add_indices.numel());
   auto add_indices_data = add_indices.data<int64_t>();
   auto select_indices_data = select_indices.data<int64_t>();
   auto src_data = src.data<T>();
@@ -49,6 +50,7 @@ static void index_select_add(const Tensor &select_indices,
   auto src_stride1 = src.stride(1);
   auto output_stride0 = output.stride(0);
   auto output_stride1 = output.stride(1);
+
   for (int64_t i = 0; i < numel; i++) {
     THBlas_axpy<T>(ddim, 1,
             src_data + src_stride0 * select_indices_data[i], src_stride1,
@@ -56,6 +58,42 @@ static void index_select_add(const Tensor &select_indices,
   }
 }
 
+// This function fuses the following three fns:
+// index_select (using select_indices as the index)
+// mul (scaling by per_sample_weights)
+// index_add (using add_indices as the index)
+template<typename T>
+static void index_select_scale_add(const Tensor &select_indices,
+                                   const Tensor &add_indices,
+                                   const Tensor &scale,
+                                   const Tensor &src,
+                                   Tensor &output) {
+  AT_ASSERT(select_indices.numel() == add_indices.numel());
+  auto add_indices_data = add_indices.data<int64_t>();
+  auto select_indices_data = select_indices.data<int64_t>();
+  auto src_data = src.data<T>();
+  auto output_data = output.data<T>();
+  auto numel = add_indices.numel();
+  int64_t ddim = src.size(1);
+  auto src_stride0 = src.stride(0);
+  auto src_stride1 = src.stride(1);
+  auto output_stride0 = output.stride(0);
+  auto output_stride1 = output.stride(1);
+
+  auto* scale_data = scale.data<T>();
+  auto scale_stride = scale.stride(0);
+
+  // XXX: We could make this faster via vectorization
+  for (int64_t i = 0; i < numel; i++) {
+    auto* src_base = src_data + src_stride0 * select_indices_data[i];
+    auto* output_base = output_data + output_stride0 * add_indices_data[i];
+    auto scale = scale_data[i * scale_stride];
+    for (int64_t j = 0; j < ddim; j++) {
+      output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
+    }
+  }
+}
+
 static void make_bag_size(const Tensor &offsets, const Tensor &indices,
                           const int64_t mode, Tensor &bag_size) {
   if (mode == MODE_MEAN || mode == MODE_MAX) {
@@ -110,7 +148,12 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
 
 template <typename scalar_t>
 std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
-  const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offset2bag,
+    const Tensor& output,
+    const Tensor& bag_size,
+    const Tensor& offsets) {
 
     auto max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.options());
 
@@ -132,11 +175,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
       auto bag = offset2bag_data[i];
       auto word_idx = indices_data[i];
 
-
       for (int dim = 0; dim < dims; dim++) {
         auto& current_item = output_data[output_stride * bag + dim];
         auto weight_item = weight_data[weight_stride0 * word_idx + dim * weight_stride1];
-
         bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
 
         if (is_first_for_bag || weight_item > current_item) {
@@ -155,9 +196,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
 std::tuple<Tensor, Tensor, Tensor, Tensor>
 embedding_bag(const Tensor &weight, const Tensor &indices,
               const Tensor &offsets, const bool scale_grad_by_freq,
-              const int64_t mode, bool sparse) {
+              const int64_t mode, bool sparse,
+              const Tensor &per_sample_weights) {
   return at::_embedding_bag(weight, indices.contiguous(), offsets.contiguous(),
-                            scale_grad_by_freq, mode, sparse);
+                            scale_grad_by_freq, mode, sparse, per_sample_weights);
   };
 
 // Assumes all input tensors except for `weight` are contiguous.
@@ -165,7 +207,8 @@ embedding_bag(const Tensor &weight, const Tensor &indices,
 std::tuple<Tensor, Tensor, Tensor, Tensor>
 _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
                   const Tensor &offsets, const bool scale_grad_by_freq,
-                  const int64_t mode, bool sparse) {
+                  const int64_t mode, bool sparse,
+                  const Tensor &per_sample_weights) {
   auto indices_arg = TensorArg(indices, "indices", 1);
   checkScalarType("embedding_bag", indices_arg, kLong);
   auto offsets_arg = TensorArg(offsets, "offsets", 1);
@@ -173,6 +216,16 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
   auto weight_arg = TensorArg(weight, "weight", 1);
   checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
 
+  if (per_sample_weights.defined()) {
+    AT_CHECK(mode == MODE_SUM,
+        "embedding_bag: per_sample_weights only supported with mode='sum'");
+    auto per_input_weights_arg = TensorArg(
+        per_sample_weights,"per_sample_weights", 1);
+    checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
+    AT_ASSERT(per_sample_weights.dim() == 1);
+    AT_ASSERT(per_sample_weights.numel() == indices.numel());
+  }
+
   auto bag_size = at::zeros(offsets.sizes(), indices.options());
   make_bag_size(offsets, indices, mode, bag_size);
 
@@ -191,14 +244,25 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
 
   if (mode == MODE_MEAN || mode == MODE_SUM) {
     AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
-      index_select_add<scalar_t>(indices, offset2bag, weight, output);
+      if (per_sample_weights.defined()) {
+        AT_ASSERT(mode == MODE_SUM);
+        index_select_scale_add<scalar_t>(
+            indices, offset2bag, per_sample_weights, weight, output);
+      } else {
+        index_select_add<scalar_t>(indices, offset2bag, weight, output);
+      }
     });
     auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
     return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
   } else { // MODE_MAX
+    at::optional<Tensor> maybe_per_sample_weights;
+    if (per_sample_weights.defined()) {
+      maybe_per_sample_weights = per_sample_weights;
+    }
     return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
       weight.scalar_type(), "embedding_bag_cpu_max", [&]() {
-        return embedding_bag_cpu_max<scalar_t>(weight, indices, offset2bag, output, bag_size, offsets);
+        return embedding_bag_cpu_max<scalar_t>(
+            weight, indices, offset2bag, output, bag_size, offsets);
       }
     );
   }
@@ -213,7 +277,8 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
                               const Tensor &max_indices_,
                               int64_t num_weights,
                               bool scale_grad_by_freq, int64_t mode,
-                              bool sparse) {
+                              bool sparse,
+                              const Tensor& per_sample_weights) {
   auto indices_arg = TensorArg(indices, "indices", 1);
   checkScalarType("embedding_bag", indices_arg, kLong);
   checkContiguous("embedding_bag", indices_arg);
@@ -224,6 +289,9 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
   checkScalarType("embedding_bag", offset2bag_arg, kLong);
   checkContiguous("embedding_bag", offset2bag_arg);
 
+  AT_CHECK(!per_sample_weights.defined(),
+      "NYI: _embedding_bag_backward: per_sample_weights");
+
   if (sparse) {
     return at::_embedding_bag_sparse_backward(
         grad, indices, offsets, offset2bag, bag_size_, num_weights,
index dea987e..a1a76a7 100644 (file)
@@ -321,7 +321,8 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
 std::tuple<Tensor, Tensor, Tensor, Tensor>
 _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
                    const Tensor &offsets, const bool scale_grad_by_freq,
-                   const int64_t mode, bool sparse) {
+                   const int64_t mode, bool sparse,
+                   const Tensor& per_sample_weights) {
   auto indices_arg = TensorArg(indices, "indices", 1);
   checkScalarType("embedding_bag_cuda", indices_arg, kLong);
   auto offsets_arg = TensorArg(offsets, "offsets", 1);
@@ -330,6 +331,9 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
   checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
   checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
 
+  AT_CHECK(!per_sample_weights.defined(),
+      "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)");
+
   int64_t numIndices = indices.size(0);
   int64_t numBags = offsets.size(0);
   int64_t featureSize = weight.size(1);
index a3406a6..86d9d2e 100644 (file)
 # applying indices = indices.contiguous().
 # The backward functions apply a check that these input tensors are contiguous.
 
-- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor)
+- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor)
   matches_jit_signature: True
 
-- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor)
+- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None) -> (Tensor, Tensor, Tensor, Tensor)
   matches_jit_signature: True
   dispatch:
     CPU: _embedding_bag_cpu
     CUDA: _embedding_bag_cuda
 
-- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor
+- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor per_sample_weights) -> Tensor
   matches_jit_signature: True
 
 - func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor
index 33a2ad2..d4f9cb3 100644 (file)
@@ -2321,6 +2321,54 @@ class TestNN(NNTestCase):
         self._test_gumbel_softmax_straight_through(cuda=True, dtype=dtype)
         self._test_gumbel_softmax_grad(cuda=True, dtype=dtype)
 
+    def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
+                                        mode='mean',
+                                        device='cpu',
+                                        dtype=torch.float,
+                                        test_per_sample_weights=False,
+                                        sparse=True,
+                                        test_backward=True):
+        es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
+        e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
+        e.weight.data.copy_(es.weight)
+        input = torch.randint(N, (B, L), device=device, dtype=torch.long)
+        offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
+        grad_output = torch.rand(B, D, device=device, dtype=dtype)
+
+        if test_per_sample_weights:
+            per_sample_weights = torch.randn(B, L, device=device, dtype=dtype)
+            output = es(input.view(-1), offsets, per_sample_weights.view(-1))
+        else:
+            output = es(input.view(-1), offsets)
+            per_sample_weights = None
+
+        if mode == 'sum':
+            if test_per_sample_weights:
+                ref_output = (e(input) * per_sample_weights.unsqueeze(-1)).sum(1)
+            else:
+                ref_output = e(input).sum(1)
+        elif mode == 'mean':
+            assert not test_per_sample_weights
+            ref_output = e(input).mean(1)
+        elif mode == 'max':
+            assert not test_per_sample_weights
+            ref_output = e(input).max(1)[0]
+
+        self.assertEqual(output, ref_output, dtype2prec[dtype])
+
+        if not test_backward:
+            return
+
+        output.backward(grad_output)
+        ref_output.backward(grad_output)
+        es_weight_grad = es.weight.grad.data
+        if sparse:
+            es_weight_grad = es.weight.grad.data.to_dense()
+
+        # We have more floating point error here because we are dealing with larger numbers
+        needed_prec = dtype2prec[dtype] * 2
+        self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
+
     def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
         # check a known test example
         device = torch.device("cuda") if cuda else torch.device("cpu")
@@ -2409,39 +2457,12 @@ class TestNN(NNTestCase):
         self.assertEqual(dense_grad, torch.zeros_like(es.weight))
 
         # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
-        def _test_vs_Embedding(N, D, B, L, max_norm=None):
-            es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
-            e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
-            e.weight.data.copy_(es.weight)
-            input = torch.randint(N, (B, L), device=device, dtype=torch.long)
-            offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
-            grad_output = torch.rand(B, D, device=device, dtype=dtype)
-
-            output = es(input.view(-1), offsets)
-            if mode == 'sum':
-                ref_output = e(input).sum(1)
-            elif mode == 'mean':
-                ref_output = e(input).mean(1)
-            elif mode == 'max':
-                ref_output = e(input).max(1)[0]
-
-            self.assertEqual(output, ref_output, dtype2prec[dtype])
-
-            output.backward(grad_output)
-            ref_output.backward(grad_output)
-            es_weight_grad = es.weight.grad.data
-            if sparse:
-                es_weight_grad = es.weight.grad.data.to_dense()
-
-            # We have more floating point error here because we are dealing with larger numbers
-            needed_prec = dtype2prec[dtype] * 2
-            self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
-
         N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
-        _test_vs_Embedding(N, D, B, L)
+        kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype)
+        self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
         for max_norm in (None, 3):
             for p in itertools.product([1, 2], repeat=4):
-                _test_vs_Embedding(*p, max_norm=max_norm)
+                self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs)
 
         # check that giving illegal input combos raises error
         es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
@@ -2535,6 +2556,110 @@ class TestNN(NNTestCase):
         self._test_EmbeddingBag(False, 'sum', True)
         self._test_EmbeddingBag(False, 'mean', True)
 
+    @staticmethod
+    def _embedding_bag_reference_impl(input, weight, offsets=None, mode='sum',
+                                      per_sample_weights=None):
+        assert mode == 'sum'
+        assert offsets is not None
+        if per_sample_weights is None:
+            per_sample_weights = torch.ones(input.size())
+        assert input.numel() == per_sample_weights.numel()
+
+        bags = []
+        embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1)
+        for index, offset in enumerate(offsets):
+            if index + 1 < len(offsets):
+                next_offset = offsets[index + 1]
+            else:
+                next_offset = len(input)
+            length = next_offset - offset
+            bags.append(embeddings.narrow(0, offset, length).sum(0))
+        return torch.stack(bags)
+
+    @staticmethod
+    def _test_EmbeddingBag_per_sample_weights_failures(self, device='cpu'):
+        # Failure 1: mismatched embeddings / per_sample_weights dtype
+        es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
+        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
+        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
+        per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
+        with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
+            es(input, offsets, per_sample_weights)
+
+        # Failure 2.1: input/per_sample_weights have different sizes (1d input)
+        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
+        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
+        per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
+        with self.assertRaisesRegex(ValueError, 'same shape as the input'):
+            es(input, offsets, per_sample_weights)
+
+        # Failure 2.2: input/per_sample_weights have different sizes (2d input)
+        input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
+        offsets = None
+        per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
+        with self.assertRaisesRegex(ValueError, 'same shape as the input'):
+            es(input, offsets, per_sample_weights)
+
+        # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean')
+        for unsupported_mode in ('max', 'mean'):
+            es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
+                dtype=torch.float, device=device)
+            input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
+            offsets = None
+            per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
+            with self.assertRaisesRegex(NotImplementedError,
+                                        "only supported for mode='sum'"):
+                es(input, offsets, per_sample_weights)
+
+    def test_EmbeddingBag_per_sample_weights_failures(self):
+        self._test_EmbeddingBag_per_sample_weights_failures(self)
+
+    @staticmethod
+    def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'):
+        def test_per_sample_weights(mode, dtype):
+            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
+            es.weight.data.copy_(
+                torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
+            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
+            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
+            per_sample_weights = torch.randn_like(input, dtype=dtype)
+
+            expected = self._embedding_bag_reference_impl(
+                input, es.weight, offsets, mode, per_sample_weights)
+            result = es(input, offsets, per_sample_weights)
+            self.assertEqual(result, expected)
+
+        dtypes = (torch.float, torch.double)
+        modes = ('sum',)
+        for dtype, mode in itertools.product(dtypes, modes):
+            test_per_sample_weights(mode, dtype)
+
+    def test_EmbeddingBag_per_sample_weights_and_offsets(self):
+        self._test_EmbeddingBag_per_sample_weights_and_offsets(self)
+
+    @staticmethod
+    def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
+        dtypes = (torch.float, torch.double)
+        modes = ('sum',)
+        for dtype, mode in itertools.product(dtypes, modes):
+            kwargs = dict(test_per_sample_weights=True, test_backward=False,
+                          mode=mode, dtype=dtype, device=device)
+
+            # Simple case
+            self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs)
+
+            # B * L > 1000
+            self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs)
+
+            # Large num_embedding
+            self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs)
+
+            # Large embedding_dim
+            self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
+
+    def test_EmbeddingBag_per_sample_weights_and_no_offsets(self):
+        self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)
+
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
     def test_embedding_bag_cuda(self, dtype=torch.float):
index ab6bb84..00c7a3b 100644 (file)
 - name: embedding_dense_backward(Tensor grad_output, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq)
   grad_output: embedding_dense_double_backward(grad, indices)
 
-- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
+- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, Tensor per_sample_weights)
   indices: not_differentiable
   offsets: not_differentiable
-  weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse)
+  weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights)
+  per_sample_weights: not_differentiable # TODO(rzou): See issue #4068
 
 - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
   indices: not_differentiable
index 937e624..7c4a8f3 100644 (file)
@@ -1502,8 +1502,9 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
 
 @weak_script
 def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
-                  scale_grad_by_freq=False, mode='mean', sparse=False):
-    # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor
+                  scale_grad_by_freq=False, mode='mean', sparse=False,
+                  per_sample_weights=None):
+    # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor]) -> Tensor
     r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the
     intermediate embeddings.
 
@@ -1530,6 +1531,11 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
         sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
                                  :class:`torch.nn.Embedding` for more details regarding sparse gradients.
                                  Note: this option is not supported when ``mode="max"``.
+        per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+            to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
+            must have exactly the same shape as input and is treated as having the same
+            :attr:`offsets`, if those are not None.
+
 
     Shape:
 
@@ -1553,6 +1559,9 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
         - :attr:`weight` (Tensor): the learnable weights of the module of
           shape `(num_embeddings, embedding_dim)`
 
+        - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as
+          :attr:`input`.
+
         - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
 
     Examples::
@@ -1575,17 +1584,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
                       "and should now be `embedding_bag(input, weight, ...)`.")
         weight, input = input, weight
 
+    if per_sample_weights is not None and input.size() != per_sample_weights.size():
+        raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, "
+                         "then it must have the same shape as the input ({})"
+                         .format(per_sample_weights.shape, input.shape))
+
     if input.dim() == 2:
         if offsets is not None:
             raise ValueError("if input is 2D, then offsets has to be None"
                              ", as input is treated is a mini-batch of"
                              " fixed length sequences. However, found "
                              "offsets of type {}".format(type(offsets)))
-        else:
-            offsets = torch.arange(0, input.numel(), input.size(1),
-                                   dtype=torch.long, device=input.device)
+        offsets = torch.arange(0, input.numel(), input.size(1),
+                               dtype=torch.long, device=input.device)
 
-            input = input.reshape(-1)
+        input = input.reshape(-1)
+        if per_sample_weights is not None:
+            per_sample_weights = per_sample_weights.reshape(-1)
     elif input.dim() == 1:
         if offsets is None:
             raise ValueError("offsets has to be a 1D Tensor but got None")
@@ -1628,13 +1643,20 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
         # remove once script supports set_grad_enabled
         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
 
+    if per_sample_weights is not None and mode != 'sum':
+        raise NotImplementedError("embedding_bag: per_sample_weights was not None. "
+                                  "per_sample_weights is only supported for mode='sum' "
+                                  "(got mode='{}'). Please open a feature request on GitHub."
+                                  .format(mode))
+
     ret, _, _, _ = torch.embedding_bag(
         weight,
         input,
         offsets,
         scale_grad_by_freq,
         mode_enum,
-        sparse)
+        sparse,
+        per_sample_weights)
     return ret
 
 
index cdd359e..325302c 100644 (file)
@@ -178,7 +178,7 @@ class EmbeddingBag(Module):
     r"""Computes sums or means of 'bags' of embeddings, without instantiating the
     intermediate embeddings.
 
-    For bags of constant length, this class
+    For bags of constant length and no :attr:`per_sample_weights`, this class
 
         * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
         * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
@@ -187,6 +187,12 @@ class EmbeddingBag(Module):
     However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
     operations.
 
+    EmbeddingBag also supports per-sample weights as an argument to the forward
+    pass. This scales the output of the Embedding before performing a weighted
+    reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
+    only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
+    :attr:`per_sample_weights`.
+
     Args:
         num_embeddings (int): size of the dictionary of embeddings
         embedding_dim (int): the size of each embedding vector
@@ -197,6 +203,9 @@ class EmbeddingBag(Module):
                                                 the words in the mini-batch. Default ``False``.
                                                 Note: this option is not supported when ``mode="max"``.
         mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
+                                 ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
+                                 into consideration. ``"mean"`` computes the average of the values
+                                 in the bag, ``"max"`` computes the max value over each bag.
                                  Default: ``"mean"``
         sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
                                  Notes for more details regarding sparse gradients. Note: this option is not
@@ -206,7 +215,8 @@ class EmbeddingBag(Module):
         weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
                          initialized from :math:`\mathcal{N}(0, 1)`.
 
-    Inputs: :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
+    Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
+        :attr:`per_index_weights` (Tensor, optional)
 
         - If :attr:`input` is 2D of shape `(B, N)`,
 
@@ -223,6 +233,12 @@ class EmbeddingBag(Module):
           having ``B`` bags. Empty bags (i.e., having 0-length) will have
           returned vectors filled by zeros.
 
+        per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+            to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
+            must have exactly the same shape as input and is treated as having the same
+            :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
+
+
     Output shape: `(B, embedding_dim)`
 
     Examples::
@@ -262,11 +278,12 @@ class EmbeddingBag(Module):
         init.normal_(self.weight)
 
     @weak_script_method
-    def forward(self, input, offsets=None):
-        # type: (Tensor, Optional[Tensor]) -> Tensor
+    def forward(self, input, offsets=None, per_sample_weights=None):
+        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
         return F.embedding_bag(input, self.weight, offsets,
                                self.max_norm, self.norm_type,
-                               self.scale_grad_by_freq, self.mode, self.sparse)
+                               self.scale_grad_by_freq, self.mode, self.sparse,
+                               per_sample_weights)
 
     def extra_repr(self):
         s = '{num_embeddings}, {embedding_dim}'
index 30ec2b9..cecc9ee 100644 (file)
@@ -423,14 +423,18 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
     return g.op("Gather", weight, indices)
 
 
-@parse_args('v', 'v', 'v', 'i', 'i', 'i')
+@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v')
 def embedding_bag(g,
                   embedding_matrix,
                   indices,
                   offsets,
                   scale_grad_by_freq,
                   mode,
-                  sparse):
+                  sparse,
+                  per_sample_weights):
+    if not per_sample_weights.node().mustBeNone():
+        raise RuntimeError('Unsupported: ONNX export of embedding_bag '
+                           'with per_sample_weights')
     return g.op("ATen",
                 embedding_matrix,
                 indices,