From 6bcff88d3ea6af9c1b9ed99157a40450ba04088c Mon Sep 17 00:00:00 2001 From: bhushan Date: Sun, 10 Mar 2019 15:18:48 -0700 Subject: [PATCH] Fix log_softmax and softmax if any dimension is 0-d (#17651) Summary: - Test added - test_dim_function_empty: softmax and log_softmax on last dimension fixes: #17262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17651 Differential Revision: D14349009 Pulled By: gchanan fbshipit-source-id: b6f728f5c6be8ae7615749e3f0c201886632923e --- aten/src/ATen/native/SoftMax.cpp | 19 +++++++++++++++---- test/test_torch.py | 2 ++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 64d259f..60ba37e 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -14,9 +14,6 @@ template void host_softmax(Tensor output, const Tensor& input, const int64_t dim) { int64_t outer_size = 1; int64_t dim_size = input.size(dim); - if (input.numel() == 0) { - return; - } int64_t inner_size = 1; for (int64_t i = 0; i < dim; ++i) outer_size *= input.size(i); @@ -124,7 +121,11 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_ auto input = input_.contiguous(); Tensor output = at::native::empty_like(input); int64_t dim = maybe_wrap_dim(dim_, input.dim()); - if (input.dim() == 0) + + if (input.numel() == 0) { + return output; + } + if (input.dim() == 0) input = input.view(1); AT_CHECK( dim >= 0 && dim < input.dim(), @@ -144,6 +145,10 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half auto input = input_.contiguous(); Tensor output = at::native::empty_like(input); int64_t dim = maybe_wrap_dim(dim_, input.dim()); + + if (input.numel() == 0) { + return output; + } if (input.dim() == 0) input = input.view(1); AT_CHECK( @@ -171,6 +176,9 @@ Tensor softmax_backward_cpu( auto output = output_.contiguous(); Tensor grad_input = at::native::empty_like(grad); + if (output.numel() == 0) { + return grad_input; + } if (grad.dim() == 0) grad = grad.view(1); if (output.dim() == 0) @@ -200,6 +208,9 @@ Tensor log_softmax_backward_cpu( auto output = output_.contiguous(); Tensor grad_input = at::native::empty_like(grad); + if (output.numel() == 0) { + return grad_input; + } if (grad.dim() == 0) grad = grad.view(1); if (output.dim() == 0) diff --git a/test/test_torch.py b/test/test_torch.py index f39b798..4960b23 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7491,9 +7491,11 @@ class _TestTorchMixin(object): # softmax, logsoftmax self.assertEqual(x, torch.nn.functional.softmax(x, 0)) self.assertEqual(x, torch.nn.functional.softmax(x, 2)) + self.assertEqual(x, torch.nn.functional.softmax(x, 3)) self.assertEqual(x, torch.nn.functional.log_softmax(x, 0)) self.assertEqual(x, torch.nn.functional.log_softmax(x, 2)) + self.assertEqual(x, torch.nn.functional.log_softmax(x, 3)) # cumsum, cumprod self.assertEqual(shape, torch.cumsum(x, 0).shape) -- 2.7.4