From 5b0dfd0f8aff50e2fce8f2f1fe6f2ef0594a9e25 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 30 Aug 2021 12:14:09 -0700 Subject: [PATCH] Fix bad use of channels last kernel in sync batch norm backward (#64100) Summary: Fixes https://github.com/pytorch/pytorch/issues/64039 There are two distinct problems here. 1. If `grad_output` is channels last but not input, then input would be read as-if it were channels last. So reading the wrong values. 2. `use_channels_last_kernels` doesn't guarunte that `suggest_memory_format` will actually return channels last, so use `empty_like` instead so the strides always match. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64100 Reviewed By: mruberry Differential Revision: D30622127 Pulled By: ngimel fbshipit-source-id: e28cc57215596817f1432fcdd6c49d69acfedcf2 --- aten/src/ATen/native/cuda/Normalization.cu | 4 ++- aten/src/ATen/native/cuda/Normalization.cuh | 6 +++-- test/test_nn.py | 42 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 0238b1b..1d4d1cc 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -648,7 +648,9 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){ + if (at::cuda::detail::canUse32BitIndexMath(self) && + batch_norm_use_channels_last_kernels(self) && + batch_norm_use_channels_last_kernels(input)) { return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); } diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index af074f5..6daa2b0 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -1649,7 +1649,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; @@ -1716,7 +1717,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto reduction_size = input.numel() / stride; auto norm_fct = 1.0 / reduction_size; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; diff --git a/test/test_nn.py b/test/test_nn.py index bb4dd59..c9815db 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11193,6 +11193,48 @@ class TestNN(NNTestCase): self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key]) @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sync_batchnorm_backward_elemt(self): + device = 'cuda' + saved_input = torch.rand(2, 3, 2, 1, device=device) + grad_output = torch.rand(2, 3, 2, 1, device=device) + mean = torch.rand(3, device=device) + invstd = torch.rand(3, device=device) + weight = torch.rand(3, device=device) + sum_dy = torch.rand(3, device=device) + sum_dy_xmu = torch.rand(3, device=device) + count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device) + + gI_contiguous = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + + # Test batch_norm_backward_elemt gives the same answer for all + # combinations of contiguous as channels_last input + for a, b in [ + (torch.channels_last, torch.contiguous_format), + (torch.contiguous_format, torch.channels_last), + (torch.channels_last, torch.channels_last), + ]: + gI_actual = torch.batch_norm_backward_elemt( + grad_output.contiguous(memory_format=a), + saved_input.contiguous(memory_format=b), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + self.assertEqual(gI_actual, gI_contiguous) + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_sync_batchnorm_accuracy_cuda(self): # The target of this test is to test the functionality and accuracy of # those single-GPU cuda kernels used in SyncBatchNorm -- 2.7.4