#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
-#include <ATen/Parallel.h>
-#include <tuple>
+#include <ATen/native/cpu/MaxUnpoolKernel.h>
namespace at {
namespace native {
-template <typename scalar_t>
-Tensor max_unpooling2d_forward_out_cpu_frame(
- Tensor& output,
- const Tensor& input,
- const Tensor& indices,
- int64_t oheight,
- int64_t owidth) {
- int64_t numBatch = 1;
- int64_t dimc = 0;
- int64_t dimh = 1;
- int64_t dimw = 2;
- if (input.ndimension() == 4) {
- numBatch = input.size(0);
- dimc++;
- dimh++;
- dimw++;
- }
- int64_t numChannels = input.size(dimc);
- int64_t inputHeight = input.size(dimh);
- int64_t inputWidth = input.size(dimw);
-
- auto* rawInput = input.data_ptr<scalar_t>();
- auto* rawIndices = indices.data_ptr<int64_t>();
- auto* rawOutput = output.data_ptr<scalar_t>();
-
- at::internal::lazy_init_num_threads();
-
- for (int64_t n = 0; n < numBatch; n++) {
- int64_t nOutputOffset = n * numChannels * owidth * oheight;
- int64_t nInputOffset = n * numChannels * inputWidth * inputHeight;
- int64_t k = 0;
- bool has_error = false;
- int64_t error_index = 0;
-#pragma omp parallel for private(k)
- for (k = 0; k < numChannels; k++) {
- int64_t finalOutputOffset = nOutputOffset + k * owidth * oheight;
- int64_t finalInputOffset = nInputOffset + k * inputWidth * inputHeight;
- scalar_t* output_p_k = rawOutput + finalOutputOffset;
- scalar_t* input_p_k = rawInput + finalInputOffset;
- int64_t* ind_p_k = rawIndices + finalInputOffset;
-
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int64_t maxp;
- for (int64_t i = 0; i < inputHeight; i++) {
- for (int64_t j = 0; j < inputWidth; j++) {
- maxp = ind_p_k[i * inputWidth + j];
- if (maxp < 0 || maxp >= owidth * oheight) {
-#pragma omp critical
- {
- has_error = true;
- error_index = maxp;
- }
- } else {
- output_p_k[maxp] = input_p_k[i * inputWidth + j];
- }
- }
- }
- }
- if (has_error) {
- AT_ERROR(
- "Found an invalid max index: ",
- error_index,
- " (output volumes are of size ",
- oheight,
- "x",
- owidth);
- (void)error_index;
- }
- }
- return output;
-}
-
-Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_,
+Tensor& max_unpooling2d_forward_out_cpu(
+ const Tensor& self_,
const Tensor& indices_,
IntArrayRef output_size,
Tensor& output) {
auto oheight = output_size[0];
auto owidth = output_size[1];
- TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
TORCH_CHECK(
indices_.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64");
TORCH_CHECK(self_.numel() > 0, "Input must be non-empty");
- auto self = self_.contiguous();
- auto indices = indices_.contiguous();
+ auto memory_format = self_.suggest_memory_format();
+ auto self = self_.contiguous(memory_format);
+ auto indices = indices_.contiguous(memory_format);
if (self.ndimension() == 3) {
int64_t numChannels = self.size(0);
} else {
int64_t numBatch = self.size(0);
int64_t numChannels = self.size(1);
- output.resize_({numBatch, numChannels, oheight, owidth});
+ output.resize_({numBatch, numChannels, oheight, owidth}, memory_format);
}
output.zero_();
- AT_DISPATCH_FLOATING_TYPES(
- self.scalar_type(), "max_unpooling2d_forward_out_cpu_frame", ([&] {
- max_unpooling2d_forward_out_cpu_frame<scalar_t>(
- output, self, indices, oheight, owidth);
- }));
+ max_unpool2d_kernel(kCPU, output, self, indices);
return output;
};
return output;
}
-template <typename scalar_t>
-Tensor max_unpooling3d_forward_out_cpu_frame(
- Tensor& output,
- const Tensor& input,
- const Tensor& indices,
- int64_t oT,
- int64_t oH,
- int64_t oW) {
- int64_t nBatch = 1;
- int64_t dimw = 3;
- int64_t dimh = 2;
- int64_t dimt = 1;
-
- if (input.ndimension() == 5) {
- nBatch = input.size(0);
- dimw++;
- dimh++;
- dimt++;
- }
-
- int64_t nSlices = input.size(dimt - 1);
- int64_t iT = input.size(dimt);
- int64_t iH = input.size(dimh);
- int64_t iW = input.size(dimw);
-
- scalar_t* input_data = input.data_ptr<scalar_t>();
- scalar_t* output_data = output.data_ptr<scalar_t>();
- int64_t* indices_data = indices.data_ptr<int64_t>();
-
- at::internal::lazy_init_num_threads();
-
- for (int64_t p = 0; p < nBatch; p++) {
- int64_t inputOffset = p * nSlices * iT * iW * iH;
- int64_t outputOffset = p * nSlices * oT * oW * oH;
- int64_t k = 0;
- bool has_error = false;
- int error_index = 0;
-#pragma omp parallel for private(k)
- for (k = 0; k < nSlices; k++) {
- int64_t finalInputOffset = inputOffset + k * iT * iW * iH;
- int64_t finalOutputOffset = outputOffset + k * oT * oW * oH;
-
- scalar_t* output_p_k = output_data + finalOutputOffset;
- scalar_t* input_p_k = input_data + finalInputOffset;
- int64_t* ind_p_k = indices_data + finalInputOffset;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int maxp;
- for (int64_t t = 0; t < iT; t++) {
- for (int64_t i = 0; i < iH; i++) {
- for (int64_t j = 0; j < iW; j++) {
- int64_t index = t * iH * iW + i * iW + j;
- maxp = ind_p_k[index];
- if (maxp < 0 || maxp >= oT * oW * oH) {
-#pragma omp critical
- {
- has_error = true;
- error_index = maxp;
- }
- } else {
- output_p_k[maxp] = input_p_k[index];
- }
- }
- }
- }
- if (has_error) {
- AT_ERROR(
- "found an invalid max index ",
- error_index,
- " (output volumes are of size ",
- oT,
- "x",
- oH,
- "x",
- oW);
- (void)error_index;
- }
- }
- }
- return output;
-}
-
static void max_unpooling3d_shape_check(
const Tensor& input,
const Tensor& gradOutput,
}
output.zero_();
- AT_DISPATCH_FLOATING_TYPES(
- self.scalar_type(), "max_unpooling3d_forward_out_cpu_frame", ([&] {
- max_unpooling3d_forward_out_cpu_frame<scalar_t>(
- output,
- self,
- indices,
- oT,
- oH,
- oW);
- }));
+ max_unpool3d_kernel(kCPU, output, self, indices);
return output;
}
return output;
}
-template <typename scalar_t>
-static void max_unpooling2d_backward_out_cpu_frame(
- scalar_t* gradInput_p,
- scalar_t* gradOutput_p,
- int64_t* ind_p,
- int64_t nslices,
- int64_t iheight,
- int64_t iwidth,
- int64_t oheight,
- int64_t owidth) {
- bool has_error = false;
- int64_t error_index = 0;
- int64_t k = 0;
-
- at::internal::lazy_init_num_threads();
-#pragma omp parallel for private(k)
- for (k = 0; k < nslices; k++) {
- scalar_t* gradInput_p_k = gradInput_p + k * iwidth * iheight;
- scalar_t* gradOutput_p_k = gradOutput_p + k * owidth * oheight;
- int64_t* ind_p_k = ind_p + k * iwidth * iheight;
-
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int64_t i, j;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int64_t maxp;
-
- for (i = 0; i < iheight; i++) {
- for (j = 0; j < iwidth; j++) {
- maxp = ind_p_k[i * iwidth + j]; /* retrieve position of max */
- if (maxp < 0 || maxp >= owidth * oheight) {
-#pragma omp critical
- {
- has_error = true;
- error_index = maxp;
- }
- }
- gradInput_p_k[i * iwidth + j] =
- gradOutput_p_k[maxp]; /* update gradient */
- }
- }
- }
- if (has_error) {
- AT_ERROR(
- "invalid max index ",
- error_index,
- ", owidth= ",
- owidth,
- ", oheight= ",
- oheight);
- (void)error_index;
- }
-}
-
Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
const Tensor& self,
const Tensor& indices_,
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
int64_t oheight = output_size[0];
int64_t owidth = output_size[1];
- int dimw = 2;
- int dimh = 1;
- int nbatch = 1;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int nslices;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int iheight;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int iwidth;
+ int64_t ndim = self.ndimension();
+ int64_t dimh = ndim == 3 ? 1 : 2;
+ int64_t dimw = ndim == 3 ? 2 : 3;
+
TORCH_CHECK(
indices_.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64");
TORCH_CHECK(
self.sizes() == indices_.sizes(), "Input shape must match indices shape");
-
TORCH_CHECK(output_size.size() == 2, "Output size must be 2");
- /* get contiguous gradOutput and indices */
- auto grad_output = grad_output_.contiguous();
- auto indices = indices_.contiguous();
+ auto memory_format = self.suggest_memory_format();
+ auto grad_output = grad_output_.contiguous(memory_format);
+ auto indices = indices_.contiguous(memory_format);
- /* resize */
- grad_input.resize_as_(self);
+ grad_input.resize_(self.sizes(), memory_format);
grad_input.zero_();
- if (self.ndimension() == 4) {
- nbatch = self.size(0);
- dimw++;
- dimh++;
- }
-
- /* sizes */
- nslices = self.size(dimh - 1);
- iheight = self.size(dimh);
- iwidth = self.size(dimw);
-
if (owidth != grad_output.size(dimw) || oheight != grad_output.size(dimh)) {
AT_ERROR(
"Inconsistent gradOutput size. output height = ",
"x",
grad_output.size(dimw));
}
- AT_DISPATCH_FLOATING_TYPES(
- self.scalar_type(), "max_unpooling2d_backward_out_cpu_frame", ([&] {
- int p;
- for (p = 0; p < nbatch; p++) {
- auto inputOffset = p * nslices * iheight * iwidth;
- auto outputOffset = p * nslices * oheight * owidth;
- max_unpooling2d_backward_out_cpu_frame<scalar_t>(
- grad_input.data_ptr<scalar_t>() + inputOffset,
- grad_output.data_ptr<scalar_t>() + outputOffset,
- indices.data_ptr<int64_t>() + inputOffset,
- nslices,
- iheight,
- iwidth,
- oheight,
- owidth);
- }
- }));
+
+ max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
return grad_input;
}
const Tensor& self,
const Tensor& indices,
IntArrayRef output_size) {
- auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- at::native::max_unpooling2d_backward_out_cpu(
+ auto grad_input = at::empty({0}, self.options());
+ max_unpooling2d_backward_out_cpu(
grad_output, self, indices, output_size, grad_input);
return grad_input;
}
-template <typename scalar_t>
-static void max_unpooling3d_backward_out_cpu_frame(
- scalar_t* gradInput_p,
- scalar_t* gradOutput_p,
- int64_t* ind_p,
- int64_t nslices,
- int64_t iT,
- int64_t iH,
- int64_t iW,
- int64_t oT,
- int64_t oH,
- int64_t oW) {
- int64_t k = 0;
- bool has_error = false;
- int error_index = 0;
-
- at::internal::lazy_init_num_threads();
-
-#pragma omp parallel for private(k)
- for (k = 0; k < nslices; k++) {
- scalar_t* gradInput_p_k = gradInput_p + k * iT * iH * iW;
- scalar_t* gradOutput_p_k = gradOutput_p + k * oT * oH * oW;
- int64_t* ind_p_k = ind_p + k * iT * iH * iW;
-
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int64_t t, i, j, index;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int64_t maxp;
- for (t = 0; t < iT; t++) {
- for (i = 0; i < iH; i++) {
- for (j = 0; j < iW; j++) {
- index = t * iH * iW + i * iW + j;
- maxp = ind_p_k[index]; /* retrieve position of max */
- if (maxp < 0 || maxp >= oT * oH * oW) {
-#pragma omp critical
- {
- has_error = true;
- error_index = maxp;
- }
- }
- gradInput_p_k[index] = gradOutput_p_k[maxp]; /* update gradient */
- }
- }
- }
- }
- if (has_error) {
- AT_ERROR(
- "invalid max index ",
- error_index,
- ", oT= ",
- oT,
- ", oW= ",
- oW,
- ",oH= ",
- oH);
- (void)error_index;
- }
-}
-
-Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_,
+Tensor& max_unpooling3d_backward_out_cpu(
+ const Tensor& grad_output_,
const Tensor& self,
const Tensor& indices_,
IntArrayRef output_size,
IntArrayRef padding,
Tensor& grad_input) {
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
- auto oT = output_size[0];
- auto oH = output_size[1];
- auto oW = output_size[2];
- int dimw = 3;
- int dimh = 2;
- int dimt = 1;
- int nbatch = 1;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int nslices;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int iT;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int iH;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int iW;
+ int64_t oT = output_size[0];
+ int64_t oH = output_size[1];
+ int64_t oW = output_size[2];
+ int64_t ndim = self.ndimension();
+ int64_t dimt = ndim == 4 ? 1 : 2;
+ int64_t dimh = ndim == 4 ? 2 : 3;
+ int64_t dimw = ndim == 4 ? 3 : 4;
max_unpooling3d_shape_check(
self, grad_output_, indices_, output_size, stride, padding);
- // TODO (from THNN): check gradOutput shape
/* get contiguous gradOutput */
auto grad_output = grad_output_.contiguous();
auto indices = indices_.contiguous();
/* resize */
grad_input.resize_as_(self);
grad_input.zero_();
- if (self.ndimension() == 5) {
- nbatch = self.size(0);
- dimt++;
- dimw++;
- dimh++;
+
+ if (oW != grad_output.size(dimw) || oH != grad_output.size(dimh) || oT != grad_output.size(dimt)) {
+ AT_ERROR(
+ "Inconsistent gradOutput size. output depth = ",
+ oT,
+ ", output height = ",
+ oH,
+ ", output width = ",
+ oW,
+ ", gradOutput: ",
+ grad_output.size(dimt),
+ "x",
+ grad_output.size(dimh),
+ "x",
+ grad_output.size(dimw));
}
- /* sizes */
- nslices = self.size(dimt - 1);
- iT = self.size(dimt);
- iH = self.size(dimh);
- iW = self.size(dimw);
-
- /* backprop */
- AT_DISPATCH_FLOATING_TYPES(
- self.scalar_type(), "max_unpooling3d_backward_out_cpu_frame", ([&] {
- int p;
- for (p = 0; p < nbatch; p++) {
- int inputOffset = p * nslices * iT * iH * iW;
- int outputOffset = p * nslices * oT * oH * oW;
- max_unpooling3d_backward_out_cpu_frame<scalar_t>(
- grad_input.data_ptr<scalar_t>() + inputOffset,
- grad_output.data_ptr<scalar_t>() + outputOffset,
- indices.data_ptr<int64_t>() + inputOffset,
- nslices,
- iT,
- iH,
- iW,
- oT,
- oH,
- oW);
- }
- }));
+ max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
return grad_input;
}
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
- auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_input = at::empty({0}, self.options());
at::native::max_unpooling3d_backward_out_cpu(
grad_output, self, indices, output_size, stride, padding, grad_input);
return grad_input;
}
+
+DEFINE_DISPATCH(max_unpool2d_kernel);
+DEFINE_DISPATCH(max_unpool2d_backward_kernel);
+DEFINE_DISPATCH(max_unpool3d_kernel);
+DEFINE_DISPATCH(max_unpool3d_backward_kernel);
+
} // namespace native
} // namespace at
--- /dev/null
+#include <ATen/ATen.h>
+
+#include <ATen/Dispatch.h>
+#include <ATen/Parallel.h>
+#include <ATen/native/Pool.h>
+#include <ATen/native/cpu/utils.h>
+
+namespace at { namespace native {
+
+namespace {
+
+template <typename scalar_t, bool is_3d = false>
+void cpu_max_unpool(
+ Tensor& output_,
+ const Tensor& input,
+ const Tensor& indices) {
+ auto output = output_.contiguous();
+
+ auto input_data = input.data_ptr<scalar_t>();
+ auto indices_data = indices.data_ptr<int64_t>();
+ auto output_data = output.data_ptr<scalar_t>();
+
+ // NB: input tensor dimensions:
+ // MaxUnpool2d:
+ // dim = 3: CHW
+ // dim = 4: NCHW
+ // MaxUnpool3d:
+ // dim = 4: CDHW
+ // dim = 5: NCDHW
+
+ int64_t numel = input.numel();
+ int64_t ndim = input.ndimension();
+
+ // treat batch size and channels as one dimension
+ // and the feature map as another dimension
+ int64_t channels, output_depth, output_height, output_width;
+ if (is_3d) {
+ TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d: expect input to be 4d or 5d tensor.");
+ channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
+ output_depth = output.size(-3);
+ output_height = output.size(-2);
+ output_width = output.size(-1);
+ } else {
+ TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d: expect input to be 3d or 4d tensor.");
+ channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
+ output_depth = 1;
+ output_height = output.size(-2);
+ output_width = output.size(-1);
+ }
+ int64_t input_image_size = numel / channels;
+ int64_t output_image_size = output.numel() / channels;
+
+ bool has_error = false;
+ int64_t error_index = 0;
+
+ // parallel on dim N, C, D, H, W: [channels, input_image_size]
+ at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
+ int64_t c = 0;
+ int64_t ip = 0;
+ data_index_init(begin, c, channels, ip, input_image_size);
+
+ for (int64_t i = begin; i < end; i++) {
+ scalar_t* output_ptr = output_data + c * output_image_size;
+
+ int64_t maxp = indices_data[i];
+ if (maxp < 0 || maxp >= output_image_size) {
+ #pragma omp critical
+ {
+ has_error = true;
+ error_index = maxp;
+ }
+ } else {
+ output_ptr[maxp] = input_data[i];
+ }
+
+ // move on to next input index
+ data_index_step(c, channels, ip, input_image_size);
+ }
+ });
+
+ if (has_error) {
+ if (is_3d) {
+ AT_ERROR("Found an invalid max index: ", error_index,
+ " (output volumes are of size ", output_depth,
+ "x", output_height, "x", output_width);
+ (void)error_index;
+ } else {
+ AT_ERROR("Found an invalid max index: ", error_index,
+ " (output volumes are of size ", output_height,
+ "x", output_width);
+ (void)error_index;
+ }
+ }
+
+ if (!output_.is_contiguous()) {
+ output_.copy_(output);
+ }
+}
+
+template <typename scalar_t>
+void cpu_max_unpool_channels_last(
+ Tensor& output_,
+ const Tensor& input,
+ const Tensor& indices) {
+ TORCH_CHECK(input.ndimension() == 4,
+ "max_unpool2d with channels last format supports tensors with 4 dims");
+ auto memory_format = at::MemoryFormat::ChannelsLast;
+ auto output = output_.contiguous(memory_format);
+
+ auto input_data = input.data_ptr<scalar_t>();
+ auto indices_data = indices.data_ptr<int64_t>();
+ auto output_data = output.data_ptr<scalar_t>();
+
+ int64_t nbatch = input.size(0);
+ int64_t channels = input.size(1);
+ int64_t input_height = input.size(2);
+ int64_t input_width = input.size(3);
+ int64_t output_height = output.size(2);
+ int64_t output_width = output.size(3);
+ int64_t input_image_size = input_height * input_width;
+ int64_t output_image_size = output_height * output_width;
+
+ bool has_error = false;
+ int64_t error_index = 0;
+
+ // parallel on dim N, H, W
+ at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) {
+ int64_t n = 0;
+ int64_t ip = 0;
+ data_index_init(begin, n, nbatch, ip, input_image_size);
+
+ for (int64_t i = begin; i < end; i++) {
+ scalar_t* input_ptr = input_data + i * channels;
+ int64_t* indices_ptr = indices_data + i * channels;
+ scalar_t* output_ptr = output_data + n * output_image_size * channels;
+
+ // can't do scatter on avx2 (only available on avx512)
+ for (int64_t c = 0; c < channels; c++) {
+ int64_t maxp = indices_ptr[c];
+ if (maxp < 0 || maxp >= output_image_size) {
+ #pragma omp critical
+ {
+ has_error = true;
+ error_index = maxp;
+ }
+ } else {
+ output_ptr[maxp * channels + c] = input_ptr[c];
+ }
+ }
+
+ // move on to next input index
+ data_index_step(n, nbatch, ip, input_image_size);
+ }
+ });
+
+ if (has_error) {
+ AT_ERROR("Found an invalid max index: ", error_index,
+ " (output volumes are of size ", output_height,
+ "x", output_width);
+ (void)error_index;
+ }
+
+ if (!output_.is_contiguous(memory_format)) {
+ output_.copy_(output);
+ }
+}
+
+template <typename scalar_t, bool is_3d = false>
+void cpu_max_unpool_backward(
+ Tensor& grad_input_,
+ const Tensor& grad_output,
+ const Tensor& indices) {
+ auto grad_input = grad_input_.contiguous();
+
+ auto grad_output_data = grad_output.data_ptr<scalar_t>();
+ auto indices_data = indices.data_ptr<int64_t>();
+ auto grad_input_data = grad_input.data_ptr<scalar_t>();
+
+ int64_t numel = grad_input.numel();
+ int64_t ndim = grad_output.ndimension();
+
+ // treat batch size and channels as one dimension
+ // and the feature map as another dimension
+ int64_t channels, output_depth, output_height, output_width;
+ if (is_3d) {
+ TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d_backward: expect grad_output to be 4d or 5d tensor.");
+ channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
+ output_depth = grad_output.size(-3);
+ output_height = grad_output.size(-2);
+ output_width = grad_output.size(-1);
+ } else {
+ TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d_backward: expect grad_output to be 3d or 4d tensor.");
+ channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
+ output_depth = 1;
+ output_height = grad_output.size(-2);
+ output_width = grad_output.size(-1);
+ }
+ int64_t input_image_size = numel / channels;
+ int64_t output_image_size = grad_output.numel() / channels;
+
+ bool has_error = false;
+ int64_t error_index = 0;
+
+ // parallel on dim N, C, D, H, W
+ at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
+ int64_t c = 0;
+ int64_t ip = 0;
+ data_index_init(begin, c, channels, ip, input_image_size);
+
+ for (int64_t i = begin; i < end; i++) {
+ scalar_t* grad_output_ptr = grad_output_data + c * output_image_size;
+
+ int64_t maxp = indices_data[i];
+ if (maxp < 0 || maxp >= output_image_size) {
+ #pragma omp critical
+ {
+ has_error = true;
+ error_index = maxp;
+ }
+ } else {
+ grad_input_data[i] = grad_output_ptr[maxp];
+ }
+
+ // move on to next input index
+ data_index_step(c, channels, ip, input_image_size);
+ }
+ });
+
+ if (has_error) {
+ if (is_3d) {
+ AT_ERROR("invalid max index ", error_index,
+ ", odepth= ", output_depth,
+ ", owidth= ", output_width,
+ ", oheight= ", output_height);
+ (void)error_index;
+ } else {
+ AT_ERROR("invalid max index ", error_index,
+ ", owidth= ", output_width,
+ ", oheight= ", output_height);
+ (void)error_index;
+ }
+ }
+
+ if (!grad_input_.is_contiguous()) {
+ grad_input_.copy_(grad_input);
+ }
+}
+
+template <typename scalar_t>
+void cpu_max_unpool_backward_channels_last(
+ Tensor& grad_input_,
+ const Tensor& grad_output,
+ const Tensor& indices) {
+ TORCH_CHECK(grad_output.ndimension() == 4,
+ "max_unpool2d backward with channels last format supports tensors with 4 dims.");
+ auto memory_format = at::MemoryFormat::ChannelsLast;
+ auto grad_input = grad_input_.contiguous(memory_format);
+
+ auto grad_input_data = grad_input.data_ptr<scalar_t>();
+ auto grad_output_data = grad_output.data_ptr<scalar_t>();
+ auto indices_data = indices.data_ptr<int64_t>();
+
+ int64_t nbatch = grad_input.size(0);
+ int64_t channels = grad_input.size(1);
+ int64_t input_height = grad_input.size(2);
+ int64_t input_width = grad_input.size(3);
+ int64_t output_height = grad_output.size(2);
+ int64_t output_width = grad_output.size(3);
+ int64_t input_image_size = input_height * input_width;
+ int64_t output_image_size = output_height * output_width;
+
+ bool has_error = false;
+ int64_t error_index = 0;
+
+ // parallel on dim N, H, W
+ at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) {
+ int64_t n = 0;
+ int64_t ip = 0;
+ data_index_init(begin, n, nbatch, ip, input_image_size);
+
+ for (int64_t i = begin; i < end; i++) {
+ scalar_t* grad_output_ptr = grad_output_data + n * output_image_size * channels;
+ scalar_t* grad_input_ptr = grad_input_data + i * channels;
+ int64_t* indices_ptr = indices_data + i * channels;
+
+ for (int64_t c = 0; c < channels; c++) {
+ int64_t maxp = indices_ptr[c];
+ if (maxp < 0 || maxp >= output_image_size) {
+ #pragma omp critical
+ {
+ has_error = true;
+ error_index = maxp;
+ }
+ } else {
+ grad_input_ptr[c] = grad_output_ptr[maxp * channels + c];
+ }
+ }
+
+ // move on to next input index
+ data_index_step(n, nbatch, ip, input_image_size);
+ }
+ });
+
+ if (has_error) {
+ AT_ERROR("invalid max index ", error_index,
+ ", owidth= ", output_width,
+ ", oheight= ", output_height);
+ (void)error_index;
+ }
+
+ if (!grad_input_.is_contiguous(memory_format)) {
+ grad_input_.copy_(grad_input);
+ }
+}
+
+void max_unpool2d_kernel_impl(
+ Tensor& output,
+ const Tensor& input,
+ const Tensor& indices) {
+ switch(input.suggest_memory_format()) {
+ case at::MemoryFormat::Contiguous: {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d", [&] {
+ cpu_max_unpool<scalar_t, /*is_3d*/false>(output, input, indices);
+ });
+ break;
+ }
+ case at::MemoryFormat::ChannelsLast: {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d_channels_last", [&] {
+ cpu_max_unpool_channels_last<scalar_t>(output, input, indices);
+ });
+ break;
+ }
+ default:
+ TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+ }
+}
+
+void max_unpool3d_kernel_impl(
+ Tensor& output,
+ const Tensor& input,
+ const Tensor& indices) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool3d", [&] {
+ cpu_max_unpool<scalar_t, /*is_3d*/true>(output, input, indices);
+ });
+}
+
+void max_unpool2d_backward_kernel_impl(
+ Tensor& grad_input,
+ const Tensor& grad_output,
+ const Tensor& indices) {
+ switch(grad_output.suggest_memory_format()) {
+ case at::MemoryFormat::Contiguous: {
+ AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward", [&] {
+ cpu_max_unpool_backward<scalar_t, /*is_3d*/false>(grad_input, grad_output, indices);
+ });
+ break;
+ }
+ case at::MemoryFormat::ChannelsLast: {
+ AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward_channels_last", [&] {
+ cpu_max_unpool_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
+ });
+ break;
+ }
+ default:
+ TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+ }
+}
+
+void max_unpool3d_backward_kernel_impl(
+ Tensor& grad_input,
+ const Tensor& grad_output,
+ const Tensor& indices) {
+ AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool3d_backward", [&] {
+ cpu_max_unpool_backward<scalar_t, /*is_3d*/true>(grad_input, grad_output, indices);
+ });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(max_unpool2d_kernel, &max_unpool2d_kernel_impl);
+REGISTER_DISPATCH(max_unpool2d_backward_kernel, &max_unpool2d_backward_kernel_impl);
+REGISTER_DISPATCH(max_unpool3d_kernel, &max_unpool3d_kernel_impl);
+REGISTER_DISPATCH(max_unpool3d_backward_kernel, &max_unpool3d_backward_kernel_impl);
+
+}} // at::native
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/native/DispatchStub.h>
+
+#pragma once
+
+namespace at { namespace native {
+
+using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_backward_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_backward_kernel);
+
+}} // at::native
else:
self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
+ def test_max_unpool2d_nhwc_cpu(self):
+ input = torch.randn(2, 10, 9, 9).float().cpu()
+ input = input.contiguous(memory_format=torch.channels_last)
+ ref_input = input.clone().contiguous()
+
+ pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
+ ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
+
+ out, ind = pool(input)
+ ref_out, ref_ind = ref_pool(ref_input)
+ out.requires_grad_()
+ ref_out.requires_grad_()
+
+ unpool = nn.MaxUnpool2d(3, stride=2).cpu()
+ ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu()
+
+ upout = unpool(out, ind)
+ ref_upout = ref_unpool(ref_out, ref_ind)
+
+ grad = torch.randn(upout.size()).float().cpu()
+ grad = grad.contiguous(memory_format=torch.channels_last)
+ ref_grad = grad.clone().contiguous()
+
+ upout.backward(grad)
+ ref_upout.backward(ref_grad)
+
+ self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last))
+ self.assertTrue(ref_upout.is_contiguous())
+ self.assertTrue(torch.allclose(upout, ref_upout))
+ self.assertTrue(torch.allclose(out.grad, ref_out.grad))
+
def test_container_copy(self):
class Model(nn.Module):
def __init__(self):
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
"aten/src/ATen/native/cpu/MaxPooling.cpp",
"aten/src/ATen/native/cpu/MaxPoolKernel.cpp",
+ "aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp",
"aten/src/ATen/native/cpu/MultinomialKernel.cpp",
"aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp",
"aten/src/ATen/native/cpu/PowKernel.cpp",