Fix cuda softmax backward with empty input (#17259)
authorTongzhou Wang <tongzhou.wang.1994@gmail.com>
Wed, 20 Feb 2019 00:33:16 +0000 (16:33 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 20 Feb 2019 00:41:52 +0000 (16:41 -0800)
Summary:
Fixes #17256
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17259

Differential Revision: D14142196

Pulled By: soumith

fbshipit-source-id: 1f2dc202951b59b43da27684f9f924314bcd3040

aten/src/ATen/native/cuda/SoftMax.cu
test/test_nn.py

index 649fcb9..b08cad5 100644 (file)
@@ -550,8 +550,11 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
 template<template<typename, typename, typename> class Epilogue>
 Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float){
   int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
+  Tensor gI = half_to_float ? at::empty_like(grad_, grad_.options().dtype(ScalarType::Half)) : at::empty_like(grad_);
+  if (grad_.numel() == 0) {
+    return gI;
+  }
   auto grad = grad_.contiguous();
-  Tensor gI = half_to_float ? at::empty_like(grad, grad.options().dtype(ScalarType::Half)) : at::empty_like(grad);
   static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
   if (grad.dim() == 0) grad = grad.view(1);
   AT_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
index 34c0060..9d61732 100644 (file)
@@ -2140,6 +2140,30 @@ class TestNN(NNTestCase):
         # should be bitwise equal
         self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
 
+    def _test_softmax_backward(self, device):
+        if device.type == 'cuda':
+            dtypes = [torch.float]
+            # FIXME: add torch.half after https://github.com/pytorch/pytorch/issues/17261 is fixed
+        else:
+            dtypes = [torch.float]
+        # FIXME: add (10, 0) after https://github.com/pytorch/pytorch/issues/17262 is fixed
+        sizes = [(0, 10), (32, 20)]
+        for fn in [F.softmax, F.log_softmax]:
+            for dtype in dtypes:
+                for size in sizes:
+                    input = torch.rand(size, device=device, dtype=dtype, requires_grad=True)
+                    output = fn(input, dtype=torch.float, dim=1).sum()
+                    grad_input, = torch.autograd.grad(output, input, create_graph=True)
+                    grad_input.sum().backward()
+
+    def test_softmax_backward(self):
+        self._test_softmax_backward(torch.device('cpu'))
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    @skipIfRocm
+    def test_softmax_backward_cuda(self):
+        self._test_softmax_backward(torch.device('cuda'))
+
     def _test_gumbel_softmax_st_shapes(self, cuda, dtype, shape, dim, count_expected):
         logits = torch.randn(shape, dtype=torch.float)
         logits = logits.to(dtype)