From cb23976f9f304a6db62b612c83aae371a077031f Mon Sep 17 00:00:00 2001 From: Sameer Deshmukh Date: Fri, 13 Aug 2021 07:31:42 -0700 Subject: [PATCH] Allow 0-dim batch sizes for AdaptiveMaxPool and MaxPool. (#62088) Summary: This issue fixes a part of https://github.com/pytorch/pytorch/issues/12013, which is summarized concretely in https://github.com/pytorch/pytorch/issues/38115. This PR allows `MaxPool` and `AdaptiveMaxPool` to accept tensors whose batch size is 0. Some changes have been made to modernize the tests so that they will show the name of C++ function that throws an error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62088 Reviewed By: bdhirsh Differential Revision: D30281285 Pulled By: jbschlosser fbshipit-source-id: 52bffc67bfe45a78e11e4706b62cce1469eba1b9 --- aten/src/ATen/native/AdaptiveMaxPooling2d.cpp | 22 ++++----- aten/src/ATen/native/AdaptiveMaxPooling3d.cpp | 18 ++++---- aten/src/ATen/native/AveragePool3d.cpp | 4 +- aten/src/ATen/native/DilatedMaxPool3d.cpp | 6 ++- aten/src/ATen/native/MaxPooling.cpp | 3 +- aten/src/ATen/native/Pool.h | 21 ++++++--- aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu | 7 +++ aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu | 6 +++ aten/src/ATen/native/cuda/DilatedMaxPool2d.cu | 6 +++ aten/src/ATen/native/cuda/DilatedMaxPool3d.cu | 23 +++++++--- test/test_nn.py | 55 ++++++++++++++++++++++- 11 files changed, 131 insertions(+), 40 deletions(-) diff --git a/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp index a25b88f..bc9bc60 100644 --- a/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp +++ b/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp @@ -6,18 +6,19 @@ namespace at { namespace meta { TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_size) { - for (int64_t i = 0; i < input.ndimension(); i++) { + int ndim = input.ndimension(); + TORCH_CHECK(ndim == 3 || ndim == 4, + "adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: ", + input.sizes()); + for (int64_t i = 1; i < ndim; i++) { TORCH_CHECK(input.size(i) > 0, - "adaptive_max_pool2d: expected input to have non-empty spatial dimensions, " + "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being empty"); } - TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - TORCH_CHECK(output_size.size() == 2, - "adaptive_max_pool2d: internal error: output_size.size() must be 2"); + "adaptive_max_pool2d(): internal error: output_size.size() must be 2"); int dimH = 1; int64_t sizeB = 1; @@ -48,16 +49,15 @@ TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_si TORCH_META_FUNC(adaptive_max_pool2d_backward) (const Tensor& grad_output, const Tensor& input, const Tensor& indices) { int64_t ndim = grad_output.ndimension(); - for (int64_t i = 0; i < ndim; i++) { + TORCH_CHECK(ndim == 3 || ndim == 4, + "adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: ", grad_output.sizes()); + for (int64_t i = 1; i < ndim; i++) { TORCH_CHECK(grad_output.size(i) > 0, - "adaptive_max_pooling2d_backward(): expected grad_output to have non-empty spatial dimensions, " + "adaptive_max_pooling2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, " "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being empty"); } - TORCH_CHECK((ndim == 3 || ndim == 4), - "non-empty 3D or 4D (batch mode) tensor expected for grad_output"); - TORCH_CHECK(input.dtype() == grad_output.dtype(), "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp index a59d46a..257670f 100644 --- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp @@ -7,10 +7,14 @@ namespace at { namespace meta { TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_size) { - for (int64_t i = 0; i < input.ndimension(); i++) { + auto ndim = input.ndimension(); + TORCH_CHECK( + ndim == 4 || ndim == 5, + "adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: ", input.sizes()); + for (int64_t i = 1; i < ndim; i++) { TORCH_CHECK( input.size(i) > 0, - "adaptive_max_pool3d: expected input to have non-empty spatial dimensions, " + "adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, " "but input has sizes ", input.sizes(), " with dimension ", @@ -20,18 +24,14 @@ TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_si } TORCH_CHECK( - (input.ndimension() == 4 || input.ndimension() == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input"); - - TORCH_CHECK( output_size.size() == 3, - "adaptive_max_pool3d: internal error: output_size.size() must be 3"); + "adaptive_max_pool3d(): internal error: output_size.size() must be 3"); int dimD = 0; int64_t sizeB = 1; int64_t sizeD = 0; - if (input.ndimension() == 5) { + if (ndim == 5) { sizeB = input.size(0); dimD++; } @@ -44,7 +44,7 @@ TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_si int64_t osizeW = output_size[2]; /* resize output */ - if (input.ndimension() == 4) { + if (ndim == 4) { set_output(0, {sizeD, osizeT, osizeH, osizeW}, input.options()); /* indices will contain max input locations for each output point */ set_output(1, {sizeD, osizeT, osizeH, osizeW}, input.options().dtype(kLong)); diff --git a/aten/src/ATen/native/AveragePool3d.cpp b/aten/src/ATen/native/AveragePool3d.cpp index 674fffe..658936f 100644 --- a/aten/src/ATen/native/AveragePool3d.cpp +++ b/aten/src/ATen/native/AveragePool3d.cpp @@ -66,6 +66,7 @@ TORCH_META_FUNC(avg_pool3d) ( 1, 1, 1, itime, iheight, iwidth, otime, oheight, owidth, + "avg_pool3d()", /*check_input_size=*/ true); /* resize output */ @@ -131,7 +132,8 @@ TORCH_META_FUNC(avg_pool3d_backward) ( dT, dH, dW, padT, padH, padW, itime, iheight, iwidth, - otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check); + otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check, + "avg_pool3d_backward()"); /* resize output */ set_output(0, input.sizes(), input.options()); diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp index 155f7e8..21398c0 100644 --- a/aten/src/ATen/native/DilatedMaxPool3d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp @@ -195,7 +195,8 @@ void max_pool3d_with_indices_out_cpu_template( pT, pH, pW, dilationT, dilationH, dilationW, itime, iheight, iwidth, - otime, oheight, owidth); + otime, oheight, owidth, + "max_pool3d_with_indices_out_cpu_template()"); /* get contiguous input */ Tensor input = input_.contiguous(); @@ -413,7 +414,8 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template( pT, pH, pW, dilationT, dilationH, dilationW, itime, iheight, iwidth, - otime, oheight, owidth); + otime, oheight, owidth, + "max_pool3d_with_indices_backward_out_cpu_template()"); /* backprop */ if (input.ndimension() == 4) /* non-batch mode*/ diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp index 682af63..53f2a10 100644 --- a/aten/src/ATen/native/MaxPooling.cpp +++ b/aten/src/ATen/native/MaxPooling.cpp @@ -23,8 +23,7 @@ Tensor max_pool1d_impl( TORCH_CHECK( self.dim() == 2 || self.dim() == 3, - "max_pool1d() input tensor must have 2 or 3 dimensions but got ", - self.dim()); + "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sizes()); TORCH_CHECK( kernel_size.size() == 1, "max_pool1d() kernel_size must be an int or int list of size 1 but got size ", diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 2aceb4d..5fe979d 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -191,6 +191,7 @@ pool3d_shape_check( int dilationT, int dilationH, int dilationW, int64_t itime, int64_t iheight, int64_t iwidth, int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name, bool check_input_size=false) { const int64_t ndim = input.ndimension(); @@ -205,8 +206,14 @@ pool3d_shape_check( "dilation should be greater than zero, but got ", "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW); - TORCH_CHECK(input.numel() > 0 && (ndim == 4 || ndim == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input, but got ndim: ", ndim); + TORCH_CHECK(ndim == 4 || ndim == 5, + fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes()); + + for (int64_t i = 1; i < ndim; ++i) { + TORCH_CHECK(input.size(i) > 0, + fn_name, "Expected input to have non-zero size for non-batch dimensions, but got", + input.sizes(), " with dimension ", i, " being empty."); + } if (check_input_size) { // AveragePool3d TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW, @@ -237,7 +244,8 @@ max_pool3d_backward_shape_check( int pT, int pH, int pW, int dilationT, int dilationH, int dilationW, int64_t itime, int64_t iheight, int64_t iwidth, - int64_t otime, int64_t oheight, int64_t owidth) + int64_t otime, int64_t oheight, int64_t owidth, + const char* fn_name) { const int64_t ndim = input.ndimension(); @@ -249,7 +257,7 @@ max_pool3d_backward_shape_check( pT, pH, pW, dilationT, dilationH, dilationW, itime, iheight, iwidth, - otime, oheight, owidth); + otime, oheight, owidth, fn_name); check_dim_size(gradOutput, ndim, ndim-4, nslices); check_dim_size(gradOutput, ndim, ndim-3, otime); @@ -271,7 +279,8 @@ avg_pool3d_backward_shape_check( int dT, int dH, int dW, int pT, int pH, int pW, int64_t itime, int64_t iheight, int64_t iwidth, - int64_t otime, int64_t oheight, int64_t owidth) + int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name) { const int64_t ndim = input.ndimension(); @@ -284,7 +293,7 @@ avg_pool3d_backward_shape_check( 1, 1, 1, itime, iheight, iwidth, otime, oheight, owidth, - true); + fn_name, true); check_dim_size(gradOutput, ndim, ndim-4, nslices); check_dim_size(gradOutput, ndim, ndim-3, otime); diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu index baba5a0..57c9244 100644 --- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu @@ -204,6 +204,9 @@ const Tensor& indices) { checkAllSameGPU( __func__, {output_arg, indices_arg, input_arg}); + if (input.numel() == 0) { + return; + } int64_t osizeH = output_size[0]; int64_t osizeW = output_size[1]; @@ -312,6 +315,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) __func__, {grad_input_arg, grad_output_arg, input_arg, indices_arg}); + if (gradOutput.numel() == 0) { + return; + } + bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu index 591c339..af09268 100644 --- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu @@ -305,6 +305,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cuda) checkAllSameGPU( __func__, {output_arg, indices_arg, input_arg}); + if (input.numel() == 0) { + return; + } int64_t osizeT = output_size[0]; int64_t osizeH = output_size[1]; @@ -380,6 +383,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cuda) checkAllSameGPU( __func__, {grad_input_arg, grad_output_arg, input_arg, indices_arg}); + if (gradOutput.numel() == 0) { + return; + } const Tensor gradOutput_ = gradOutput.contiguous(); diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu index cb77158..4451a78 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu @@ -301,6 +301,9 @@ const Tensor& indices) { TensorArg input_arg{ input_, "input_", 3 }; checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); + if (output.numel() == 0) { + return; + } const int kH = safe_downcast(kernel_size[0]); const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); @@ -424,6 +427,9 @@ const Tensor& gradInput) { checkAllSameGPU(__func__, {gradInput_arg, gradOutput_arg, input_arg, indices_arg}); + if (gradOutput_.numel() == 0) { + return; + } const int kH = safe_downcast(kernel_size[0]); const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index a0e5ecb..36504f6 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -229,9 +229,6 @@ void max_pool3d_with_indices_out_cuda_template( const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast(dilation[1]); const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast(dilation[2]); - TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input"); - const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; const int64_t nslices = input.size(-4); const int64_t itime = input.size(-3); @@ -250,7 +247,8 @@ void max_pool3d_with_indices_out_cuda_template( pT, pH, pW, dilationT, dilationH, dilationW, itime, iheight, iwidth, - otime, oheight, owidth); + otime, oheight, owidth, + "max_pool3d_with_indices_out_cuda_template()"); if (input.ndimension() == 4) { output.resize_({ nslices, otime, oheight, owidth}); @@ -261,6 +259,10 @@ void max_pool3d_with_indices_out_cuda_template( indices.resize_({nbatch, nslices, otime, oheight, owidth}); } + if (input.numel() == 0) { + return; + } + Tensor work_input = input.contiguous(); Tensor work_output = output; Tensor work_indices = indices; @@ -338,10 +340,12 @@ void max_pool3d_with_indices_backward_out_cuda_template( const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast(dilation[2]); TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input"); + "max_pool2d_with_indices_backward_out_cuda_template(): ", + "Expected 4D or 5D input tensor, but got ", input.sizes()); TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5), - "non-empty 4D or 5D (batch mode) tensor expected for gradOutput"); + "max_pool2d_with_indices_backward_out_cuda_template(): ", + "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes()); // Resize and initialize result tensor. gradInput.resize_as_(input); @@ -368,7 +372,12 @@ void max_pool3d_with_indices_backward_out_cuda_template( pT, pH, pW, dilationT, dilationH, dilationW, itime, iheight, iwidth, - otime, oheight, owidth); + otime, oheight, owidth, + "max_pool3d_with_indices_backward_out_cuda_template()"); + + if (gradOutput.numel() == 0) { + return; + } Tensor work_grad_input = gradInput; Tensor work_grad_output = gradOutput.contiguous(); diff --git a/test/test_nn.py b/test/test_nn.py index 75c4d43..5784779 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13355,6 +13355,57 @@ class TestNNDeviceType(NNTestCase): unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) unfold(inp) + @onlyOnCPUAndCUDA + def test_MaxPool_zero_batch_dim(self, device): + inp = torch.randn(0, 16, 50, device=device) + mod = torch.nn.MaxPool1d(3, stride=2).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + # 1D is supposed to be okay with 0 numel() inputs so dont test + # error raising for that case. + + inp = torch.randn(0, 16, 50, 32, device=device) + mod = torch.nn.MaxPool2d(3, stride=2).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Expected"): + inp = torch.randn(1, 0, 50, 32, device=device) + mod(inp) + + inp = torch.ones(0, 16, 50, 44, 31, device=device) + mod = torch.nn.MaxPool3d(3, stride=2).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Expected"): + inp = torch.ones(1, 0, 50, 44, 31, device=device) + mod(inp) + + @onlyOnCPUAndCUDA + def test_AdaptiveMaxPool_zero_batch_dim(self, device): + inp = torch.randn(0, 16, 50, device=device) + mod = torch.nn.AdaptiveMaxPool1d(3).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Expected"): + inp = torch.randn(1, 0, 50, device=device) + mod(inp) + + inp = torch.randn(0, 16, 50, 32, device=device) + mod = torch.nn.AdaptiveMaxPool2d(3).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Expected"): + inp = torch.randn(1, 0, 50, 32, device=device) + mod(inp) + + inp = torch.ones(0, 16, 50, 44, 31, device=device) + mod = torch.nn.AdaptiveMaxPool3d(3).to(device) + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Expected"): + inp = torch.ones(1, 0, 50, 44, 31, device=device) + mod(inp) + @onlyCUDA @dtypes(torch.float, torch.double) @tf32_on_and_off(0.005) @@ -13852,8 +13903,8 @@ class TestNNDeviceType(NNTestCase): model(torch.tensor(x, device=device, dtype=dtype)) # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) - check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0") - check([], (1,), "input tensor must have 2 or 3 dimensions but got 1") + check(0, (1,), "Expected 2D or 3D input tensor, but got") + check([], (1,), "Expected 2D or 3D input tensor, but got") check([[]], (1, 0), "stride must be greater than zero, but got 0") check([[]], (1, 1, -1), "padding must be non-negative, but got -1") check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1") -- 2.7.4