Fix bad use of channels last kernel in sync batch norm backward (#64100)
authorPeter Bell <peterbell10@live.co.uk>
Mon, 30 Aug 2021 19:14:09 +0000 (12:14 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:16:30 +0000 (12:16 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/64039

There are two distinct problems here.
1. If `grad_output` is channels last but not input, then input would be read as-if it were channels last. So reading the wrong values.
2. `use_channels_last_kernels` doesn't guarunte that `suggest_memory_format` will actually return channels last, so use `empty_like` instead so the strides always match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64100

Reviewed By: mruberry

Differential Revision: D30622127

Pulled By: ngimel

fbshipit-source-id: e28cc57215596817f1432fcdd6c49d69acfedcf2

aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Normalization.cuh
test/test_nn.py

index 0238b1b..1d4d1cc 100644 (file)
@@ -648,7 +648,9 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c
   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
   const Tensor& weight = *weight_maybe_owned;
 
-  if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){
+  if (at::cuda::detail::canUse32BitIndexMath(self) &&
+      batch_norm_use_channels_last_kernels(self) &&
+      batch_norm_use_channels_last_kernels(input))  {
     return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
   }
 
index af074f5..6daa2b0 100644 (file)
@@ -1649,7 +1649,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
   const auto stride = input.sizes()[1];
   const auto reduction_size = input.numel() / stride;
 
-  at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
+  // Input is guarunteed to be channels-last compatible
+  at::Tensor grad_input = at::empty_like(input);
 
   dim3 block;
   dim3 grid;
@@ -1716,7 +1717,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
   const auto reduction_size = input.numel() / stride;
   auto norm_fct = 1.0 / reduction_size;
 
-  at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
+  // Input is guarunteed to be channels-last compatible
+  at::Tensor grad_input = at::empty_like(input);
 
   dim3 block;
   dim3 grid;
index bb4dd59..c9815db 100644 (file)
@@ -11193,6 +11193,48 @@ class TestNN(NNTestCase):
                 self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA not available")
+    def test_sync_batchnorm_backward_elemt(self):
+        device = 'cuda'
+        saved_input = torch.rand(2, 3, 2, 1, device=device)
+        grad_output = torch.rand(2, 3, 2, 1, device=device)
+        mean = torch.rand(3, device=device)
+        invstd = torch.rand(3, device=device)
+        weight = torch.rand(3, device=device)
+        sum_dy = torch.rand(3, device=device)
+        sum_dy_xmu = torch.rand(3, device=device)
+        count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device)
+
+        gI_contiguous = torch.batch_norm_backward_elemt(
+            grad_output,
+            saved_input,
+            mean,
+            invstd,
+            weight,
+            sum_dy,
+            sum_dy_xmu,
+            count_tensor
+        )
+
+        # Test batch_norm_backward_elemt gives the same answer for all
+        # combinations of contiguous as channels_last input
+        for a, b in [
+                (torch.channels_last, torch.contiguous_format),
+                (torch.contiguous_format, torch.channels_last),
+                (torch.channels_last, torch.channels_last),
+        ]:
+            gI_actual = torch.batch_norm_backward_elemt(
+                grad_output.contiguous(memory_format=a),
+                saved_input.contiguous(memory_format=b),
+                mean,
+                invstd,
+                weight,
+                sum_dy,
+                sum_dy_xmu,
+                count_tensor
+            )
+            self.assertEqual(gI_actual, gI_contiguous)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
     def test_sync_batchnorm_accuracy_cuda(self):
         # The target of this test is to test the functionality and accuracy of
         #   those single-GPU cuda kernels used in SyncBatchNorm