From 7205ca02107059443bddd14322d2b9ed8562c60b Mon Sep 17 00:00:00 2001 From: Sameer Deshmukh Date: Wed, 8 Sep 2021 08:40:01 -0700 Subject: [PATCH] Change MaxUnpool to accept tensors with 0-dim batch sizes. (#64082) Summary: Part of the fix for https://github.com/pytorch/pytorch/issues/38115. Changes the `MaxUnpool` module to work with 0-dimensions batch sizes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64082 Reviewed By: mrshenli Differential Revision: D30793907 Pulled By: jbschlosser fbshipit-source-id: d21aa665be5aa18f592b39ef7b4e3cbc632e21ed --- aten/src/ATen/native/MaxUnpooling.cpp | 38 +++++++++++++----- aten/src/ATen/native/cuda/MaxUnpooling.cu | 67 ++++++++++++++++++++----------- test/test_nn.py | 34 ++++++++++++++++ 3 files changed, 107 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp index 9987408..ec96601 100644 --- a/aten/src/ATen/native/MaxUnpooling.cpp +++ b/aten/src/ATen/native/MaxUnpooling.cpp @@ -25,7 +25,11 @@ Tensor& max_unpooling2d_forward_out_cpu( self_.sizes() == indices_.sizes(), "Shape of indices should match shape of input"); - TORCH_CHECK(self_.numel() > 0, "Input must be non-empty"); + for (int64_t i = 1; i < self_.ndimension(); ++i) { + TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ", + "Expected input to have non-zero size for non-batch dimensions, but got ", + self_.sizes(), " with dimension ", i , " being empty."); + } auto memory_format = self_.suggest_memory_format(); auto self = self_.contiguous(memory_format); @@ -41,7 +45,10 @@ Tensor& max_unpooling2d_forward_out_cpu( } output.zero_(); - max_unpool2d_kernel(kCPU, output, self, indices); + if (output.numel() != 0) { + max_unpool2d_kernel(kCPU, output, self, indices); + } + return output; }; @@ -60,7 +67,8 @@ static void max_unpooling3d_shape_check( const Tensor& indices, IntArrayRef output_size, IntArrayRef stride, - IntArrayRef padding) { + IntArrayRef padding, + const char *fn_name) { int64_t oT = output_size[0]; int64_t oH = output_size[1]; int64_t oW = output_size[2]; @@ -84,7 +92,11 @@ static void max_unpooling3d_shape_check( input.sizes() == indices.sizes(), "Shape of indices should match shape of input"); - TORCH_CHECK(input.numel() > 0, "Input must be non-empty"); + for (int64_t i = 1; i < input.ndimension(); ++i) { + TORCH_CHECK(input.size(i) > 0, fn_name, + ": Expected input to have non-zero size for non-batch dimensions, but got ", + input.sizes(), " with dimension ", i , " being empty."); + } TORCH_CHECK( stride[0] > 0 && stride[1] > 0 && stride[2] > 0, @@ -144,7 +156,7 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_, auto indices = indices_.contiguous(); max_unpooling3d_shape_check( - self_, Tensor(), indices_, output_size, stride, padding); + self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()"); if (self_.ndimension() == 5) { output.resize_({self.size(0), self.size(1), oT, oH, oW}); @@ -152,8 +164,10 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_, output.resize_({self.size(0), oT, oH, oW}); } output.zero_(); + if (output.numel() != 0) { + max_unpool3d_kernel(kCPU, output, self, indices); + } - max_unpool3d_kernel(kCPU, output, self, indices); return output; } @@ -207,7 +221,10 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, grad_output.size(dimw)); } - max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices); + if (grad_input.numel() != 0) { + max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices); + } + return grad_input; } @@ -240,7 +257,7 @@ Tensor& max_unpooling3d_backward_out_cpu( int64_t dimw = ndim == 4 ? 3 : 4; max_unpooling3d_shape_check( - self, grad_output_, indices_, output_size, stride, padding); + self, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cpu()"); /* get contiguous gradOutput */ auto grad_output = grad_output_.contiguous(); @@ -266,7 +283,10 @@ Tensor& max_unpooling3d_backward_out_cpu( grad_output.size(dimw)); } - max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices); + if (grad_input.numel() != 0) { + max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices); + } + return grad_input; } diff --git a/aten/src/ATen/native/cuda/MaxUnpooling.cu b/aten/src/ATen/native/cuda/MaxUnpooling.cu index e67f8e7..7c6d746 100644 --- a/aten/src/ATen/native/cuda/MaxUnpooling.cu +++ b/aten/src/ATen/native/cuda/MaxUnpooling.cu @@ -114,7 +114,11 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_, checkAllSameGPU( "max_unpooling2d_forward_out_cuda", {output_arg, self_arg, indices_arg}); - TORCH_CHECK(self_.numel() > 0, "Input must be non-empty tensor"); + for (int64_t i = 1; i < self_.ndimension(); ++i) { + TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cuda(): ", + "Expected input to have non-zero size for non-batch dimensions, but got ", + self_.sizes(), " with dimension ", i , " being empty."); + } TORCH_CHECK( (self_.ndimension() == 3 || self_.ndimension() == 4), @@ -152,24 +156,26 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_, output.zero_(); auto count = self.numel(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, - self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] { - max_unpooling2d_forward_kernel<<< - GET_BLOCKS(count), - CUDA_NUM_THREADS, - 0, - at::cuda::getCurrentCUDAStream()>>>( - self.numel(), - self.data_ptr(), - indices.data_ptr(), - numChannels, - inputHeight, - inputWidth, - oheight, - owidth, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - })); + if (count != 0) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, + self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] { + max_unpooling2d_forward_kernel<<< + GET_BLOCKS(count), + CUDA_NUM_THREADS, + 0, + at::cuda::getCurrentCUDAStream()>>>( + self.numel(), + self.data_ptr(), + indices.data_ptr(), + numChannels, + inputHeight, + inputWidth, + oheight, + owidth, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + })); + } if (self.ndimension() == 3) { output.resize_({numChannels, oheight, owidth}); } @@ -191,7 +197,8 @@ static void max_unpooling3d_shape_check( const Tensor& indices, IntArrayRef output_size, IntArrayRef stride, - IntArrayRef padding) { + IntArrayRef padding, + const char *fn_name) { int64_t oT = output_size[0]; int64_t oH = output_size[1]; int64_t oW = output_size[2]; @@ -215,7 +222,11 @@ static void max_unpooling3d_shape_check( input.sizes() == indices.sizes(), "Shape of indices should match shape of input"); - TORCH_CHECK(input.numel() > 0, "Input must be non-empty"); + for (int64_t i = 1; i < input.ndimension(); ++i) { + TORCH_CHECK(input.size(i) > 0, fn_name, + ": Expected input to have non-zero size for non-batch dimensions, but got ", + input.sizes(), " with dimension ", i , " being empty."); + } TORCH_CHECK( stride[0] > 0 && stride[1] > 0 && stride[2] > 0, @@ -268,7 +279,7 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_, Tensor& output) { TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); max_unpooling3d_shape_check( - self_, Tensor(), indices_, output_size, stride, padding); + self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cuda()"); int64_t oT = output_size[0]; int64_t oH = output_size[1]; @@ -318,6 +329,10 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_, indices.size(4)}); } + if (self.numel() == 0) { + return output; + } + int totalZ = inputTime * inputSlices * batchSize; int offsetZ = 0; dim3 block(32, 8); @@ -426,6 +441,9 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_, grad_input.zero_(); int64_t count = self.numel(); + if (count == 0) { + return grad_input; + } AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "max_unpooling2d_backward_kernel", ([&] { @@ -471,7 +489,7 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_, int64_t oW = output_size[2]; max_unpooling3d_shape_check( - self_, grad_output_, indices_, output_size, stride, padding); + self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()"); int batchSize = 0; int inputSlices = 0; @@ -521,6 +539,9 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_, indices.size(3), indices.size(4)}); } + if (grad_input.numel() == 0) { + return grad_input; + } int totalZ = inputTime * inputSlices * batchSize; int offsetZ = 0; diff --git a/test/test_nn.py b/test/test_nn.py index 2d66477..cc702df 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13764,6 +13764,40 @@ class TestNNDeviceType(NNTestCase): mod(inp) @onlyOnCPUAndCUDA + def test_MaxUnpool_zero_batch_dim(self, device): + pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device) + unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device) + inp = torch.randn(0, 10, 10, requires_grad=True, device=device) + output, indices = pool(inp) + output.requires_grad_(True) + unpool_out = unpool(output, indices) + unpool_out.sum().backward() + + self.assertEqual(inp.grad, torch.zeros_like(inp)) + self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) + + pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device) + unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device) + inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device) + output, indices = pool(inp) + unpool_out = unpool(output, indices) + unpool_out.sum().backward() + + self.assertEqual(inp.grad, torch.zeros_like(inp)) + self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) + + pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device) + unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device) + inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device) + output, indices = pool(inp) + output.requires_grad_(True) + unpool_out = unpool(output, indices) + unpool_out.sum().backward() + + self.assertEqual(inp.grad, torch.zeros_like(inp)) + self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) + + @onlyOnCPUAndCUDA def test_AdaptiveMaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) mod = torch.nn.AdaptiveMaxPool1d(3).to(device) -- 2.7.4