From c5ed31e4a7550bfe5a4893b3803ee7fdf1b31f53 Mon Sep 17 00:00:00 2001 From: mingfeima Date: Sun, 29 Aug 2021 18:35:37 -0700 Subject: [PATCH] add channel last support for MaxUnpool2d (#49984) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49984 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D26007051 Pulled By: VitalyFedyunin fbshipit-source-id: 6c54751ade4092e03c1651aaa60380f7d6e92f6b --- aten/src/ATen/native/MaxUnpooling.cpp | 434 ++++----------------------- aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp | 385 ++++++++++++++++++++++++ aten/src/ATen/native/cpu/MaxUnpoolKernel.h | 16 + test/test_nn.py | 31 ++ tools/build_variables.bzl | 1 + 5 files changed, 486 insertions(+), 381 deletions(-) create mode 100644 aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp create mode 100644 aten/src/ATen/native/cpu/MaxUnpoolKernel.h diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp index b3c0194..9987408 100644 --- a/aten/src/ATen/native/MaxUnpooling.cpp +++ b/aten/src/ATen/native/MaxUnpooling.cpp @@ -1,90 +1,17 @@ #include #include -#include -#include +#include namespace at { namespace native { -template -Tensor max_unpooling2d_forward_out_cpu_frame( - Tensor& output, - const Tensor& input, - const Tensor& indices, - int64_t oheight, - int64_t owidth) { - int64_t numBatch = 1; - int64_t dimc = 0; - int64_t dimh = 1; - int64_t dimw = 2; - if (input.ndimension() == 4) { - numBatch = input.size(0); - dimc++; - dimh++; - dimw++; - } - int64_t numChannels = input.size(dimc); - int64_t inputHeight = input.size(dimh); - int64_t inputWidth = input.size(dimw); - - auto* rawInput = input.data_ptr(); - auto* rawIndices = indices.data_ptr(); - auto* rawOutput = output.data_ptr(); - - at::internal::lazy_init_num_threads(); - - for (int64_t n = 0; n < numBatch; n++) { - int64_t nOutputOffset = n * numChannels * owidth * oheight; - int64_t nInputOffset = n * numChannels * inputWidth * inputHeight; - int64_t k = 0; - bool has_error = false; - int64_t error_index = 0; -#pragma omp parallel for private(k) - for (k = 0; k < numChannels; k++) { - int64_t finalOutputOffset = nOutputOffset + k * owidth * oheight; - int64_t finalInputOffset = nInputOffset + k * inputWidth * inputHeight; - scalar_t* output_p_k = rawOutput + finalOutputOffset; - scalar_t* input_p_k = rawInput + finalInputOffset; - int64_t* ind_p_k = rawIndices + finalInputOffset; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - for (int64_t i = 0; i < inputHeight; i++) { - for (int64_t j = 0; j < inputWidth; j++) { - maxp = ind_p_k[i * inputWidth + j]; - if (maxp < 0 || maxp >= owidth * oheight) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } else { - output_p_k[maxp] = input_p_k[i * inputWidth + j]; - } - } - } - } - if (has_error) { - AT_ERROR( - "Found an invalid max index: ", - error_index, - " (output volumes are of size ", - oheight, - "x", - owidth); - (void)error_index; - } - } - return output; -} - -Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, +Tensor& max_unpooling2d_forward_out_cpu( + const Tensor& self_, const Tensor& indices_, IntArrayRef output_size, Tensor& output) { auto oheight = output_size[0]; auto owidth = output_size[1]; - TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); TORCH_CHECK( indices_.scalar_type() == at::ScalarType::Long, "elements in indices should be type int64"); @@ -100,8 +27,9 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, TORCH_CHECK(self_.numel() > 0, "Input must be non-empty"); - auto self = self_.contiguous(); - auto indices = indices_.contiguous(); + auto memory_format = self_.suggest_memory_format(); + auto self = self_.contiguous(memory_format); + auto indices = indices_.contiguous(memory_format); if (self.ndimension() == 3) { int64_t numChannels = self.size(0); @@ -109,15 +37,11 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, } else { int64_t numBatch = self.size(0); int64_t numChannels = self.size(1); - output.resize_({numBatch, numChannels, oheight, owidth}); + output.resize_({numBatch, numChannels, oheight, owidth}, memory_format); } output.zero_(); - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling2d_forward_out_cpu_frame", ([&] { - max_unpooling2d_forward_out_cpu_frame( - output, self, indices, oheight, owidth); - })); + max_unpool2d_kernel(kCPU, output, self, indices); return output; }; @@ -130,87 +54,6 @@ Tensor max_unpooling2d_forward_cpu( return output; } -template -Tensor max_unpooling3d_forward_out_cpu_frame( - Tensor& output, - const Tensor& input, - const Tensor& indices, - int64_t oT, - int64_t oH, - int64_t oW) { - int64_t nBatch = 1; - int64_t dimw = 3; - int64_t dimh = 2; - int64_t dimt = 1; - - if (input.ndimension() == 5) { - nBatch = input.size(0); - dimw++; - dimh++; - dimt++; - } - - int64_t nSlices = input.size(dimt - 1); - int64_t iT = input.size(dimt); - int64_t iH = input.size(dimh); - int64_t iW = input.size(dimw); - - scalar_t* input_data = input.data_ptr(); - scalar_t* output_data = output.data_ptr(); - int64_t* indices_data = indices.data_ptr(); - - at::internal::lazy_init_num_threads(); - - for (int64_t p = 0; p < nBatch; p++) { - int64_t inputOffset = p * nSlices * iT * iW * iH; - int64_t outputOffset = p * nSlices * oT * oW * oH; - int64_t k = 0; - bool has_error = false; - int error_index = 0; -#pragma omp parallel for private(k) - for (k = 0; k < nSlices; k++) { - int64_t finalInputOffset = inputOffset + k * iT * iW * iH; - int64_t finalOutputOffset = outputOffset + k * oT * oW * oH; - - scalar_t* output_p_k = output_data + finalOutputOffset; - scalar_t* input_p_k = input_data + finalInputOffset; - int64_t* ind_p_k = indices_data + finalInputOffset; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int maxp; - for (int64_t t = 0; t < iT; t++) { - for (int64_t i = 0; i < iH; i++) { - for (int64_t j = 0; j < iW; j++) { - int64_t index = t * iH * iW + i * iW + j; - maxp = ind_p_k[index]; - if (maxp < 0 || maxp >= oT * oW * oH) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } else { - output_p_k[maxp] = input_p_k[index]; - } - } - } - } - if (has_error) { - AT_ERROR( - "found an invalid max index ", - error_index, - " (output volumes are of size ", - oT, - "x", - oH, - "x", - oW); - (void)error_index; - } - } - } - return output; -} - static void max_unpooling3d_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -310,16 +153,7 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_, } output.zero_(); - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling3d_forward_out_cpu_frame", ([&] { - max_unpooling3d_forward_out_cpu_frame( - output, - self, - indices, - oT, - oH, - oW); - })); + max_unpool3d_kernel(kCPU, output, self, indices); return output; } @@ -335,59 +169,6 @@ Tensor max_unpooling3d_forward_cpu( return output; } -template -static void max_unpooling2d_backward_out_cpu_frame( - scalar_t* gradInput_p, - scalar_t* gradOutput_p, - int64_t* ind_p, - int64_t nslices, - int64_t iheight, - int64_t iwidth, - int64_t oheight, - int64_t owidth) { - bool has_error = false; - int64_t error_index = 0; - int64_t k = 0; - - at::internal::lazy_init_num_threads(); -#pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) { - scalar_t* gradInput_p_k = gradInput_p + k * iwidth * iheight; - scalar_t* gradOutput_p_k = gradOutput_p + k * owidth * oheight; - int64_t* ind_p_k = ind_p + k * iwidth * iheight; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i, j; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - - for (i = 0; i < iheight; i++) { - for (j = 0; j < iwidth; j++) { - maxp = ind_p_k[i * iwidth + j]; /* retrieve position of max */ - if (maxp < 0 || maxp >= owidth * oheight) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } - gradInput_p_k[i * iwidth + j] = - gradOutput_p_k[maxp]; /* update gradient */ - } - } - } - if (has_error) { - AT_ERROR( - "invalid max index ", - error_index, - ", owidth= ", - owidth, - ", oheight= ", - oheight); - (void)error_index; - } -} - Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, const Tensor& self, const Tensor& indices_, @@ -396,42 +177,24 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); int64_t oheight = output_size[0]; int64_t owidth = output_size[1]; - int dimw = 2; - int dimh = 1; - int nbatch = 1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int nslices; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iheight; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iwidth; + int64_t ndim = self.ndimension(); + int64_t dimh = ndim == 3 ? 1 : 2; + int64_t dimw = ndim == 3 ? 2 : 3; + TORCH_CHECK( indices_.scalar_type() == at::ScalarType::Long, "elements in indices should be type int64"); TORCH_CHECK( self.sizes() == indices_.sizes(), "Input shape must match indices shape"); - TORCH_CHECK(output_size.size() == 2, "Output size must be 2"); - /* get contiguous gradOutput and indices */ - auto grad_output = grad_output_.contiguous(); - auto indices = indices_.contiguous(); + auto memory_format = self.suggest_memory_format(); + auto grad_output = grad_output_.contiguous(memory_format); + auto indices = indices_.contiguous(memory_format); - /* resize */ - grad_input.resize_as_(self); + grad_input.resize_(self.sizes(), memory_format); grad_input.zero_(); - if (self.ndimension() == 4) { - nbatch = self.size(0); - dimw++; - dimh++; - } - - /* sizes */ - nslices = self.size(dimh - 1); - iheight = self.size(dimh); - iwidth = self.size(dimw); - if (owidth != grad_output.size(dimw) || oheight != grad_output.size(dimh)) { AT_ERROR( "Inconsistent gradOutput size. output height = ", @@ -443,23 +206,8 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, "x", grad_output.size(dimw)); } - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling2d_backward_out_cpu_frame", ([&] { - int p; - for (p = 0; p < nbatch; p++) { - auto inputOffset = p * nslices * iheight * iwidth; - auto outputOffset = p * nslices * oheight * owidth; - max_unpooling2d_backward_out_cpu_frame( - grad_input.data_ptr() + inputOffset, - grad_output.data_ptr() + outputOffset, - indices.data_ptr() + inputOffset, - nslices, - iheight, - iwidth, - oheight, - owidth); - } - })); + + max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices); return grad_input; } @@ -468,72 +216,14 @@ Tensor max_unpooling2d_backward_cpu( const Tensor& self, const Tensor& indices, IntArrayRef output_size) { - auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - at::native::max_unpooling2d_backward_out_cpu( + auto grad_input = at::empty({0}, self.options()); + max_unpooling2d_backward_out_cpu( grad_output, self, indices, output_size, grad_input); return grad_input; } -template -static void max_unpooling3d_backward_out_cpu_frame( - scalar_t* gradInput_p, - scalar_t* gradOutput_p, - int64_t* ind_p, - int64_t nslices, - int64_t iT, - int64_t iH, - int64_t iW, - int64_t oT, - int64_t oH, - int64_t oW) { - int64_t k = 0; - bool has_error = false; - int error_index = 0; - - at::internal::lazy_init_num_threads(); - -#pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) { - scalar_t* gradInput_p_k = gradInput_p + k * iT * iH * iW; - scalar_t* gradOutput_p_k = gradOutput_p + k * oT * oH * oW; - int64_t* ind_p_k = ind_p + k * iT * iH * iW; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t t, i, j, index; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - for (t = 0; t < iT; t++) { - for (i = 0; i < iH; i++) { - for (j = 0; j < iW; j++) { - index = t * iH * iW + i * iW + j; - maxp = ind_p_k[index]; /* retrieve position of max */ - if (maxp < 0 || maxp >= oT * oH * oW) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } - gradInput_p_k[index] = gradOutput_p_k[maxp]; /* update gradient */ - } - } - } - } - if (has_error) { - AT_ERROR( - "invalid max index ", - error_index, - ", oT= ", - oT, - ", oW= ", - oW, - ",oH= ", - oH); - (void)error_index; - } -} - -Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, +Tensor& max_unpooling3d_backward_out_cpu( + const Tensor& grad_output_, const Tensor& self, const Tensor& indices_, IntArrayRef output_size, @@ -541,26 +231,17 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, IntArrayRef padding, Tensor& grad_input) { TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); - auto oT = output_size[0]; - auto oH = output_size[1]; - auto oW = output_size[2]; - int dimw = 3; - int dimh = 2; - int dimt = 1; - int nbatch = 1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int nslices; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iT; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iH; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iW; + int64_t oT = output_size[0]; + int64_t oH = output_size[1]; + int64_t oW = output_size[2]; + int64_t ndim = self.ndimension(); + int64_t dimt = ndim == 4 ? 1 : 2; + int64_t dimh = ndim == 4 ? 2 : 3; + int64_t dimw = ndim == 4 ? 3 : 4; max_unpooling3d_shape_check( self, grad_output_, indices_, output_size, stride, padding); - // TODO (from THNN): check gradOutput shape /* get contiguous gradOutput */ auto grad_output = grad_output_.contiguous(); auto indices = indices_.contiguous(); @@ -568,39 +249,24 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, /* resize */ grad_input.resize_as_(self); grad_input.zero_(); - if (self.ndimension() == 5) { - nbatch = self.size(0); - dimt++; - dimw++; - dimh++; + + if (oW != grad_output.size(dimw) || oH != grad_output.size(dimh) || oT != grad_output.size(dimt)) { + AT_ERROR( + "Inconsistent gradOutput size. output depth = ", + oT, + ", output height = ", + oH, + ", output width = ", + oW, + ", gradOutput: ", + grad_output.size(dimt), + "x", + grad_output.size(dimh), + "x", + grad_output.size(dimw)); } - /* sizes */ - nslices = self.size(dimt - 1); - iT = self.size(dimt); - iH = self.size(dimh); - iW = self.size(dimw); - - /* backprop */ - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling3d_backward_out_cpu_frame", ([&] { - int p; - for (p = 0; p < nbatch; p++) { - int inputOffset = p * nslices * iT * iH * iW; - int outputOffset = p * nslices * oT * oH * oW; - max_unpooling3d_backward_out_cpu_frame( - grad_input.data_ptr() + inputOffset, - grad_output.data_ptr() + outputOffset, - indices.data_ptr() + inputOffset, - nslices, - iT, - iH, - iW, - oT, - oH, - oW); - } - })); + max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices); return grad_input; } @@ -611,10 +277,16 @@ Tensor max_unpooling3d_backward_cpu( IntArrayRef output_size, IntArrayRef stride, IntArrayRef padding) { - auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_input = at::empty({0}, self.options()); at::native::max_unpooling3d_backward_out_cpu( grad_output, self, indices, output_size, stride, padding, grad_input); return grad_input; } + +DEFINE_DISPATCH(max_unpool2d_kernel); +DEFINE_DISPATCH(max_unpool2d_backward_kernel); +DEFINE_DISPATCH(max_unpool3d_kernel); +DEFINE_DISPATCH(max_unpool3d_backward_kernel); + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp new file mode 100644 index 0000000..5a7b031 --- /dev/null +++ b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp @@ -0,0 +1,385 @@ +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +void cpu_max_unpool( + Tensor& output_, + const Tensor& input, + const Tensor& indices) { + auto output = output_.contiguous(); + + auto input_data = input.data_ptr(); + auto indices_data = indices.data_ptr(); + auto output_data = output.data_ptr(); + + // NB: input tensor dimensions: + // MaxUnpool2d: + // dim = 3: CHW + // dim = 4: NCHW + // MaxUnpool3d: + // dim = 4: CDHW + // dim = 5: NCDHW + + int64_t numel = input.numel(); + int64_t ndim = input.ndimension(); + + // treat batch size and channels as one dimension + // and the feature map as another dimension + int64_t channels, output_depth, output_height, output_width; + if (is_3d) { + TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d: expect input to be 4d or 5d tensor."); + channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1); + output_depth = output.size(-3); + output_height = output.size(-2); + output_width = output.size(-1); + } else { + TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d: expect input to be 3d or 4d tensor."); + channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1); + output_depth = 1; + output_height = output.size(-2); + output_width = output.size(-1); + } + int64_t input_image_size = numel / channels; + int64_t output_image_size = output.numel() / channels; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, C, D, H, W: [channels, input_image_size] + at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) { + int64_t c = 0; + int64_t ip = 0; + data_index_init(begin, c, channels, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* output_ptr = output_data + c * output_image_size; + + int64_t maxp = indices_data[i]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + output_ptr[maxp] = input_data[i]; + } + + // move on to next input index + data_index_step(c, channels, ip, input_image_size); + } + }); + + if (has_error) { + if (is_3d) { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_depth, + "x", output_height, "x", output_width); + (void)error_index; + } else { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_height, + "x", output_width); + (void)error_index; + } + } + + if (!output_.is_contiguous()) { + output_.copy_(output); + } +} + +template +void cpu_max_unpool_channels_last( + Tensor& output_, + const Tensor& input, + const Tensor& indices) { + TORCH_CHECK(input.ndimension() == 4, + "max_unpool2d with channels last format supports tensors with 4 dims"); + auto memory_format = at::MemoryFormat::ChannelsLast; + auto output = output_.contiguous(memory_format); + + auto input_data = input.data_ptr(); + auto indices_data = indices.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + int64_t output_height = output.size(2); + int64_t output_width = output.size(3); + int64_t input_image_size = input_height * input_width; + int64_t output_image_size = output_height * output_width; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, H, W + at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t ip = 0; + data_index_init(begin, n, nbatch, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* input_ptr = input_data + i * channels; + int64_t* indices_ptr = indices_data + i * channels; + scalar_t* output_ptr = output_data + n * output_image_size * channels; + + // can't do scatter on avx2 (only available on avx512) + for (int64_t c = 0; c < channels; c++) { + int64_t maxp = indices_ptr[c]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + output_ptr[maxp * channels + c] = input_ptr[c]; + } + } + + // move on to next input index + data_index_step(n, nbatch, ip, input_image_size); + } + }); + + if (has_error) { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_height, + "x", output_width); + (void)error_index; + } + + if (!output_.is_contiguous(memory_format)) { + output_.copy_(output); + } +} + +template +void cpu_max_unpool_backward( + Tensor& grad_input_, + const Tensor& grad_output, + const Tensor& indices) { + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto indices_data = indices.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + + int64_t numel = grad_input.numel(); + int64_t ndim = grad_output.ndimension(); + + // treat batch size and channels as one dimension + // and the feature map as another dimension + int64_t channels, output_depth, output_height, output_width; + if (is_3d) { + TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d_backward: expect grad_output to be 4d or 5d tensor."); + channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + output_depth = grad_output.size(-3); + output_height = grad_output.size(-2); + output_width = grad_output.size(-1); + } else { + TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d_backward: expect grad_output to be 3d or 4d tensor."); + channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + output_depth = 1; + output_height = grad_output.size(-2); + output_width = grad_output.size(-1); + } + int64_t input_image_size = numel / channels; + int64_t output_image_size = grad_output.numel() / channels; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, C, D, H, W + at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) { + int64_t c = 0; + int64_t ip = 0; + data_index_init(begin, c, channels, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* grad_output_ptr = grad_output_data + c * output_image_size; + + int64_t maxp = indices_data[i]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + grad_input_data[i] = grad_output_ptr[maxp]; + } + + // move on to next input index + data_index_step(c, channels, ip, input_image_size); + } + }); + + if (has_error) { + if (is_3d) { + AT_ERROR("invalid max index ", error_index, + ", odepth= ", output_depth, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } else { + AT_ERROR("invalid max index ", error_index, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } + } + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +template +void cpu_max_unpool_backward_channels_last( + Tensor& grad_input_, + const Tensor& grad_output, + const Tensor& indices) { + TORCH_CHECK(grad_output.ndimension() == 4, + "max_unpool2d backward with channels last format supports tensors with 4 dims."); + auto memory_format = at::MemoryFormat::ChannelsLast; + auto grad_input = grad_input_.contiguous(memory_format); + + auto grad_input_data = grad_input.data_ptr(); + auto grad_output_data = grad_output.data_ptr(); + auto indices_data = indices.data_ptr(); + + int64_t nbatch = grad_input.size(0); + int64_t channels = grad_input.size(1); + int64_t input_height = grad_input.size(2); + int64_t input_width = grad_input.size(3); + int64_t output_height = grad_output.size(2); + int64_t output_width = grad_output.size(3); + int64_t input_image_size = input_height * input_width; + int64_t output_image_size = output_height * output_width; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, H, W + at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t ip = 0; + data_index_init(begin, n, nbatch, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* grad_output_ptr = grad_output_data + n * output_image_size * channels; + scalar_t* grad_input_ptr = grad_input_data + i * channels; + int64_t* indices_ptr = indices_data + i * channels; + + for (int64_t c = 0; c < channels; c++) { + int64_t maxp = indices_ptr[c]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + grad_input_ptr[c] = grad_output_ptr[maxp * channels + c]; + } + } + + // move on to next input index + data_index_step(n, nbatch, ip, input_image_size); + } + }); + + if (has_error) { + AT_ERROR("invalid max index ", error_index, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } + + if (!grad_input_.is_contiguous(memory_format)) { + grad_input_.copy_(grad_input); + } +} + +void max_unpool2d_kernel_impl( + Tensor& output, + const Tensor& input, + const Tensor& indices) { + switch(input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d", [&] { + cpu_max_unpool(output, input, indices); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d_channels_last", [&] { + cpu_max_unpool_channels_last(output, input, indices); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void max_unpool3d_kernel_impl( + Tensor& output, + const Tensor& input, + const Tensor& indices) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool3d", [&] { + cpu_max_unpool(output, input, indices); + }); +} + +void max_unpool2d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& indices) { + switch(grad_output.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward", [&] { + cpu_max_unpool_backward(grad_input, grad_output, indices); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward_channels_last", [&] { + cpu_max_unpool_backward_channels_last(grad_input, grad_output, indices); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void max_unpool3d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& indices) { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool3d_backward", [&] { + cpu_max_unpool_backward(grad_input, grad_output, indices); + }); +} + +} // anonymous namespace + +REGISTER_DISPATCH(max_unpool2d_kernel, &max_unpool2d_kernel_impl); +REGISTER_DISPATCH(max_unpool2d_backward_kernel, &max_unpool2d_backward_kernel_impl); +REGISTER_DISPATCH(max_unpool3d_kernel, &max_unpool3d_kernel_impl); +REGISTER_DISPATCH(max_unpool3d_backward_kernel, &max_unpool3d_backward_kernel_impl); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.h b/aten/src/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000..00fbeb6 --- /dev/null +++ b/aten/src/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,16 @@ +#include +#include +#include + +#pragma once + +namespace at { namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_backward_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_backward_kernel); + +}} // at::native diff --git a/test/test_nn.py b/test/test_nn.py index 4e01c94..7d26246 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6186,6 +6186,37 @@ class TestNN(NNTestCase): else: self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w))) + def test_max_unpool2d_nhwc_cpu(self): + input = torch.randn(2, 10, 9, 9).float().cpu() + input = input.contiguous(memory_format=torch.channels_last) + ref_input = input.clone().contiguous() + + pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() + ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() + + out, ind = pool(input) + ref_out, ref_ind = ref_pool(ref_input) + out.requires_grad_() + ref_out.requires_grad_() + + unpool = nn.MaxUnpool2d(3, stride=2).cpu() + ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu() + + upout = unpool(out, ind) + ref_upout = ref_unpool(ref_out, ref_ind) + + grad = torch.randn(upout.size()).float().cpu() + grad = grad.contiguous(memory_format=torch.channels_last) + ref_grad = grad.clone().contiguous() + + upout.backward(grad) + ref_upout.backward(ref_grad) + + self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_upout.is_contiguous()) + self.assertTrue(torch.allclose(upout, ref_upout)) + self.assertTrue(torch.allclose(out.grad, ref_out.grad)) + def test_container_copy(self): class Model(nn.Module): def __init__(self): diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index b2a1016..34846b5 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -907,6 +907,7 @@ aten_native_source_codegen_list = [ "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp", "aten/src/ATen/native/cpu/MaxPooling.cpp", "aten/src/ATen/native/cpu/MaxPoolKernel.cpp", + "aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp", "aten/src/ATen/native/cpu/MultinomialKernel.cpp", "aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp", "aten/src/ATen/native/cpu/PowKernel.cpp", -- 2.7.4