fix batchnorm2d issue when input is non contiguous (#63392)
authormingfeima <mingfei.ma@intel.com>
Tue, 24 Aug 2021 15:22:47 +0000 (08:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 15:24:01 +0000 (08:24 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63392

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30476317

Pulled By: VitalyFedyunin

fbshipit-source-id: 03055a0aec21cf2c029b6f32315da2b09cb722d0

aten/src/ATen/native/Normalization.cpp
aten/src/ATen/native/cpu/batch_norm_kernel.cpp
test/test_nn.py

index 611faf0..25ae1a7 100644 (file)
@@ -74,6 +74,13 @@ static inline bool is_contiguous(const Tensor& t) {
   return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
 }
 
+// For some ambiguous cases, it is possible a channels last contiguous Tensor has
+//   `suggest_memory_format` of Contiguous.
+// See https://github.com/pytorch/pytorch/issues/63224 for details.
+static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
+  return t.is_contiguous() ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+}
+
 template<typename scalar_t>
 std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
     const Tensor& input, const Tensor& weight, const Tensor& bias,
@@ -87,10 +94,9 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
       && running_mean.is_contiguous()
       && running_var.is_contiguous();
 
-  Tensor output = at::empty_like(input, input.suggest_memory_format());
-
   // inference contiguous path
   if (all_contiguous) {
+    Tensor output = at::empty_like(input, suggest_memory_format_contig(input));
     batch_norm_cpu_stub(kCPU, output, input, weight, bias,
         save_mean, save_invstd, running_mean, running_var, train, eps);
     return std::make_tuple(output, save_mean, save_invstd);
@@ -120,6 +126,7 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
   auto b = bias.defined() ? as_nd(bias) :
       at::detail::scalar_tensor_static(0, input.scalar_type(), kCPU);
 
+  Tensor output = at::empty_like(input, input.suggest_memory_format());
   auto iter = TensorIteratorConfig()
     .add_output(output)
     .add_input(input)
@@ -250,6 +257,9 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
       && input.suggest_memory_format() == grad_out_.suggest_memory_format();
 
   if (all_contiguous) {
+    if (grad_input_mask[0]) {
+      grad_input = at::empty_like(input, suggest_memory_format_contig(input));
+    }
     batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
         grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
     return std::make_tuple(grad_input, grad_weight, grad_bias);
index 2d12755..7503760 100644 (file)
@@ -611,48 +611,38 @@ void batch_norm_cpu_backward_channels_last_impl(Tensor& grad_input, Tensor& grad
 void batch_norm_cpu_kernel(Tensor& output, const Tensor& input,
     const Tensor& weight, const Tensor& bias, const Tensor& save_mean,  const Tensor& save_invstd,
     const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
-  switch (input.suggest_memory_format()) {
-    case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] {
-        batch_norm_cpu_contiguous_impl<scalar_t>(output, input, weight, bias,
-            save_mean, save_invstd, running_mean, running_var, train, eps);
-      });
-      break;
-    }
-    case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] {
-        batch_norm_cpu_channels_last_impl<scalar_t>(output, input, weight, bias,
-            save_mean, save_invstd, running_mean, running_var, train, eps);
-      });
-      break;
-    }
-    default:
-      TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+  if (input.is_contiguous()) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] {
+      batch_norm_cpu_contiguous_impl<scalar_t>(output, input, weight, bias,
+          save_mean, save_invstd, running_mean, running_var, train, eps);
+    });
+  } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] {
+      batch_norm_cpu_channels_last_impl<scalar_t>(output, input, weight, bias,
+          save_mean, save_invstd, running_mean, running_var, train, eps);
+    });
+  } else {
+    TORCH_CHECK(false, "batch_norm_cpu_kernel: expecting input to be contiguous.");
   }
 }
 
 void batch_norm_cpu_collect_stats_kernel(
     Tensor& mean, Tensor& var_sum, const Tensor& input) {
   int64_t image_size = input.numel() / input.size(0) / input.size(1);
-  switch (input.suggest_memory_format()) {
-    case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] {
-        if (image_size == 1) { // NC11 is also channels last
-          batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
-        } else {
-          batch_norm_cpu_collect_stats_contiguous_impl<scalar_t>(mean, var_sum, input);
-        }
-      });
-      break;
-    }
-    case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] {
+  if (input.is_contiguous()) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] {
+      if (image_size == 1) { // NC11 is also channels last
         batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
-      });
-      break;
-    }
-    default:
-      TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+      } else {
+        batch_norm_cpu_collect_stats_contiguous_impl<scalar_t>(mean, var_sum, input);
+      }
+    });
+  } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] {
+      batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
+    });
+  } else {
+    TORCH_CHECK(false, "batch_norm_cpu_collect_stats_kernel: expecting input to be contiguous.");
   }
 }
 
@@ -661,28 +651,23 @@ void batch_norm_cpu_backward_kernel(Tensor& grad_input, Tensor& grad_weight, Ten
     const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
     bool train, double eps) {
   int64_t image_size = input.numel() / input.size(0) / input.size(1);
-  switch (input.suggest_memory_format()) {
-    case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] {
-        if (image_size == 1) { // NC11 is also channels last
-          batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
-              grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
-        } else {
-          batch_norm_cpu_backward_contiguous_impl<scalar_t>(grad_input, grad_weight, grad_bias,
-              grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
-        }
-      });
-      break;
-    }
-    case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] {
+  if (input.is_contiguous()) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] {
+      if (image_size == 1) { // NC11 is also channels last
         batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
             grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
-      });
-      break;
-    }
-    default:
-      TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+      } else {
+        batch_norm_cpu_backward_contiguous_impl<scalar_t>(grad_input, grad_weight, grad_bias,
+            grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
+      }
+    });
+  } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] {
+      batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
+          grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
+    });
+  } else {
+    TORCH_CHECK(false, "batch_norm_cpu_backward_kernel: expecting input to be contiguous.");
   }
 }
 
index f4691e6..07a2b48 100644 (file)
@@ -8929,6 +8929,25 @@ class TestNN(NNTestCase):
         helper(self, (4, 1, 9, 9))
         helper(self, (4, 9, 1, 1))
 
+    def test_batchnorm_non_contig_cpu(self):
+        input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu()
+        input = input.permute(0, 2, 1, 3)
+
+        bn = torch.nn.BatchNorm2d(2).cpu().float().eval()
+        bn.weight.data.uniform_()
+        bn.bias.data.uniform_()
+
+        ref_input = input.detach().clone().contiguous()
+        ref_bn = nn.BatchNorm2d(2).cpu().float().eval()
+        ref_bn.load_state_dict(bn.state_dict())
+
+        out = bn(input)
+        ref_out = ref_bn(ref_input)
+
+        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
+        self.assertTrue(ref_out.is_contiguous())
+        self.assertEqual(out, ref_out)
+
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
     @skipIfRocm