BatchNorm autodiff re-enabled (#57321)
authorjiej <jiej@nvidia.com>
Sat, 21 Aug 2021 16:05:04 +0000 (09:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 21 Aug 2021 16:07:31 +0000 (09:07 -0700)
Summary:
Turns on BN in autodiff:

1. outputs an empty tensor for running stats to by pass autodiff issue on None;
2. fixing BN inference backward in cudnn & miopen, where backward falls back to native batchnorm kernel instead;

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57321

Reviewed By: albanD, ngimel

Differential Revision: D30250419

Pulled By: jansel

fbshipit-source-id: a62553789c20fb50a820003a056f40d9d642dfaa

aten/src/ATen/native/Normalization.cpp
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cudnn/BatchNorm.cpp
aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
test/test_jit.py
torch/csrc/jit/runtime/symbolic_script.cpp
torch/testing/_internal/jit_metaprogramming_utils.py

index 40ee1d5..611faf0 100644 (file)
@@ -240,7 +240,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
     grad_weight = at::empty_like(weight, at::MemoryFormat::Contiguous);
   }
   if (grad_input_mask[2]) {
-    grad_bias = at::empty_like(weight, at::MemoryFormat::Contiguous);
+    grad_bias = at::empty({input.size(1)}, input.options());
   }
 
   // since we are directly manipulating pointers in contiguous path,
@@ -416,6 +416,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
 
   auto num_features = input.sizes()[1];
+
+  if (input.numel() == 0) {
+    Tensor reserve = at::empty({0}, input.options().dtype(kByte));
+    auto options = input.options().dtype(
+        at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
+    auto save_mean = at::empty({num_features}, options);
+    auto save_invstd = at::empty({num_features}, options);
+
+    // don't return view of input, don't return empty tensor because it will break gradient chain
+    auto out = input.clone();
+    if (weight.defined()) out = out * weight[0];
+    if (bias.defined()) out = out + bias[0];
+    return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
+        out, save_mean, save_invstd, reserve, 0);
+  }
+
   if (running_mean.defined()) {
     check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
   } else if (!training) {
@@ -508,7 +524,30 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
   const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
   const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();});
 
-  if (impl_index == 0) {
+  if (input.numel() == 0) {
+    std::vector<int64_t> dims(input.dim() - 1);
+    dims[0] = 0;
+    std::iota(dims.begin() + 1, dims.end(), 2);
+
+    // don't return empty tensor because it will break gradient chain
+    Tensor grad_input;
+    Tensor grad_weight;
+    Tensor grad_bias;
+    if (output_mask[2]) {
+      grad_bias = grad_output.sum(dims);
+    }
+    if (output_mask[1]) {
+      grad_weight = (grad_output * input).sum(dims);
+    }
+    if (output_mask[0] && weight.defined()) {
+      grad_input = grad_output * weight[0];
+    }
+    return std::make_tuple(grad_input, grad_weight, grad_bias);
+  }
+
+  // backward in inference mode is not supported in cudnn, fallback to native
+  // TODO: verify the same thing in miopen
+  if (impl_index == 0 || (!train)) {
     return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
   } else if (impl_index == 1) {
     // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
@@ -528,13 +567,6 @@ Tensor batch_norm(
   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
-  if (input.numel()==0){
-    //don't return view of input, don't return empty tensor because it will break gradient chain
-    auto out = input.clone();
-    if (weight.defined()) out = out * weight[0];
-    if (bias.defined()) out = out + bias[0];
-    return out;
-  }
   return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
                                                 training, momentum, eps, cudnn_enabled));
 }
