From 3b29cbaf861d5daaff7c7506b62a252b9c7d5614 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 16 Apr 2019 08:51:01 -0700 Subject: [PATCH] Enable half for CUDA dense EmbeddingBag backward. (#19293) Summary: I audited the relevant kernel and saw it accumulates a good deal into float so it should be fine. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19293 Differential Revision: D14942274 Pulled By: zou3519 fbshipit-source-id: 36996ba0fbb29fbfb12b27bfe9c0ad1eb012ba3c --- aten/src/ATen/native/cuda/EmbeddingBag.cu | 2 +- test/test_nn.py | 41 ++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 677d04f..8e0f6e8 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -491,7 +491,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda( dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); auto output = at::empty({num_samples}, grad.options()); - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { _embedding_bag_per_sample_weights_backward_kernel <<>>( diff --git a/test/test_nn.py b/test/test_nn.py index c093502..d489e50 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2381,7 +2381,8 @@ class TestNN(NNTestCase): self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) if test_per_sample_weights and trainable_per_sample_weights: - self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad) + self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, + dtype2prec[dtype]) def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double): # check a known test example @@ -2653,16 +2654,21 @@ class TestNN(NNTestCase): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected) + self.assertEqual(result, expected, prec=dtype2prec[dtype]) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) - self.assertEqual(es.weight.grad, reference_weights.grad) + self.assertEqual(es.weight.grad, reference_weights.grad, + dtype2prec[dtype]) if trainable_scale: - self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad) + self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, + prec=dtype2prec[dtype]) - dtypes = (torch.float, torch.double) + if device == 'cuda': + dtypes = (torch.float, torch.double, torch.half) + else: + dtypes = (torch.float, torch.double) modes = ('sum',) trainable_scale = (True, False) for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): @@ -2677,12 +2683,7 @@ class TestNN(NNTestCase): @staticmethod def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'): - dtypes = (torch.float, torch.double) - modes = ('sum',) - sparsity = (True, False) - trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): + def run_tests(dtype, mode, sparse, trainable_per_sample_weights): kwargs = dict(test_per_sample_weights=True, device=device, mode=mode, dtype=dtype, sparse=sparse, trainable_per_sample_weights=trainable_per_sample_weights) @@ -2699,6 +2700,24 @@ class TestNN(NNTestCase): # Large embedding_dim self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) + dtypes = (torch.float, torch.double) + modes = ('sum',) + sparsity = (True, False) + trainable_scale = (True, False) + for dtype, mode, sparse, trainable_per_sample_weights in \ + itertools.product(dtypes, modes, sparsity, trainable_scale): + run_tests(dtype, mode, sparse, trainable_per_sample_weights) + + # Test CUDA Dense on half precision + if device == 'cuda': + dtypes = (torch.half,) + modes = ('sum',) + sparsity = (False,) + trainable_scale = (True, False) + for dtype, mode, sparse, trainable_per_sample_weights in \ + itertools.product(dtypes, modes, sparsity, trainable_scale): + run_tests(dtype, mode, sparse, trainable_per_sample_weights) + def test_EmbeddingBag_per_sample_weights_and_no_offsets(self): self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self) -- 2.7.4