EmbeddingBag w/ per_sample_weights CUDA fwd + bwd (#18800)
authorRichard Zou <zou3519@gmail.com>
Wed, 10 Apr 2019 01:08:59 +0000 (18:08 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 01:13:02 +0000 (18:13 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18800
ghimport-source-id: 17f638dea0e1ac9a86ec06b223c60362ed78449c

Reviewed By: cpuhrsch

Differential Revision: D14851422

Pulled By: zou3519

fbshipit-source-id: 27b114e51e66112e4bc9cfc63d1d1ddfa650d347

aten/src/ATen/native/EmbeddingBag.cpp
aten/src/ATen/native/cuda/EmbeddingBag.cu
test/test_nn.py

index ba062b4..526d6b3 100644 (file)
@@ -289,11 +289,6 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
   checkScalarType("embedding_bag", offset2bag_arg, kLong);
   checkContiguous("embedding_bag", offset2bag_arg);
 
-  if (per_sample_weights.defined() &&
-      per_sample_weights.device().type() != DeviceType::CPU) {
-    AT_ERROR("NYI: _embedding_bag_backward: per_sample_weights only supported for CPU");
-  }
-
   if (sparse) {
     return at::_embedding_bag_sparse_backward(
         grad, indices, offsets, offset2bag, bag_size_, num_weights,
index 043a981..fa3bf2b 100644 (file)
@@ -24,13 +24,15 @@ namespace native {
 
 namespace {
 
-// This kernel assumes that all input tensors except `weight` are contiguous.
+// This kernel assumes that all input tensors except `weight` and
+// per_sample_weights are contiguous.
 template <typename scalar_t>
 __global__ void EmbeddingBag_updateOutputKernel(
     int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
     int64_t *offset2bag, int64_t numIndices, int64_t numBags,
     int64_t featureSize, int64_t weight_stide0, int64_t weight_stride1,
-    int mode, int64_t *bag_size, int64_t *max_indices) {
+    int mode, int64_t *bag_size, int64_t *max_indices,
+    scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
 
   // the strategy here is that each bag x feature is handled by a single thread
 
@@ -64,7 +66,13 @@ __global__ void EmbeddingBag_updateOutputKernel(
             maxWord = input[emb];
           }
         } else {
-          weightFeatSum += static_cast<accscalar_t>(weightValue);
+          if (per_sample_weights) {
+            accscalar_t scaleWeightBy = static_cast<accscalar_t>(
+                per_sample_weights[emb * per_sample_weights_stride]);
+            weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
+          } else {
+            weightFeatSum += static_cast<accscalar_t>(weightValue);
+          }
         }
 
         bag_size_++;
@@ -106,7 +114,8 @@ template <typename scalar_t>
 __global__ void EmbeddingBag_accGradParametersKernel_sum_avg(
     int64_t *input, int64_t *indices, scalar_t *gradOutput,
     scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
-    int64_t stride, int mode, const int64_t *bag_size) {
+    int64_t stride, int mode, const int64_t *bag_size,
+               scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
 
   using accscalar_t = acc_type<scalar_t, true>;
   int idx = blockIdx.x * 4 + threadIdx.y;
@@ -134,7 +143,10 @@ __global__ void EmbeddingBag_accGradParametersKernel_sum_avg(
       const int seq_number = offset2bag[origRow];
       const int gradOutputRow = ((int)seq_number) * stride;
 
-      const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+      accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
+                       if (per_sample_weights) {
+                               scale *= per_sample_weights[origRow * per_sample_weights_stride];
+                       }
 
       accscalar_t gradient[SZ];
       accscalar_t weight[SZ];
@@ -179,7 +191,8 @@ Tensor embedding_bag_backward_cuda_sum_avg(
                                    const Tensor &offset2bag,
                                    const Tensor &bag_size,
                                    int64_t num_weights,
-                                   bool scale_grad_by_freq, int64_t mode) {
+                                   bool scale_grad_by_freq, int64_t mode,
+                                   const Tensor& per_sample_weights) {
 
   auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
 
@@ -255,7 +268,9 @@ Tensor embedding_bag_backward_cuda_sum_avg(
             grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
             offset2bag.data<int64_t>(),
             count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
-            mode, bag_size.data<int64_t>());
+            mode, bag_size.data<int64_t>(),
+            per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+            per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
       });
 
   THCudaCheck(cudaGetLastError());
@@ -331,9 +346,6 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
   checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
   checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
 
-  AT_CHECK(!per_sample_weights.defined(),
-      "NYI: embedding_bag: CUDA per_sample_weights (see issue #4068)");
-
   int64_t numIndices = indices.size(0);
   int64_t numBags = offsets.size(0);
   int64_t featureSize = weight.size(1);
@@ -363,7 +375,9 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
         weight.data<scalar_t>(), output.data<scalar_t>(),
         offset2bag.data<int64_t>(), numIndices, numBags, featureSize,
         weight.stride(0), weight.stride(1), mode, bag_size.data<int64_t>(),
-        mode == MODE_MAX ? max_indices.data<int64_t>() : NULL);
+        mode == MODE_MAX ? max_indices.data<int64_t>() : NULL,
+        per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
+        per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
   });
 
   THCudaCheck(cudaGetLastError());
@@ -391,14 +405,16 @@ Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &ind
   checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg);
   checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
 
-  AT_ASSERT(!per_sample_weights.defined());
 
   switch (mode) {
     case MODE_SUM:
     case MODE_MEAN:
-      return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode);
+      if (mode == MODE_MEAN)
+        AT_ASSERT(!per_sample_weights.defined());
+      return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights);
 
     case MODE_MAX:
+      AT_ASSERT(!per_sample_weights.defined());
       return embedding_bag_backward_cuda_max(grad, max_indices, num_weights);
 
     default:
index 478f6e5..79391ab 100644 (file)
@@ -2589,8 +2589,12 @@ class TestNN(NNTestCase):
         input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
         offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
         per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
-        with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
-            es(input, offsets, per_sample_weights)
+        if device == 'cpu':
+            with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
+                es(input, offsets, per_sample_weights)
+        else:
+            with self.assertRaisesRegex(RuntimeError, 'expected scalar type'):
+                es(input, offsets, per_sample_weights)
 
         # Failure 2.1: input/per_sample_weights have different sizes (1d input)
         input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
@@ -2620,6 +2624,10 @@ class TestNN(NNTestCase):
     def test_EmbeddingBag_per_sample_weights_failures(self):
         self._test_EmbeddingBag_per_sample_weights_failures(self)
 
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_EmbeddingBag_per_sample_weights_failures_cuda(self):
+        self._test_EmbeddingBag_per_sample_weights_failures(self, device='cuda')
+
     @staticmethod
     def _test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cpu'):
         def test_per_sample_weights(mode, dtype):
@@ -2649,6 +2657,10 @@ class TestNN(NNTestCase):
     def test_EmbeddingBag_per_sample_weights_and_offsets(self):
         self._test_EmbeddingBag_per_sample_weights_and_offsets(self)
 
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_EmbeddingBag_per_sample_weights_and_offsets_cuda(self):
+        self._test_EmbeddingBag_per_sample_weights_and_offsets(self, device='cuda')
+
     @staticmethod
     def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
         dtypes = (torch.float, torch.double)
@@ -2674,6 +2686,10 @@ class TestNN(NNTestCase):
         self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_EmbeddingBag_per_sample_weights_and_no_offsets_cuda(self):
+        self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cuda')
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @repeat_test_for_types(ALL_TENSORTYPES)
     def test_embedding_bag_cuda(self, dtype=torch.float):
         self._test_EmbeddingBag(True, 'sum', False, dtype)