@@ -602,7 +634,9 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
 
   return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] {
       if (!train) {
-        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
+        auto save_mean = at::empty({0}, self.options());
+        auto save_var = at::empty({0}, self.options());
+        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps);
       } else {
         auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
         return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
index dff3f69..0238b1b 100644 (file)
@@ -487,7 +487,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
   // save_mean and save_invstd, so it needs recalculated.
   const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
   Tensor mean;
-  if (save_mean->defined()) {
+  TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n");
+  if (save_mean->numel() != 0) {
     mean = *save_mean;
   } else if (needs_reduction) {
     TORCH_CHECK(!train && running_mean->defined());
@@ -496,7 +497,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
   }
 
   Tensor invstd;
-  if (save_invstd->defined()) {
+  TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n");
+  if (save_invstd->numel() != 0) {
     invstd = *save_invstd;
   } else {
     TORCH_CHECK(!train && running_var->defined());
index 3a34e32..1c70aa3 100644 (file)
@@ -212,6 +212,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
 #endif // CUDNN_VERSION >= 7400
   } else {
     reserve = at::empty({0}, input->options().dtype(kByte));
+    // This keeps a consistent output with native_batch_norm
+    save_mean = at::empty({0}, weight_t.options());
+    save_var = at::empty({0}, weight_t.options());
     AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
       handle, mode, &one, &zero,
       idesc.desc(), input->data_ptr(),
index d78fe07..28e20e9 100644 (file)
@@ -120,6 +120,8 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
       save_mean.data_ptr(),
       save_var.data_ptr()));
   } else {
+    save_mean = at::empty({0}, weight_t.options());
+    save_var = at::empty({0}, weight_t.options());
     MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
       handle, mode, &one, &zero,
       idesc.desc(), input->data_ptr(),
index 2dd0d47..06afe65 100644 (file)
@@ -10774,6 +10774,68 @@ dedent """
         self.assertEqual(w.grad, w_ref.grad)
         self.assertEqual(b.grad, b_ref.grad)
 
+    @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix")
+    def test_batch_norm_inference_backward_cuda(self):
+        with enable_profiling_mode_for_profiling_tests():
+            class MyBatchNorm(torch.nn.Module):
+                def __init__(self, num_features, affine, track_running_stats):
+                    super(MyBatchNorm, self).__init__()
+                    self.bn = torch.nn.BatchNorm2d(
+                        num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float()
+
+                def forward(self, x: torch.Tensor):
+                    o = self.bn(x)
+                    o = torch.nn.functional.relu(o)
+                    return o
+
+            batch = 4
+            c = 2
+            hw = 3
+            # Initialize param and input values
+            x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
+            grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
+
+            training = False
+            affine = True
+            track_running_stats = True
+
+            module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda()
+            ref_module = MyBatchNorm(c, affine, track_running_stats).cuda()
+            module.eval()
+            ref_module.eval()
+
+            jit_module = torch.jit.script(module)
+            ref_module.load_state_dict(module.state_dict())
+
+            x = x_init.detach().clone()
+            x.requires_grad_()
+            x_ref = x_init.detach().clone()
+            x_ref.requires_grad_()
+
+            # Test symbolic differentiation
+            # Run Forward and Backward thrice to trigger autodiff graph
+            for i in range(0, 3):
+                y = jit_module(x)
+                y.backward(grad)
+            x.grad.zero_()
+
+            module.bn.running_mean.zero_()
+            module.bn.running_var.fill_(1.0)
+            ref_module.bn.running_mean.zero_()
+            ref_module.bn.running_var.fill_(1.0)
+
+            # run jitted module
+            y = jit_module(x)
+            y.backward(grad)
+            # reference computation
+            y_ref = ref_module(x_ref)
+            y_ref.backward(grad)
+
+            self.assertEqual(y_ref, y)
+            self.assertEqual(x.grad, x_ref.grad)
+            self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean)
+            self.assertEqual(module.bn.running_var, ref_module.bn.running_var)
+
     def test_zeros(self):
         class M(torch.jit.ScriptModule):
             __constants__ = ['d']
index 453a83c..29ce74a 100644 (file)
@@ -1117,7 +1117,7 @@ const std::vector<std::string> functions = {
             return result, backward
     )",
     R"(
-        def batch_norm_disabled(input : Tensor,
+        def batch_norm(input : Tensor,
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],
                        running_mean : Optional[Tensor],
index a21717b..350866c 100644 (file)
@@ -109,8 +109,39 @@ nn_functional_tests = [
     ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
     ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
     ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
-    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
-        '', (False, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S),
+        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
+        'training', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (0, S, S, S),
+        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+        'size_zero', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (0, S, S, S),
+        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+        'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S),
+        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+        'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            None, non_differentiable(torch.ones(S)), True, ),
+        'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            non_differentiable(torch.randn(S)), None, True, ),
+        'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            None, None, False, ),
+        'inference', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
+        'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            None, non_differentiable(torch.ones(S)), False, ),
+        'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                            non_differentiable(torch.randn(S)), None, False, ),
+        'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
     ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
     ('layer_norm', (S, S, S, S), ([5],), '',
      (True, ['aten::native_layer_norm'])),