Change MaxUnpool to accept tensors with 0-dim batch sizes. (#64082)
authorSameer Deshmukh <sameer.deshmukh93@gmail.com>
Wed, 8 Sep 2021 15:40:01 +0000 (08:40 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 15:41:09 +0000 (08:41 -0700)
Summary:
Part of the fix for https://github.com/pytorch/pytorch/issues/38115.

Changes the `MaxUnpool` module to work with 0-dimensions batch sizes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64082

Reviewed By: mrshenli

Differential Revision: D30793907

Pulled By: jbschlosser

fbshipit-source-id: d21aa665be5aa18f592b39ef7b4e3cbc632e21ed

aten/src/ATen/native/MaxUnpooling.cpp
aten/src/ATen/native/cuda/MaxUnpooling.cu
test/test_nn.py

index 9987408..ec96601 100644 (file)
@@ -25,7 +25,11 @@ Tensor& max_unpooling2d_forward_out_cpu(
       self_.sizes() == indices_.sizes(),
       "Shape of indices should match shape of input");
 
-  TORCH_CHECK(self_.numel() > 0, "Input must be non-empty");
+  for (int64_t i = 1; i < self_.ndimension(); ++i) {
+    TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ",
+                "Expected input to have non-zero size for non-batch dimensions, but got ",
+                self_.sizes(), " with dimension ", i , " being empty.");
+  }
 
   auto memory_format = self_.suggest_memory_format();
   auto self = self_.contiguous(memory_format);
@@ -41,7 +45,10 @@ Tensor& max_unpooling2d_forward_out_cpu(
   }
   output.zero_();
 
-  max_unpool2d_kernel(kCPU, output, self, indices);
+  if (output.numel() != 0) {
+    max_unpool2d_kernel(kCPU, output, self, indices);
+  }
+
   return output;
 };
 
@@ -60,7 +67,8 @@ static void max_unpooling3d_shape_check(
     const Tensor& indices,
     IntArrayRef output_size,
     IntArrayRef stride,
-    IntArrayRef padding) {
+    IntArrayRef padding,
+    const char *fn_name) {
   int64_t oT = output_size[0];
   int64_t oH = output_size[1];
   int64_t oW = output_size[2];
@@ -84,7 +92,11 @@ static void max_unpooling3d_shape_check(
       input.sizes() == indices.sizes(),
       "Shape of indices should match shape of input");
 
-  TORCH_CHECK(input.numel() > 0, "Input must be non-empty");
+  for (int64_t i = 1; i < input.ndimension(); ++i) {
+    TORCH_CHECK(input.size(i) > 0, fn_name,
+                ": Expected input to have non-zero size for non-batch dimensions, but got ",
+                input.sizes(), " with dimension ", i , " being empty.");
+  }
 
   TORCH_CHECK(
       stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
@@ -144,7 +156,7 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
   auto indices = indices_.contiguous();
 
   max_unpooling3d_shape_check(
-      self_, Tensor(), indices_, output_size, stride, padding);
+    self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");
 
   if (self_.ndimension() == 5) {
     output.resize_({self.size(0), self.size(1), oT, oH, oW});
@@ -152,8 +164,10 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
     output.resize_({self.size(0), oT, oH, oW});
   }
   output.zero_();
+  if (output.numel() != 0) {
+    max_unpool3d_kernel(kCPU, output, self, indices);
+  }
 
-  max_unpool3d_kernel(kCPU, output, self, indices);
   return output;
 }
 
@@ -207,7 +221,10 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
         grad_output.size(dimw));
   }
 
-  max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
+  if (grad_input.numel() != 0) {
+    max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
+  }
+
   return grad_input;
 }
 
@@ -240,7 +257,7 @@ Tensor& max_unpooling3d_backward_out_cpu(
   int64_t dimw = ndim == 4 ? 3 : 4;
 
   max_unpooling3d_shape_check(
-      self, grad_output_, indices_, output_size, stride, padding);
+   self, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cpu()");
 
   /* get contiguous gradOutput */
   auto grad_output = grad_output_.contiguous();
@@ -266,7 +283,10 @@ Tensor& max_unpooling3d_backward_out_cpu(
         grad_output.size(dimw));
   }
 
-  max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
+  if (grad_input.numel() != 0) {
+    max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
+  }
+
   return grad_input;
 }
 
index e67f8e7..7c6d746 100644 (file)
@@ -114,7 +114,11 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
   checkAllSameGPU(
       "max_unpooling2d_forward_out_cuda", {output_arg, self_arg, indices_arg});
 
