Fix log_softmax and softmax if any dimension is 0-d (#17651)
authorbhushan <bhushan.s.94@gmail.com>
Sun, 10 Mar 2019 22:18:48 +0000 (15:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 10 Mar 2019 22:25:58 +0000 (15:25 -0700)
Summary:
- Test added
- test_dim_function_empty: softmax and log_softmax on last dimension

fixes: #17262
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17651

Differential Revision: D14349009

Pulled By: gchanan

fbshipit-source-id: b6f728f5c6be8ae7615749e3f0c201886632923e

aten/src/ATen/native/SoftMax.cpp
test/test_torch.py

index 64d259f..60ba37e 100644 (file)
@@ -14,9 +14,6 @@ template <typename scalar_t, bool LogSoftMax>
 void host_softmax(Tensor output, const Tensor& input, const int64_t dim) {
   int64_t outer_size = 1;
   int64_t dim_size = input.size(dim);
-  if (input.numel() == 0) {
-    return;
-  }
   int64_t inner_size = 1;
   for (int64_t i = 0; i < dim; ++i)
     outer_size *= input.size(i);
@@ -124,7 +121,11 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_
   auto input = input_.contiguous();
   Tensor output = at::native::empty_like(input);
   int64_t dim = maybe_wrap_dim(dim_, input.dim());
-  if (input.dim() == 0)
+
+  if (input.numel() == 0) {
+    return output;
+  }
+ if (input.dim() == 0)
     input = input.view(1);
   AT_CHECK(
       dim >= 0 && dim < input.dim(),
@@ -144,6 +145,10 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half
   auto input = input_.contiguous();
   Tensor output = at::native::empty_like(input);
   int64_t dim = maybe_wrap_dim(dim_, input.dim());
+
+  if (input.numel() == 0) {
+    return output;
+  }
   if (input.dim() == 0)
     input = input.view(1);
   AT_CHECK(
@@ -171,6 +176,9 @@ Tensor softmax_backward_cpu(
   auto output = output_.contiguous();
   Tensor grad_input = at::native::empty_like(grad);
 
+  if (output.numel() == 0) {
+    return grad_input;
+  }
   if (grad.dim() == 0)
     grad = grad.view(1);
   if (output.dim() == 0)
@@ -200,6 +208,9 @@ Tensor log_softmax_backward_cpu(
   auto output = output_.contiguous();
   Tensor grad_input = at::native::empty_like(grad);
 
+  if (output.numel() == 0) {
+    return grad_input;
+  }
   if (grad.dim() == 0)
     grad = grad.view(1);
   if (output.dim() == 0)
index f39b798..4960b23 100644 (file)
@@ -7491,9 +7491,11 @@ class _TestTorchMixin(object):
             # softmax, logsoftmax
             self.assertEqual(x, torch.nn.functional.softmax(x, 0))
             self.assertEqual(x, torch.nn.functional.softmax(x, 2))
+            self.assertEqual(x, torch.nn.functional.softmax(x, 3))
 
             self.assertEqual(x, torch.nn.functional.log_softmax(x, 0))
             self.assertEqual(x, torch.nn.functional.log_softmax(x, 2))
+            self.assertEqual(x, torch.nn.functional.log_softmax(x, 3))
 
             # cumsum, cumprod
             self.assertEqual(shape, torch.cumsum(x, 0).shape)