From 5b7cdc5a3ddb9a1a3d46d05b2925b5b4713b0025 Mon Sep 17 00:00:00 2001 From: mingfeima Date: Mon, 23 Aug 2021 22:53:35 -0700 Subject: [PATCH] add channels last for GroupNorm (#49821) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49821 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D26007053 Pulled By: VitalyFedyunin fbshipit-source-id: 34a48d5d3b66a159febf3c3d96748fbaba1b9e31 --- aten/src/ATen/native/cpu/group_norm_kernel.cpp | 162 ++++++++++++++++++++++--- aten/src/ATen/native/group_norm.cpp | 17 ++- test/test_nn.py | 34 ++++++ 3 files changed, 193 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index 290a631..fb8db7e 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -74,6 +74,136 @@ void GroupNormKernelImplInternal( }); } +template +void GroupNormKernelImplChannelsLastInternal( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + T eps, + Tensor& Y, + Tensor& mean, + Tensor& rstd) { + TORCH_CHECK(X.numel() == N * C * HxW); + TORCH_CHECK(!gamma.defined() || gamma.numel() == C); + TORCH_CHECK(!beta.defined() || beta.numel() == C); + const int64_t G = group; + const int64_t D = C / G; + const T* X_data = X.data_ptr(); + const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; + T* Y_data = Y.data_ptr(); + T* mean_data = mean.data_ptr(); + T* rstd_data = rstd.data_ptr(); + const T s = T(1) / static_cast(D * HxW); + const bool gamma_null = (gamma_data == nullptr); + const bool beta_null = beta_data == nullptr; + + // temp buffer holding x and x2 + Tensor buffer = at::empty({N, 2 * C}, X.options()).zero_(); + T* buffer_data = buffer.data_ptr(); + + using Vec = vec::Vectorized; + at::parallel_for(0, N, 1, [&](int64_t start, int64_t end) { + constexpr int64_t K = Vec::size(); + const int64_t inner_size = C / K * K; + for (int64_t n = start; n < end; ++n) { + T* mean_ptr = buffer_data + n * 2 * C; + T* rstd_ptr = mean_ptr + C; + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + const Vec x_vec = Vec::loadu(X_ptr + j); + Vec mean_vec = Vec::loadu(mean_ptr + j) + x_vec; + Vec rstd_vec = Vec::loadu(rstd_ptr + j) + x_vec * x_vec; + mean_vec.store(mean_ptr + j); + rstd_vec.store(rstd_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + mean_ptr[j] += X_ptr[j]; + rstd_ptr[j] += X_ptr[j] * X_ptr[j]; + } + } + + for (int64_t g = 0; g < G; ++g) { + T mean_val = T(0); + T rstd_val = T(0); + for (int64_t d = 0; d < D; ++d) { + mean_val += mean_ptr[g * D + d]; + rstd_val += rstd_ptr[g * D + d]; + } + mean_val *= s; + rstd_val = std::max(rstd_val * s - mean_val * mean_val, T(0)); + rstd_val = T(1) / std::sqrt(rstd_val + eps); + + // continue to use the temp buffer for mean and rstd value, + // so that we can vectorize the following math on entire C dimension. + for (int64_t d = 0; d < D; ++d) { + mean_ptr[g * D + d] = mean_val; + rstd_ptr[g * D + d] = rstd_val; + } + + mean_data[n * G + g] = mean_val; + rstd_data[n * G + g] = rstd_val; + } + + // expand gamma_null and beta_null to reduce if-else on critial path. + if (!gamma_null && !beta_null) { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec scale_vec = Vec::loadu(rstd_ptr + j) * Vec::loadu(gamma_data + j); + Vec bias_vec = Vec::loadu(beta_data + j) - scale_vec * Vec::loadu(mean_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) + bias_vec; + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j] * gamma_data[j]; + T bias = -scale * mean_ptr[j] + beta_data[j]; + Y_ptr[j] = scale * X_ptr[j] + bias; + } + } + } else if (gamma_null && beta_null) { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec scale_vec = Vec::loadu(rstd_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) - scale_vec * Vec::loadu(mean_ptr + j); + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j]; + Y_ptr[j] = scale * X_ptr[j] -scale * mean_ptr[j]; + } + } + } else { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec gamma_vec = gamma_null ? Vec(1) : Vec::loadu(gamma_data + j); + Vec beta_vec = beta_null ? Vec(0) : Vec::loadu(beta_data + j); + Vec scale_vec = Vec::loadu(rstd_ptr + j) * gamma_vec; + Vec bias_vec = beta_vec - scale_vec * Vec::loadu(mean_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) + bias_vec; + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j] * (gamma_null ? T(1) : gamma_data[j]); + T bias = -scale * mean_ptr[j] + (beta_null ? T(0) : beta_data[j]); + Y_ptr[j] = scale * X_ptr[j] + bias; + } + } + } + } + }); +} + void GroupNormKernelImpl( const Tensor& X, const Tensor& gamma, @@ -86,20 +216,24 @@ void GroupNormKernelImpl( Tensor& Y, Tensor& mean, Tensor& rstd) { - AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { - GroupNormKernelImplInternal( - X, - gamma, - beta, - N, - C, - HxW, - group, - static_cast(eps), - Y, - mean, - rstd); - }); + switch (X.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { + GroupNormKernelImplInternal( + X, gamma, beta, N, C, HxW, group, static_cast(eps), Y, mean, rstd); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { + GroupNormKernelImplChannelsLastInternal( + X, gamma, beta, N, C, HxW, group, static_cast(eps), Y, mean, rstd); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } } template diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 3a60d19..5533780 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -31,7 +31,10 @@ std::tuple native_group_norm( const Tensor& gamma = *gamma_maybe_owned; const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); }); - TORCH_CHECK(X.is_contiguous()); + auto memory_format = X.device().is_cpu() ? + X.suggest_memory_format() : at::MemoryFormat::Contiguous; + + TORCH_CHECK(X.is_contiguous(memory_format)); Tensor Y = at::native::empty_like( X, @@ -39,7 +42,7 @@ std::tuple native_group_norm( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + memory_format); Tensor mean = at::empty({N, group}, X.options()); Tensor rstd = at::empty({N, group}, X.options()); GroupNormKernel( @@ -73,7 +76,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } if (grad_input_mask[1]) { dgamma = at::native::empty_like( @@ -82,7 +85,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } if (grad_input_mask[2]) { dbeta = at::native::empty_like( @@ -91,7 +94,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } GroupNormBackwardKernel( X.device().type(), @@ -153,7 +156,9 @@ Tensor group_norm( c10::multiply_integers(input_shape.cbegin() + 2, input_shape.cend()); const Tensor kEmpty; - const auto& X = input.is_contiguous() ? input : input.contiguous(); + auto memory_format = input.suggest_memory_format(); + const auto& X = input.device().is_cpu() ? + input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; TORCH_CHECK(!gamma.defined() || gamma.numel() == C); diff --git a/test/test_nn.py b/test/test_nn.py index bb109cf..f4691e6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13026,6 +13026,40 @@ class TestNNDeviceType(NNTestCase): with torch.backends.cudnn.flags(enabled=False): self._test_module_empty_input(mod, inp) + @onlyCPU + @dtypes(torch.float, torch.double) + def test_groupnorm_nhwc(self, device, dtype): + def helper(self, size, groups): + channels = size[1] + input = torch.randn(size, dtype=dtype, device=device, requires_grad=True) + input = input.contiguous(memory_format=torch.channels_last) + input.retain_grad() + grad = torch.randn(size, dtype=dtype, device=device) + grad = grad.contiguous(memory_format=torch.channels_last) + gn = nn.GroupNorm(groups, channels).to(device).to(dtype) + gn.weight.data.uniform_() + gn.bias.data.uniform_() + + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype) + ref_gn.load_state_dict(gn.state_dict()) + + out = gn(input) + out.backward(grad) + ref_out = ref_gn(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(gn.weight.grad, ref_gn.weight.grad) + self.assertEqual(gn.bias.grad, ref_gn.bias.grad) + self.assertEqual(input.grad, ref_input.grad) + + helper(self, (4, 8, 10, 10), 4) + helper(self, (2, 30, 9, 9), 3) + @onlyOnCPUAndCUDA def test_GroupNorm_numeric(self, device): def group_norm_ref(X, gamma, beta, groups, channels, eps): -- 2.7.4