-  TORCH_CHECK(self_.numel() > 0, "Input must be non-empty tensor");
+  for (int64_t i = 1; i < self_.ndimension(); ++i) {
+    TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cuda(): ",
+                "Expected input to have non-zero size for non-batch dimensions, but got ",
+                self_.sizes(), " with dimension ", i , " being empty.");
+  }
 
   TORCH_CHECK(
       (self_.ndimension() == 3 || self_.ndimension() == 4),
@@ -152,24 +156,26 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
   output.zero_();
 
   auto count = self.numel();
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
-      self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
-        max_unpooling2d_forward_kernel<<<
-            GET_BLOCKS(count),
-            CUDA_NUM_THREADS,
-            0,
-            at::cuda::getCurrentCUDAStream()>>>(
-            self.numel(),
-            self.data_ptr<scalar_t>(),
-            indices.data_ptr<int64_t>(),
-            numChannels,
-            inputHeight,
-            inputWidth,
-            oheight,
-            owidth,
-            output.data_ptr<scalar_t>());
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
-      }));
+  if (count != 0) {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
+        self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
+          max_unpooling2d_forward_kernel<<<
+              GET_BLOCKS(count),
+              CUDA_NUM_THREADS,
+              0,
+              at::cuda::getCurrentCUDAStream()>>>(
+              self.numel(),
+              self.data_ptr<scalar_t>(),
+              indices.data_ptr<int64_t>(),
+              numChannels,
+              inputHeight,
+              inputWidth,
+              oheight,
+              owidth,
+              output.data_ptr<scalar_t>());
+          C10_CUDA_KERNEL_LAUNCH_CHECK();
+        }));
+  }
   if (self.ndimension() == 3) {
     output.resize_({numChannels, oheight, owidth});
   }
@@ -191,7 +197,8 @@ static void max_unpooling3d_shape_check(
     const Tensor& indices,
     IntArrayRef output_size,
     IntArrayRef stride,
-    IntArrayRef padding) {
+    IntArrayRef padding,
+    const char *fn_name) {
   int64_t oT = output_size[0];
   int64_t oH = output_size[1];
   int64_t oW = output_size[2];
@@ -215,7 +222,11 @@ static void max_unpooling3d_shape_check(
       input.sizes() == indices.sizes(),
       "Shape of indices should match shape of input");
 
-  TORCH_CHECK(input.numel() > 0, "Input must be non-empty");
+  for (int64_t i = 1; i < input.ndimension(); ++i) {
+    TORCH_CHECK(input.size(i) > 0, fn_name,
+                ": Expected input to have non-zero size for non-batch dimensions, but got ",
+                input.sizes(), " with dimension ", i , " being empty.");
+  }
 
   TORCH_CHECK(
       stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
@@ -268,7 +279,7 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
     Tensor& output) {
   TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
   max_unpooling3d_shape_check(
-      self_, Tensor(), indices_, output_size, stride, padding);
+    self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cuda()");
 
   int64_t oT = output_size[0];
   int64_t oH = output_size[1];
@@ -318,6 +329,10 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
                                indices.size(4)});
   }
 
+  if (self.numel() == 0) {
+    return output;
+  }
+
   int totalZ = inputTime * inputSlices * batchSize;
   int offsetZ = 0;
   dim3 block(32, 8);
@@ -426,6 +441,9 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
   grad_input.zero_();
 
   int64_t count = self.numel();
+  if (count == 0) {
+    return grad_input;
+  }
 
   AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
       self.scalar_type(), "max_unpooling2d_backward_kernel", ([&] {
@@ -471,7 +489,7 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
   int64_t oW = output_size[2];
 
   max_unpooling3d_shape_check(
-      self_, grad_output_, indices_, output_size, stride, padding);
+    self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()");
 
   int batchSize = 0;
   int inputSlices = 0;
@@ -521,6 +539,9 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
                                indices.size(3),
                                indices.size(4)});
   }
+  if (grad_input.numel() == 0) {
+    return grad_input;
+  }
 
   int totalZ = inputTime * inputSlices * batchSize;
   int offsetZ = 0;
index 2d66477..cc702df 100644 (file)
@@ -13764,6 +13764,40 @@ class TestNNDeviceType(NNTestCase):
             mod(inp)
 
     @onlyOnCPUAndCUDA
+    def test_MaxUnpool_zero_batch_dim(self, device):
+        pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device)
+        unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device)
+        inp = torch.randn(0, 10, 10, requires_grad=True, device=device)
+        output, indices = pool(inp)
+        output.requires_grad_(True)
+        unpool_out = unpool(output, indices)
+        unpool_out.sum().backward()
+
+        self.assertEqual(inp.grad, torch.zeros_like(inp))
+        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
+
+        pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device)
+        unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device)
+        inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device)
+        output, indices = pool(inp)
+        unpool_out = unpool(output, indices)
+        unpool_out.sum().backward()
+
+        self.assertEqual(inp.grad, torch.zeros_like(inp))
+        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
+
+        pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device)
+        unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device)
+        inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device)
+        output, indices = pool(inp)
+        output.requires_grad_(True)
+        unpool_out = unpool(output, indices)
+        unpool_out.sum().backward()
+
+        self.assertEqual(inp.grad, torch.zeros_like(inp))
+        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
+
+    @onlyOnCPUAndCUDA
     def test_AdaptiveMaxPool_zero_batch_dim(self, device):
         inp = torch.randn(0, 16, 50, device=device)
         mod = torch.nn.AdaptiveMaxPool1d(3).to(device)