Enable half for CUDA dense EmbeddingBag backward. (#19293)
authorRichard Zou <zou3519@gmail.com>
Tue, 16 Apr 2019 15:51:01 +0000 (08:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 15:57:20 +0000 (08:57 -0700)
Summary:
I audited the relevant kernel and saw it accumulates a good deal into float
so it should be fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19293

Differential Revision: D14942274

Pulled By: zou3519

fbshipit-source-id: 36996ba0fbb29fbfb12b27bfe9c0ad1eb012ba3c

aten/src/ATen/native/cuda/EmbeddingBag.cu
test/test_nn.py

index 677d04f..8e0f6e8 100644 (file)
@@ -491,7 +491,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
   dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
 
   auto output = at::empty({num_samples}, grad.options());
-  AT_DISPATCH_FLOATING_TYPES(
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
     grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
       _embedding_bag_per_sample_weights_backward_kernel<scalar_t>
         <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
index c093502..d489e50 100644 (file)
@@ -2381,7 +2381,8 @@ class TestNN(NNTestCase):
         self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
 
         if test_per_sample_weights and trainable_per_sample_weights:
-            self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad)
+            self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
+                             dtype2prec[dtype])
 
     def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
         # check a known test example
@@ -2653,16 +2654,21 @@ class TestNN(NNTestCase):
             expected = self._embedding_bag_reference_impl(
                 input, reference_weights, offsets, mode, ref_per_sample_weights)
             result = es(input, offsets, per_sample_weights)
-            self.assertEqual(result, expected)
+            self.assertEqual(result, expected, prec=dtype2prec[dtype])
 
             grad = torch.randn_like(expected)
             result.backward(grad)
             expected.backward(grad)
-            self.assertEqual(es.weight.grad, reference_weights.grad)
+            self.assertEqual(es.weight.grad, reference_weights.grad,
+                             dtype2prec[dtype])
             if trainable_scale:
-                self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad)
+                self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
+                                 prec=dtype2prec[dtype])
 
-        dtypes = (torch.float, torch.double)
+        if device == 'cuda':
+            dtypes = (torch.float, torch.double, torch.half)
+        else:
+            dtypes = (torch.float, torch.double)
         modes = ('sum',)
         trainable_scale = (True, False)
         for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
@@ -2677,12 +2683,7 @@ class TestNN(NNTestCase):
 
     @staticmethod
     def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
-        dtypes = (torch.float, torch.double)
-        modes = ('sum',)
-        sparsity = (True, False)
-        trainable_scale = (True, False)
-        for dtype, mode, sparse, trainable_per_sample_weights in \
-                itertools.product(dtypes, modes, sparsity, trainable_scale):
+        def run_tests(dtype, mode, sparse, trainable_per_sample_weights):
             kwargs = dict(test_per_sample_weights=True, device=device,
                           mode=mode, dtype=dtype, sparse=sparse,
                           trainable_per_sample_weights=trainable_per_sample_weights)
@@ -2699,6 +2700,24 @@ class TestNN(NNTestCase):
             # Large embedding_dim
             self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
 
+        dtypes = (torch.float, torch.double)
+        modes = ('sum',)
+        sparsity = (True, False)
+        trainable_scale = (True, False)
+        for dtype, mode, sparse, trainable_per_sample_weights in \
+                itertools.product(dtypes, modes, sparsity, trainable_scale):
+            run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
+        # Test CUDA Dense on half precision
+        if device == 'cuda':
+            dtypes = (torch.half,)
+            modes = ('sum',)
+            sparsity = (False,)
+            trainable_scale = (True, False)
+            for dtype, mode, sparse, trainable_per_sample_weights in \
+                    itertools.product(dtypes, modes, sparsity, trainable_scale):
+                run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
     def test_EmbeddingBag_per_sample_weights_and_no_offsets(self):
         self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)