From e926f75b0bc66c789365cb1c48ba41e8447b97fb Mon Sep 17 00:00:00 2001 From: jiej Date: Sat, 21 Aug 2021 09:05:04 -0700 Subject: [PATCH] BatchNorm autodiff re-enabled (#57321) Summary: Turns on BN in autodiff: 1. outputs an empty tensor for running stats to by pass autodiff issue on None; 2. fixing BN inference backward in cudnn & miopen, where backward falls back to native batchnorm kernel instead; Pull Request resolved: https://github.com/pytorch/pytorch/pull/57321 Reviewed By: albanD, ngimel Differential Revision: D30250419 Pulled By: jansel fbshipit-source-id: a62553789c20fb50a820003a056f40d9d642dfaa --- aten/src/ATen/native/Normalization.cpp | 54 +++++++++++++++---- aten/src/ATen/native/cuda/Normalization.cu | 6 ++- aten/src/ATen/native/cudnn/BatchNorm.cpp | 3 ++ aten/src/ATen/native/miopen/BatchNorm_miopen.cpp | 2 + test/test_jit.py | 62 ++++++++++++++++++++++ torch/csrc/jit/runtime/symbolic_script.cpp | 2 +- .../testing/_internal/jit_metaprogramming_utils.py | 35 +++++++++++- 7 files changed, 149 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 40ee1d5..611faf0 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -240,7 +240,7 @@ std::tuple batch_norm_backward_cpu_template( grad_weight = at::empty_like(weight, at::MemoryFormat::Contiguous); } if (grad_input_mask[2]) { - grad_bias = at::empty_like(weight, at::MemoryFormat::Contiguous); + grad_bias = at::empty({input.size(1)}, input.options()); } // since we are directly manipulating pointers in contiguous path, @@ -416,6 +416,22 @@ std::tuple _batch_norm_impl_index( const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); auto num_features = input.sizes()[1]; + + if (input.numel() == 0) { + Tensor reserve = at::empty({0}, input.options().dtype(kByte)); + auto options = input.options().dtype( + at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda())); + auto save_mean = at::empty({num_features}, options); + auto save_invstd = at::empty({num_features}, options); + + // don't return view of input, don't return empty tensor because it will break gradient chain + auto out = input.clone(); + if (weight.defined()) out = out * weight[0]; + if (bias.defined()) out = out + bias[0]; + return std::tuple( + out, save_mean, save_invstd, reserve, 0); + } + if (running_mean.defined()) { check_dims_match_num_input_features("running_mean", num_features, running_mean.numel()); } else if (!training) { @@ -508,7 +524,30 @@ std::tuple _batch_norm_impl_index_backward( const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();}); - if (impl_index == 0) { + if (input.numel() == 0) { + std::vector dims(input.dim() - 1); + dims[0] = 0; + std::iota(dims.begin() + 1, dims.end(), 2); + + // don't return empty tensor because it will break gradient chain + Tensor grad_input; + Tensor grad_weight; + Tensor grad_bias; + if (output_mask[2]) { + grad_bias = grad_output.sum(dims); + } + if (output_mask[1]) { + grad_weight = (grad_output * input).sum(dims); + } + if (output_mask[0] && weight.defined()) { + grad_input = grad_output * weight[0]; + } + return std::make_tuple(grad_input, grad_weight, grad_bias); + } + + // backward in inference mode is not supported in cudnn, fallback to native + // TODO: verify the same thing in miopen + if (impl_index == 0 || (!train)) { return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask); } else if (impl_index == 1) { // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC @@ -528,13 +567,6 @@ Tensor batch_norm( const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - if (input.numel()==0){ - //don't return view of input, don't return empty tensor because it will break gradient chain - auto out = input.clone(); - if (weight.defined()) out = out * weight[0]; - if (bias.defined()) out = out + bias[0]; - return out; - } return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)); } @@ -602,7 +634,9 @@ std::tuple batch_norm_cpu(const Tensor& self, const c10: return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] { if (!train) { - return batch_norm_cpu_transform_input_template(self, weight, bias, {}, {}, running_mean, running_var, train, eps); + auto save_mean = at::empty({0}, self.options()); + auto save_var = at::empty({0}, self.options()); + return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps); } else { auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps); return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps); diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index dff3f69..0238b1b 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -487,7 +487,8 @@ std::tuple batch_norm_backward_cuda(const Tensor& grad_o // save_mean and save_invstd, so it needs recalculated. const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); Tensor mean; - if (save_mean->defined()) { + TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n"); + if (save_mean->numel() != 0) { mean = *save_mean; } else if (needs_reduction) { TORCH_CHECK(!train && running_mean->defined()); @@ -496,7 +497,8 @@ std::tuple batch_norm_backward_cuda(const Tensor& grad_o } Tensor invstd; - if (save_invstd->defined()) { + TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n"); + if (save_invstd->numel() != 0) { invstd = *save_invstd; } else { TORCH_CHECK(!train && running_var->defined()); diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 3a34e32..1c70aa3 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -212,6 +212,9 @@ std::tuple cudnn_batch_norm( #endif // CUDNN_VERSION >= 7400 } else { reserve = at::empty({0}, input->options().dtype(kByte)); + // This keeps a consistent output with native_batch_norm + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index d78fe07..28e20e9 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -120,6 +120,8 @@ std::tuple miopen_batch_norm( save_mean.data_ptr(), save_var.data_ptr())); } else { + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); MIOPEN_CHECK(miopenBatchNormalizationForwardInference( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), diff --git a/test/test_jit.py b/test/test_jit.py index 2dd0d47..06afe65 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10774,6 +10774,68 @@ dedent """ self.assertEqual(w.grad, w_ref.grad) self.assertEqual(b.grad, b_ref.grad) + @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix") + def test_batch_norm_inference_backward_cuda(self): + with enable_profiling_mode_for_profiling_tests(): + class MyBatchNorm(torch.nn.Module): + def __init__(self, num_features, affine, track_running_stats): + super(MyBatchNorm, self).__init__() + self.bn = torch.nn.BatchNorm2d( + num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float() + + def forward(self, x: torch.Tensor): + o = self.bn(x) + o = torch.nn.functional.relu(o) + return o + + batch = 4 + c = 2 + hw = 3 + # Initialize param and input values + x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() + grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() + + training = False + affine = True + track_running_stats = True + + module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda() + ref_module = MyBatchNorm(c, affine, track_running_stats).cuda() + module.eval() + ref_module.eval() + + jit_module = torch.jit.script(module) + ref_module.load_state_dict(module.state_dict()) + + x = x_init.detach().clone() + x.requires_grad_() + x_ref = x_init.detach().clone() + x_ref.requires_grad_() + + # Test symbolic differentiation + # Run Forward and Backward thrice to trigger autodiff graph + for i in range(0, 3): + y = jit_module(x) + y.backward(grad) + x.grad.zero_() + + module.bn.running_mean.zero_() + module.bn.running_var.fill_(1.0) + ref_module.bn.running_mean.zero_() + ref_module.bn.running_var.fill_(1.0) + + # run jitted module + y = jit_module(x) + y.backward(grad) + # reference computation + y_ref = ref_module(x_ref) + y_ref.backward(grad) + + self.assertEqual(y_ref, y) + self.assertEqual(x.grad, x_ref.grad) + self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean) + self.assertEqual(module.bn.running_var, ref_module.bn.running_var) + def test_zeros(self): class M(torch.jit.ScriptModule): __constants__ = ['d'] diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 453a83c..29ce74a 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1117,7 +1117,7 @@ const std::vector functions = { return result, backward )", R"( - def batch_norm_disabled(input : Tensor, + def batch_norm(input : Tensor, weight : Optional[Tensor], bias : Optional[Tensor], running_mean : Optional[Tensor], diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index a21717b..350866c 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -109,8 +109,39 @@ nn_functional_tests = [ ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ), - '', (False, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), + 'training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), True, ), + 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, True, ), + 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, None, False, ), + 'inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), + 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), False, ), + 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, False, ), + 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), ('layer_norm', (S, S, S, S), ([5],), '', (True, ['aten::native_layer_norm'])), -- 2.7.4