Allow 0-dim batch sizes for AdaptiveMaxPool and MaxPool. (#62088)
authorSameer Deshmukh <sameer.deshmukh93@gmail.com>
Fri, 13 Aug 2021 14:31:42 +0000 (07:31 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 14:33:17 +0000 (07:33 -0700)
Summary:
This issue fixes a part of https://github.com/pytorch/pytorch/issues/12013, which is summarized concretely in  https://github.com/pytorch/pytorch/issues/38115.

This PR allows `MaxPool` and `AdaptiveMaxPool` to accept tensors whose batch size is 0. Some changes have been made to modernize the tests so that they will show the name of C++ function that throws an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62088

Reviewed By: bdhirsh

Differential Revision: D30281285

Pulled By: jbschlosser

fbshipit-source-id: 52bffc67bfe45a78e11e4706b62cce1469eba1b9

aten/src/ATen/native/AdaptiveMaxPooling2d.cpp
aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
aten/src/ATen/native/AveragePool3d.cpp
aten/src/ATen/native/DilatedMaxPool3d.cpp
aten/src/ATen/native/MaxPooling.cpp
aten/src/ATen/native/Pool.h
aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
aten/src/ATen/native/cuda/DilatedMaxPool2d.cu
aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
test/test_nn.py

index a25b88f..bc9bc60 100644 (file)
@@ -6,18 +6,19 @@
 namespace at {
 namespace meta {
 TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_size) {
-  for (int64_t i = 0; i < input.ndimension(); i++) {
+  int ndim = input.ndimension();
+  TORCH_CHECK(ndim == 3 || ndim == 4,
+              "adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: ",
+              input.sizes());
+  for (int64_t i = 1; i < ndim; i++) {
     TORCH_CHECK(input.size(i) > 0,
-        "adaptive_max_pool2d: expected input to have non-empty spatial dimensions, "
+        "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
         "but input has sizes ", input.sizes(), " with dimension ", i,
         " being empty");
   }
 
-  TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
-      "non-empty 3D or 4D (batch mode) tensor expected for input");
-
   TORCH_CHECK(output_size.size() == 2,
-      "adaptive_max_pool2d: internal error: output_size.size() must be 2");
+      "adaptive_max_pool2d(): internal error: output_size.size() must be 2");
 
   int dimH = 1;
   int64_t sizeB = 1;
@@ -48,16 +49,15 @@ TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_si
 TORCH_META_FUNC(adaptive_max_pool2d_backward)
 (const Tensor& grad_output, const Tensor& input, const Tensor& indices) {
   int64_t ndim = grad_output.ndimension();
-  for (int64_t i = 0; i < ndim; i++) {
+  TORCH_CHECK(ndim == 3 || ndim == 4,
+    "adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: ", grad_output.sizes());
+  for (int64_t i = 1; i < ndim; i++) {
     TORCH_CHECK(grad_output.size(i) > 0,
-      "adaptive_max_pooling2d_backward(): expected grad_output to have non-empty spatial dimensions, "
+      "adaptive_max_pooling2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
       "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i,
       " being empty");
   }
 
-  TORCH_CHECK((ndim == 3 || ndim == 4),
-    "non-empty 3D or 4D (batch mode) tensor expected for grad_output");
-
   TORCH_CHECK(input.dtype() == grad_output.dtype(),
     "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype());
 
index a59d46a..257670f 100644 (file)
@@ -7,10 +7,14 @@
 namespace at {
 namespace meta {
 TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_size) {
-  for (int64_t i = 0; i < input.ndimension(); i++) {
+  auto ndim = input.ndimension();
+  TORCH_CHECK(
+    ndim == 4 || ndim == 5,
+    "adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: ", input.sizes());
+  for (int64_t i = 1; i < ndim; i++) {
     TORCH_CHECK(
         input.size(i) > 0,
-        "adaptive_max_pool3d: expected input to have non-empty spatial dimensions, "
+        "adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
         "but input has sizes ",
         input.sizes(),
         " with dimension ",
@@ -20,18 +24,14 @@ TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_si
   }
 
   TORCH_CHECK(
-      (input.ndimension() == 4 || input.ndimension() == 5),
-      "non-empty 4D or 5D (batch mode) tensor expected for input");
-
-  TORCH_CHECK(
       output_size.size() == 3,
-      "adaptive_max_pool3d: internal error: output_size.size() must be 3");
+      "adaptive_max_pool3d(): internal error: output_size.size() must be 3");
 
   int dimD = 0;
   int64_t sizeB = 1;
   int64_t sizeD = 0;
 
-  if (input.ndimension() == 5) {
+  if (ndim == 5) {
     sizeB = input.size(0);
     dimD++;
   }
@@ -44,7 +44,7 @@ TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_si
   int64_t osizeW = output_size[2];
 
   /* resize output */
-  if (input.ndimension() == 4) {
+  if (ndim == 4) {
     set_output(0, {sizeD, osizeT, osizeH, osizeW}, input.options());
     /* indices will contain max input locations for each output point */
     set_output(1, {sizeD, osizeT, osizeH, osizeW}, input.options().dtype(kLong));
index 674fffe..658936f 100644 (file)
@@ -66,6 +66,7 @@ TORCH_META_FUNC(avg_pool3d) (
     1, 1, 1,
     itime, iheight, iwidth,
     otime, oheight, owidth,
+    "avg_pool3d()",
     /*check_input_size=*/ true);
 
   /* resize output */
@@ -131,7 +132,8 @@ TORCH_META_FUNC(avg_pool3d_backward) (
     dT, dH, dW,
     padT, padH, padW,
     itime, iheight, iwidth,
-    otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check);
+    otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check,
+    "avg_pool3d_backward()");
 
   /* resize output */
   set_output(0, input.sizes(), input.options());
index 155f7e8..21398c0 100644 (file)
@@ -195,7 +195,8 @@ void max_pool3d_with_indices_out_cpu_template(
     pT, pH, pW,
     dilationT, dilationH, dilationW,
     itime, iheight, iwidth,
-    otime, oheight, owidth);
+    otime, oheight, owidth,
+    "max_pool3d_with_indices_out_cpu_template()");
 
   /* get contiguous input */
   Tensor input = input_.contiguous();
@@ -413,7 +414,8 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template(
     pT, pH, pW,
     dilationT, dilationH, dilationW,
     itime, iheight, iwidth,
-    otime, oheight, owidth);
+    otime, oheight, owidth,
+    "max_pool3d_with_indices_backward_out_cpu_template()");
 
   /* backprop */
   if (input.ndimension() == 4) /* non-batch mode*/
index 682af63..53f2a10 100644 (file)
@@ -23,8 +23,7 @@ Tensor max_pool1d_impl(
 
   TORCH_CHECK(
       self.dim() == 2 || self.dim() == 3,
-      "max_pool1d() input tensor must have 2 or 3 dimensions but got ",
-      self.dim());
+      "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sizes());
   TORCH_CHECK(
       kernel_size.size() == 1,
       "max_pool1d() kernel_size must be an int or int list of size 1 but got size ",
index 2aceb4d..5fe979d 100644 (file)
@@ -191,6 +191,7 @@ pool3d_shape_check(
   int dilationT, int dilationH, int dilationW,
   int64_t itime, int64_t iheight, int64_t iwidth,
   int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name,
   bool check_input_size=false)
 {
   const int64_t ndim = input.ndimension();
@@ -205,8 +206,14 @@ pool3d_shape_check(
               "dilation should be greater than zero, but got ",
               "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
 
-  TORCH_CHECK(input.numel() > 0 && (ndim == 4 || ndim == 5),
-              "non-empty 4D or 5D (batch mode) tensor expected for input, but got ndim: ", ndim);
+  TORCH_CHECK(ndim == 4 || ndim == 5,
+              fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
+
+  for (int64_t i = 1; i < ndim; ++i) {
+    TORCH_CHECK(input.size(i) > 0,
+                fn_name, "Expected input to have non-zero size for non-batch dimensions, but got",
+                input.sizes(), " with dimension ", i, " being empty.");
+  }
 
   if (check_input_size) { // AveragePool3d
     TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
@@ -237,7 +244,8 @@ max_pool3d_backward_shape_check(
   int pT, int pH, int pW,
   int dilationT, int dilationH, int dilationW,
   int64_t itime, int64_t iheight, int64_t iwidth,
-  int64_t otime, int64_t oheight, int64_t owidth)
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char* fn_name)
 {
   const int64_t ndim = input.ndimension();
 
@@ -249,7 +257,7 @@ max_pool3d_backward_shape_check(
     pT, pH, pW,
     dilationT, dilationH, dilationW,
     itime, iheight, iwidth,
-    otime, oheight, owidth);
+    otime, oheight, owidth, fn_name);
 
   check_dim_size(gradOutput, ndim, ndim-4, nslices);
   check_dim_size(gradOutput, ndim, ndim-3, otime);
@@ -271,7 +279,8 @@ avg_pool3d_backward_shape_check(
   int dT, int dH, int dW,
   int pT, int pH, int pW,
   int64_t itime, int64_t iheight, int64_t iwidth,
-  int64_t otime, int64_t oheight, int64_t owidth)
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name)
 {
   const int64_t ndim = input.ndimension();
 
@@ -284,7 +293,7 @@ avg_pool3d_backward_shape_check(
     1, 1, 1,
     itime, iheight, iwidth,
     otime, oheight, owidth,
-    true);
+    fn_name, true);
 
   check_dim_size(gradOutput, ndim, ndim-4, nslices);
   check_dim_size(gradOutput, ndim, ndim-3, otime);
index baba5a0..57c9244 100644 (file)
@@ -204,6 +204,9 @@ const Tensor& indices) {
 
   checkAllSameGPU(
       __func__, {output_arg, indices_arg, input_arg});
+  if (input.numel() == 0) {
+    return;
+  }
 
   int64_t osizeH = output_size[0];
   int64_t osizeW = output_size[1];
@@ -312,6 +315,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
       __func__,
       {grad_input_arg, grad_output_arg, input_arg, indices_arg});
 
+  if (gradOutput.numel() == 0) {
+    return;
+  }
+
   bool atomic =
       true; // suboptimal, but without atomic it doesn't pass the tests
 
index 591c339..af09268 100644 (file)
@@ -305,6 +305,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cuda)
 
   checkAllSameGPU(
       __func__, {output_arg, indices_arg, input_arg});
+  if (input.numel() == 0) {
+    return;
+  }
 
   int64_t osizeT = output_size[0];
   int64_t osizeH = output_size[1];
@@ -380,6 +383,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cuda)
   checkAllSameGPU(
       __func__,
       {grad_input_arg, grad_output_arg, input_arg, indices_arg});
+  if (gradOutput.numel() == 0) {
+    return;
+  }
 
   const Tensor gradOutput_ = gradOutput.contiguous();
 
index cb77158..4451a78 100644 (file)
@@ -301,6 +301,9 @@ const Tensor& indices) {
   TensorArg input_arg{ input_, "input_", 3 };
 
   checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg});
+  if (output.numel() == 0) {
+    return;
+  }
 
   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
@@ -424,6 +427,9 @@ const Tensor& gradInput) {
 
   checkAllSameGPU(__func__,
                   {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
+  if (gradOutput_.numel() == 0) {
+    return;
+  }
 
   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
index a0e5ecb..36504f6 100644 (file)
@@ -229,9 +229,6 @@ void max_pool3d_with_indices_out_cuda_template(
   const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
 
-  TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
-    "non-empty 4D or 5D (batch mode) tensor expected for input");
-
   const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
   const int64_t nslices = input.size(-4);
   const int64_t itime = input.size(-3);
@@ -250,7 +247,8 @@ void max_pool3d_with_indices_out_cuda_template(
     pT, pH, pW,
     dilationT, dilationH, dilationW,
     itime, iheight, iwidth,
-    otime, oheight, owidth);
+    otime, oheight, owidth,
+    "max_pool3d_with_indices_out_cuda_template()");
 
   if (input.ndimension() == 4) {
     output.resize_({ nslices, otime, oheight, owidth});
@@ -261,6 +259,10 @@ void max_pool3d_with_indices_out_cuda_template(
     indices.resize_({nbatch, nslices, otime, oheight, owidth});
   }
 
+  if (input.numel() == 0) {
+    return;
+  }
+
   Tensor work_input = input.contiguous();
   Tensor work_output = output;
   Tensor work_indices = indices;
@@ -338,10 +340,12 @@ void max_pool3d_with_indices_backward_out_cuda_template(
   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
 
   TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
-    "non-empty 4D or 5D (batch mode) tensor expected for input");
+    "max_pool2d_with_indices_backward_out_cuda_template(): ",
+    "Expected 4D or 5D input tensor, but got ", input.sizes());
 
   TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5),
-    "non-empty 4D or 5D (batch mode) tensor expected for gradOutput");
+    "max_pool2d_with_indices_backward_out_cuda_template(): ",
+    "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes());
 
   // Resize and initialize result tensor.
   gradInput.resize_as_(input);
@@ -368,7 +372,12 @@ void max_pool3d_with_indices_backward_out_cuda_template(
     pT, pH, pW,
     dilationT, dilationH, dilationW,
     itime, iheight, iwidth,
-    otime, oheight, owidth);
+    otime, oheight, owidth,
+    "max_pool3d_with_indices_backward_out_cuda_template()");
+
+  if (gradOutput.numel() == 0) {
+    return;
+  }
 
   Tensor work_grad_input = gradInput;
   Tensor work_grad_output = gradOutput.contiguous();
index 75c4d43..5784779 100644 (file)
@@ -13355,6 +13355,57 @@ class TestNNDeviceType(NNTestCase):
             unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
             unfold(inp)
 
+    @onlyOnCPUAndCUDA
+    def test_MaxPool_zero_batch_dim(self, device):
+        inp = torch.randn(0, 16, 50, device=device)
+        mod = torch.nn.MaxPool1d(3, stride=2).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        # 1D is supposed to be okay with 0 numel() inputs so dont test
+        # error raising for that case.
+
+        inp = torch.randn(0, 16, 50, 32, device=device)
+        mod = torch.nn.MaxPool2d(3, stride=2).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected"):
+            inp = torch.randn(1, 0, 50, 32, device=device)
+            mod(inp)
+
+        inp = torch.ones(0, 16, 50, 44, 31, device=device)
+        mod = torch.nn.MaxPool3d(3, stride=2).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected"):
+            inp = torch.ones(1, 0, 50, 44, 31, device=device)
+            mod(inp)
+
+    @onlyOnCPUAndCUDA
+    def test_AdaptiveMaxPool_zero_batch_dim(self, device):
+        inp = torch.randn(0, 16, 50, device=device)
+        mod = torch.nn.AdaptiveMaxPool1d(3).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected"):
+            inp = torch.randn(1, 0, 50, device=device)
+            mod(inp)
+
+        inp = torch.randn(0, 16, 50, 32, device=device)
+        mod = torch.nn.AdaptiveMaxPool2d(3).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected"):
+            inp = torch.randn(1, 0, 50, 32, device=device)
+            mod(inp)
+
+        inp = torch.ones(0, 16, 50, 44, 31, device=device)
+        mod = torch.nn.AdaptiveMaxPool3d(3).to(device)
+        self._test_module_empty_input(mod, inp, check_size=False)
+
+        with self.assertRaisesRegex(RuntimeError, "Expected"):
+            inp = torch.ones(1, 0, 50, 44, 31, device=device)
+            mod(inp)
+
     @onlyCUDA
     @dtypes(torch.float, torch.double)
     @tf32_on_and_off(0.005)
@@ -13852,8 +13903,8 @@ class TestNNDeviceType(NNTestCase):
                 model(torch.tensor(x, device=device, dtype=dtype))
 
         # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
-        check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0")
-        check([], (1,), "input tensor must have 2 or 3 dimensions but got 1")
+        check(0, (1,), "Expected 2D or 3D input tensor, but got")
+        check([], (1,), "Expected 2D or 3D input tensor, but got")
         check([[]], (1, 0), "stride must be greater than zero, but got 0")
         check([[]], (1, 1, -1), "padding must be non-negative, but got -1")
         check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1")