add channel last support for MaxUnpool2d (#49984)
authormingfeima <mingfei.ma@intel.com>
Mon, 30 Aug 2021 01:35:37 +0000 (18:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 01:37:10 +0000 (18:37 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49984

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D26007051

Pulled By: VitalyFedyunin

fbshipit-source-id: 6c54751ade4092e03c1651aaa60380f7d6e92f6b

aten/src/ATen/native/MaxUnpooling.cpp
aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp [new file with mode: 0644]
aten/src/ATen/native/cpu/MaxUnpoolKernel.h [new file with mode: 0644]
test/test_nn.py
tools/build_variables.bzl

index b3c0194..9987408 100644 (file)
@@ -1,90 +1,17 @@
 #include <ATen/ATen.h>
 #include <ATen/NativeFunctions.h>
-#include <ATen/Parallel.h>
-#include <tuple>
+#include <ATen/native/cpu/MaxUnpoolKernel.h>
 
 namespace at {
 namespace native {
 
-template <typename scalar_t>
-Tensor max_unpooling2d_forward_out_cpu_frame(
-    Tensor& output,
-    const Tensor& input,
-    const Tensor& indices,
-    int64_t oheight,
-    int64_t owidth) {
-  int64_t numBatch = 1;
-  int64_t dimc = 0;
-  int64_t dimh = 1;
-  int64_t dimw = 2;
-  if (input.ndimension() == 4) {
-    numBatch = input.size(0);
-    dimc++;
-    dimh++;
-    dimw++;
-  }
-  int64_t numChannels = input.size(dimc);
-  int64_t inputHeight = input.size(dimh);
-  int64_t inputWidth = input.size(dimw);
-
-  auto* rawInput = input.data_ptr<scalar_t>();
-  auto* rawIndices = indices.data_ptr<int64_t>();
-  auto* rawOutput = output.data_ptr<scalar_t>();
-
-  at::internal::lazy_init_num_threads();
-
-  for (int64_t n = 0; n < numBatch; n++) {
-    int64_t nOutputOffset = n * numChannels * owidth * oheight;
-    int64_t nInputOffset = n * numChannels * inputWidth * inputHeight;
-    int64_t k = 0;
-    bool has_error = false;
-    int64_t error_index = 0;
-#pragma omp parallel for private(k)
-    for (k = 0; k < numChannels; k++) {
-      int64_t finalOutputOffset = nOutputOffset + k * owidth * oheight;
-      int64_t finalInputOffset = nInputOffset + k * inputWidth * inputHeight;
-      scalar_t* output_p_k = rawOutput + finalOutputOffset;
-      scalar_t* input_p_k = rawInput + finalInputOffset;
-      int64_t* ind_p_k = rawIndices + finalInputOffset;
-
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      int64_t maxp;
-      for (int64_t i = 0; i < inputHeight; i++) {
-        for (int64_t j = 0; j < inputWidth; j++) {
-          maxp = ind_p_k[i * inputWidth + j];
-          if (maxp < 0 || maxp >= owidth * oheight) {
-#pragma omp critical
-            {
-              has_error = true;
-              error_index = maxp;
-            }
-          } else {
-            output_p_k[maxp] = input_p_k[i * inputWidth + j];
-          }
-        }
-      }
-    }
-    if (has_error) {
-      AT_ERROR(
-          "Found an invalid max index: ",
-          error_index,
-          " (output volumes are of size ",
-          oheight,
-          "x",
-          owidth);
-      (void)error_index;
-    }
-  }
-  return output;
-}
-
-Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_,
+Tensor& max_unpooling2d_forward_out_cpu(
+    const Tensor& self_,
     const Tensor& indices_,
     IntArrayRef output_size,
     Tensor& output) {
   auto oheight = output_size[0];
   auto owidth = output_size[1];
-  TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
   TORCH_CHECK(
       indices_.scalar_type() == at::ScalarType::Long,
       "elements in indices should be type int64");
@@ -100,8 +27,9 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_,
 
   TORCH_CHECK(self_.numel() > 0, "Input must be non-empty");
 
-  auto self = self_.contiguous();
-  auto indices = indices_.contiguous();
+  auto memory_format = self_.suggest_memory_format();
+  auto self = self_.contiguous(memory_format);
+  auto indices = indices_.contiguous(memory_format);
 
   if (self.ndimension() == 3) {
     int64_t numChannels = self.size(0);
@@ -109,15 +37,11 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_,
   } else {
     int64_t numBatch = self.size(0);
     int64_t numChannels = self.size(1);
-    output.resize_({numBatch, numChannels, oheight, owidth});
+    output.resize_({numBatch, numChannels, oheight, owidth}, memory_format);
   }
   output.zero_();
 
-  AT_DISPATCH_FLOATING_TYPES(
-      self.scalar_type(), "max_unpooling2d_forward_out_cpu_frame", ([&] {
-        max_unpooling2d_forward_out_cpu_frame<scalar_t>(
-            output, self, indices, oheight, owidth);
-      }));
+  max_unpool2d_kernel(kCPU, output, self, indices);
   return output;
 };
 
@@ -130,87 +54,6 @@ Tensor max_unpooling2d_forward_cpu(
   return output;
 }
 
-template <typename scalar_t>
-Tensor max_unpooling3d_forward_out_cpu_frame(
-    Tensor& output,
-    const Tensor& input,
-    const Tensor& indices,
-    int64_t oT,
-    int64_t oH,
-    int64_t oW) {
-  int64_t nBatch = 1;
-  int64_t dimw = 3;
-  int64_t dimh = 2;
-  int64_t dimt = 1;
-
-  if (input.ndimension() == 5) {
-    nBatch = input.size(0);
-    dimw++;
-    dimh++;
-    dimt++;
-  }
-
-  int64_t nSlices = input.size(dimt - 1);
-  int64_t iT = input.size(dimt);
-  int64_t iH = input.size(dimh);
-  int64_t iW = input.size(dimw);
-
-  scalar_t* input_data = input.data_ptr<scalar_t>();
-  scalar_t* output_data = output.data_ptr<scalar_t>();
-  int64_t* indices_data = indices.data_ptr<int64_t>();
-
-  at::internal::lazy_init_num_threads();
-
-  for (int64_t p = 0; p < nBatch; p++) {
-    int64_t inputOffset = p * nSlices * iT * iW * iH;
-    int64_t outputOffset = p * nSlices * oT * oW * oH;
-    int64_t k = 0;
-    bool has_error = false;
-    int error_index = 0;
-#pragma omp parallel for private(k)
-    for (k = 0; k < nSlices; k++) {
-      int64_t finalInputOffset = inputOffset + k * iT * iW * iH;
-      int64_t finalOutputOffset = outputOffset + k * oT * oW * oH;
-
-      scalar_t* output_p_k = output_data + finalOutputOffset;
-      scalar_t* input_p_k = input_data + finalInputOffset;
-      int64_t* ind_p_k = indices_data + finalInputOffset;
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      int maxp;
-      for (int64_t t = 0; t < iT; t++) {
-        for (int64_t i = 0; i < iH; i++) {
-          for (int64_t j = 0; j < iW; j++) {
-            int64_t index = t * iH * iW + i * iW + j;
-            maxp = ind_p_k[index];
-            if (maxp < 0 || maxp >= oT * oW * oH) {
-#pragma omp critical
-              {
-                has_error = true;
-                error_index = maxp;
-              }
-            } else {
-              output_p_k[maxp] = input_p_k[index];
-            }
-          }
-        }
-      }
-      if (has_error) {
-        AT_ERROR(
-            "found an invalid max index ",
-            error_index,
-            " (output volumes are of size ",
-            oT,
-            "x",
-            oH,
-            "x",
-            oW);
-        (void)error_index;
-      }
-    }
-  }
-  return output;
-}
-
 static void max_unpooling3d_shape_check(
     const Tensor& input,
     const Tensor& gradOutput,
@@ -310,16 +153,7 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
   }
   output.zero_();
 
-  AT_DISPATCH_FLOATING_TYPES(
-      self.scalar_type(), "max_unpooling3d_forward_out_cpu_frame", ([&] {
-        max_unpooling3d_forward_out_cpu_frame<scalar_t>(
-            output,
-            self,
-            indices,
-            oT,
-            oH,
-            oW);
-      }));
+  max_unpool3d_kernel(kCPU, output, self, indices);
   return output;
 }
 
@@ -335,59 +169,6 @@ Tensor max_unpooling3d_forward_cpu(
   return output;
 }
 
-template <typename scalar_t>
-static void max_unpooling2d_backward_out_cpu_frame(
-    scalar_t* gradInput_p,
-    scalar_t* gradOutput_p,
-    int64_t* ind_p,
-    int64_t nslices,
-    int64_t iheight,
-    int64_t iwidth,
-    int64_t oheight,
-    int64_t owidth) {
-  bool has_error = false;
-  int64_t error_index = 0;
-  int64_t k = 0;
-
-  at::internal::lazy_init_num_threads();
-#pragma omp parallel for private(k)
-  for (k = 0; k < nslices; k++) {
-    scalar_t* gradInput_p_k = gradInput_p + k * iwidth * iheight;
-    scalar_t* gradOutput_p_k = gradOutput_p + k * owidth * oheight;
-    int64_t* ind_p_k = ind_p + k * iwidth * iheight;
-
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    int64_t i, j;
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    int64_t maxp;
-
-    for (i = 0; i < iheight; i++) {
-      for (j = 0; j < iwidth; j++) {
-        maxp = ind_p_k[i * iwidth + j]; /* retrieve position of max */
-        if (maxp < 0 || maxp >= owidth * oheight) {
-#pragma omp critical
-          {
-            has_error = true;
-            error_index = maxp;
-          }
-        }
-        gradInput_p_k[i * iwidth + j] =
-            gradOutput_p_k[maxp]; /* update gradient */
-      }
-    }
-  }
-  if (has_error) {
-    AT_ERROR(
-        "invalid max index ",
-        error_index,
-        ", owidth= ",
-        owidth,
-        ", oheight= ",
-        oheight);
-    (void)error_index;
-  }
-}
-
 Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
     const Tensor& self,
     const Tensor& indices_,
@@ -396,42 +177,24 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
   int64_t oheight = output_size[0];
   int64_t owidth = output_size[1];
-  int dimw = 2;
-  int dimh = 1;
-  int nbatch = 1;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int nslices;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int iheight;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int iwidth;
+  int64_t ndim = self.ndimension();
+  int64_t dimh = ndim == 3 ? 1 : 2;
+  int64_t dimw = ndim == 3 ? 2 : 3;
+
   TORCH_CHECK(
       indices_.scalar_type() == at::ScalarType::Long,
       "elements in indices should be type int64");
   TORCH_CHECK(
       self.sizes() == indices_.sizes(), "Input shape must match indices shape");
-
   TORCH_CHECK(output_size.size() == 2, "Output size must be 2");
 
-  /* get contiguous gradOutput and indices */
-  auto grad_output = grad_output_.contiguous();
-  auto indices = indices_.contiguous();
+  auto memory_format = self.suggest_memory_format();
+  auto grad_output = grad_output_.contiguous(memory_format);
+  auto indices = indices_.contiguous(memory_format);
 
-  /* resize */
-  grad_input.resize_as_(self);
+  grad_input.resize_(self.sizes(), memory_format);
   grad_input.zero_();
 
-  if (self.ndimension() == 4) {
-    nbatch = self.size(0);
-    dimw++;
-    dimh++;
-  }
-
-  /* sizes */
-  nslices = self.size(dimh - 1);
-  iheight = self.size(dimh);
-  iwidth = self.size(dimw);
-
   if (owidth != grad_output.size(dimw) || oheight != grad_output.size(dimh)) {
     AT_ERROR(
         "Inconsistent gradOutput size. output height = ",
@@ -443,23 +206,8 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
         "x",
         grad_output.size(dimw));
   }
-  AT_DISPATCH_FLOATING_TYPES(
-      self.scalar_type(), "max_unpooling2d_backward_out_cpu_frame", ([&] {
-        int p;
-        for (p = 0; p < nbatch; p++) {
-          auto inputOffset = p * nslices * iheight * iwidth;
-          auto outputOffset = p * nslices * oheight * owidth;
-          max_unpooling2d_backward_out_cpu_frame<scalar_t>(
-              grad_input.data_ptr<scalar_t>() + inputOffset,
-              grad_output.data_ptr<scalar_t>() + outputOffset,
-              indices.data_ptr<int64_t>() + inputOffset,
-              nslices,
-              iheight,
-              iwidth,
-              oheight,
-              owidth);
-        }
-      }));
+
+  max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
   return grad_input;
 }
 
@@ -468,72 +216,14 @@ Tensor max_unpooling2d_backward_cpu(
     const Tensor& self,
     const Tensor& indices,
     IntArrayRef output_size) {
-  auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  at::native::max_unpooling2d_backward_out_cpu(
+  auto grad_input = at::empty({0}, self.options());
+  max_unpooling2d_backward_out_cpu(
       grad_output, self, indices, output_size, grad_input);
   return grad_input;
 }
 
-template <typename scalar_t>
-static void max_unpooling3d_backward_out_cpu_frame(
-    scalar_t* gradInput_p,
-    scalar_t* gradOutput_p,
-    int64_t* ind_p,
-    int64_t nslices,
-    int64_t iT,
-    int64_t iH,
-    int64_t iW,
-    int64_t oT,
-    int64_t oH,
-    int64_t oW) {
-  int64_t k = 0;
-  bool has_error = false;
-  int error_index = 0;
-
-  at::internal::lazy_init_num_threads();
-
-#pragma omp parallel for private(k)
-  for (k = 0; k < nslices; k++) {
-    scalar_t* gradInput_p_k = gradInput_p + k * iT * iH * iW;
-    scalar_t* gradOutput_p_k = gradOutput_p + k * oT * oH * oW;
-    int64_t* ind_p_k = ind_p + k * iT * iH * iW;
-
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    int64_t t, i, j, index;
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    int64_t maxp;
-    for (t = 0; t < iT; t++) {
-      for (i = 0; i < iH; i++) {
-        for (j = 0; j < iW; j++) {
-          index = t * iH * iW + i * iW + j;
-          maxp = ind_p_k[index]; /* retrieve position of max */
-          if (maxp < 0 || maxp >= oT * oH * oW) {
-#pragma omp critical
-            {
-              has_error = true;
-              error_index = maxp;
-            }
-          }
-          gradInput_p_k[index] = gradOutput_p_k[maxp]; /* update gradient */
-        }
-      }
-    }
-  }
-  if (has_error) {
-    AT_ERROR(
-        "invalid max index ",
-        error_index,
-        ", oT= ",
-        oT,
-        ", oW= ",
-        oW,
-        ",oH= ",
-        oH);
-    (void)error_index;
-  }
-}
-
-Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_,
+Tensor& max_unpooling3d_backward_out_cpu(
+    const Tensor& grad_output_,
     const Tensor& self,
     const Tensor& indices_,
     IntArrayRef output_size,
@@ -541,26 +231,17 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_,
     IntArrayRef padding,
     Tensor& grad_input) {
   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
-  auto oT = output_size[0];
-  auto oH = output_size[1];
-  auto oW = output_size[2];
-  int dimw = 3;
-  int dimh = 2;
-  int dimt = 1;
-  int nbatch = 1;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int nslices;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int iT;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int iH;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  int iW;
+  int64_t oT = output_size[0];
+  int64_t oH = output_size[1];
+  int64_t oW = output_size[2];
+  int64_t ndim = self.ndimension();
+  int64_t dimt = ndim == 4 ? 1 : 2;
+  int64_t dimh = ndim == 4 ? 2 : 3;
+  int64_t dimw = ndim == 4 ? 3 : 4;
 
   max_unpooling3d_shape_check(
       self, grad_output_, indices_, output_size, stride, padding);
 
-  // TODO (from THNN): check gradOutput shape
   /* get contiguous gradOutput */
   auto grad_output = grad_output_.contiguous();
   auto indices = indices_.contiguous();
@@ -568,39 +249,24 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_,
   /* resize */
   grad_input.resize_as_(self);
   grad_input.zero_();
-  if (self.ndimension() == 5) {
-    nbatch = self.size(0);
-    dimt++;
-    dimw++;
-    dimh++;
+
+  if (oW != grad_output.size(dimw) || oH != grad_output.size(dimh) || oT != grad_output.size(dimt)) {
+    AT_ERROR(
+        "Inconsistent gradOutput size. output depth = ",
+        oT,
+        ", output height = ",
+        oH,
+        ", output width = ",
+        oW,
+        ", gradOutput: ",
+        grad_output.size(dimt),
+        "x",
+        grad_output.size(dimh),
+        "x",
+        grad_output.size(dimw));
   }
 
-  /* sizes */
-  nslices = self.size(dimt - 1);
-  iT = self.size(dimt);
-  iH = self.size(dimh);
-  iW = self.size(dimw);
-
-  /* backprop */
-  AT_DISPATCH_FLOATING_TYPES(
-      self.scalar_type(), "max_unpooling3d_backward_out_cpu_frame", ([&] {
-        int p;
-        for (p = 0; p < nbatch; p++) {
-          int inputOffset = p * nslices * iT * iH * iW;
-          int outputOffset = p * nslices * oT * oH * oW;
-          max_unpooling3d_backward_out_cpu_frame<scalar_t>(
-              grad_input.data_ptr<scalar_t>() + inputOffset,
-              grad_output.data_ptr<scalar_t>() + outputOffset,
-              indices.data_ptr<int64_t>() + inputOffset,
-              nslices,
-              iT,
-              iH,
-              iW,
-              oT,
-              oH,
-              oW);
-        }
-      }));
+  max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
   return grad_input;
 }
 
@@ -611,10 +277,16 @@ Tensor max_unpooling3d_backward_cpu(
     IntArrayRef output_size,
     IntArrayRef stride,
     IntArrayRef padding) {
-  auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  auto grad_input = at::empty({0}, self.options());
   at::native::max_unpooling3d_backward_out_cpu(
       grad_output, self, indices, output_size, stride, padding, grad_input);
   return grad_input;
 }
+
+DEFINE_DISPATCH(max_unpool2d_kernel);
+DEFINE_DISPATCH(max_unpool2d_backward_kernel);
+DEFINE_DISPATCH(max_unpool3d_kernel);
+DEFINE_DISPATCH(max_unpool3d_backward_kernel);
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp
new file mode 100644 (file)
index 0000000..5a7b031
--- /dev/null
@@ -0,0 +1,385 @@
+#include <ATen/ATen.h>
+
+#include <ATen/Dispatch.h>
+#include <ATen/Parallel.h>
+#include <ATen/native/Pool.h>
+#include <ATen/native/cpu/utils.h>
+
+namespace at { namespace native {
+
+namespace {
+
+template <typename scalar_t, bool is_3d = false>
+void cpu_max_unpool(
+    Tensor& output_,
+    const Tensor& input,
+    const Tensor& indices) {
+  auto output = output_.contiguous();
+
+  auto input_data = input.data_ptr<scalar_t>();
+  auto indices_data = indices.data_ptr<int64_t>();
+  auto output_data = output.data_ptr<scalar_t>();
+
+  // NB: input tensor dimensions:
+  // MaxUnpool2d:
+  //    dim = 3: CHW
+  //    dim = 4: NCHW
+  // MaxUnpool3d:
+  //    dim = 4: CDHW
+  //    dim = 5: NCDHW
+
+  int64_t numel = input.numel();
+  int64_t ndim = input.ndimension();
+
+  // treat batch size and channels as one dimension
+  // and the feature map as another dimension
+  int64_t channels, output_depth, output_height, output_width;
+  if (is_3d) {
+    TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d: expect input to be 4d or 5d tensor.");
+    channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
+    output_depth = output.size(-3);
+    output_height = output.size(-2);
+    output_width = output.size(-1);
+  } else {
+    TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d: expect input to be 3d or 4d tensor.");
+    channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
+    output_depth = 1;
+    output_height = output.size(-2);
+    output_width = output.size(-1);
+  }
+  int64_t input_image_size = numel / channels;
+  int64_t output_image_size = output.numel() / channels;
+
+  bool has_error = false;
+  int64_t error_index = 0;
+
+  // parallel on dim N, C, D, H, W: [channels, input_image_size]
+  at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
+    int64_t c = 0;
+    int64_t ip = 0;
+    data_index_init(begin, c, channels, ip, input_image_size);
+
+    for (int64_t i = begin; i < end; i++) {
+      scalar_t* output_ptr = output_data + c * output_image_size;
+
+      int64_t maxp = indices_data[i];
+      if (maxp < 0 || maxp >= output_image_size) {
+        #pragma omp critical
+        {
+          has_error = true;
+          error_index = maxp;
+        }
+      } else {
+        output_ptr[maxp] = input_data[i];
+      }
+
+      // move on to next input index
+      data_index_step(c, channels, ip, input_image_size);
+    }
+  });
+
+  if (has_error) {
+    if (is_3d) {
+      AT_ERROR("Found an invalid max index: ", error_index,
+          " (output volumes are of size ", output_depth,
+          "x", output_height, "x", output_width);
+      (void)error_index;
+    } else {
+      AT_ERROR("Found an invalid max index: ", error_index,
+          " (output volumes are of size ", output_height,
+          "x", output_width);
+      (void)error_index;
+    }
+  }
+
+  if (!output_.is_contiguous()) {
+    output_.copy_(output);
+  }
+}
+
+template <typename scalar_t>
+void cpu_max_unpool_channels_last(
+    Tensor& output_,
+    const Tensor& input,
+    const Tensor& indices) {
+  TORCH_CHECK(input.ndimension() == 4,
+              "max_unpool2d with channels last format supports tensors with 4 dims");
+  auto memory_format = at::MemoryFormat::ChannelsLast;
+  auto output = output_.contiguous(memory_format);
+
+  auto input_data = input.data_ptr<scalar_t>();
+  auto indices_data = indices.data_ptr<int64_t>();
+  auto output_data = output.data_ptr<scalar_t>();
+
+  int64_t nbatch = input.size(0);
+  int64_t channels = input.size(1);
+  int64_t input_height = input.size(2);
+  int64_t input_width = input.size(3);
+  int64_t output_height = output.size(2);
+  int64_t output_width = output.size(3);
+  int64_t input_image_size = input_height * input_width;
+  int64_t output_image_size = output_height * output_width;
+
+  bool has_error = false;
+  int64_t error_index = 0;
+
+  // parallel on dim N, H, W
+  at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) {
+    int64_t n = 0;
+    int64_t ip = 0;
+    data_index_init(begin, n, nbatch, ip, input_image_size);
+
+    for (int64_t i = begin; i < end; i++) {
+      scalar_t* input_ptr = input_data + i * channels;
+      int64_t* indices_ptr = indices_data + i * channels;
+      scalar_t* output_ptr = output_data + n * output_image_size * channels;
+
+      // can't do scatter on avx2 (only available on avx512)
+      for (int64_t c = 0; c < channels; c++) {
+        int64_t maxp = indices_ptr[c];
+        if (maxp < 0 || maxp >= output_image_size) {
+          #pragma omp critical
+          {
+            has_error = true;
+            error_index = maxp;
+          }
+        } else {
+          output_ptr[maxp * channels + c] = input_ptr[c];
+        }
+      }
+
+      // move on to next input index
+      data_index_step(n, nbatch, ip, input_image_size);
+    }
+  });
+
+  if (has_error) {
+    AT_ERROR("Found an invalid max index: ", error_index,
+        " (output volumes are of size ", output_height,
+        "x", output_width);
+    (void)error_index;
+  }
+
+  if (!output_.is_contiguous(memory_format)) {
+    output_.copy_(output);
+  }
+}
+
+template <typename scalar_t, bool is_3d = false>
+void cpu_max_unpool_backward(
+    Tensor& grad_input_,
+    const Tensor& grad_output,
+    const Tensor& indices) {
+  auto grad_input = grad_input_.contiguous();
+
+  auto grad_output_data = grad_output.data_ptr<scalar_t>();
+  auto indices_data = indices.data_ptr<int64_t>();
+  auto grad_input_data = grad_input.data_ptr<scalar_t>();
+
+  int64_t numel = grad_input.numel();
+  int64_t ndim = grad_output.ndimension();
+
+  // treat batch size and channels as one dimension
+  // and the feature map as another dimension
+  int64_t channels, output_depth, output_height, output_width;
+  if (is_3d) {
+    TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d_backward: expect grad_output to be 4d or 5d tensor.");
+    channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
+    output_depth = grad_output.size(-3);
+    output_height = grad_output.size(-2);
+    output_width = grad_output.size(-1);
+  } else {
+    TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d_backward: expect grad_output to be 3d or 4d tensor.");
+    channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
+    output_depth = 1;
+    output_height = grad_output.size(-2);
+    output_width = grad_output.size(-1);
+  }
+  int64_t input_image_size = numel / channels;
+  int64_t output_image_size = grad_output.numel() / channels;
+
+  bool has_error = false;
+  int64_t error_index = 0;
+
+  // parallel on dim N, C, D, H, W
+  at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
+    int64_t c = 0;
+    int64_t ip = 0;
+    data_index_init(begin, c, channels, ip, input_image_size);
+
+    for (int64_t i = begin; i < end; i++) {
+      scalar_t* grad_output_ptr = grad_output_data + c * output_image_size;
+
+      int64_t maxp = indices_data[i];
+      if (maxp < 0 || maxp >= output_image_size) {
+        #pragma omp critical
+        {
+          has_error = true;
+          error_index = maxp;
+        }
+      } else {
+        grad_input_data[i] = grad_output_ptr[maxp];
+      }
+
+      // move on to next input index
+      data_index_step(c, channels, ip, input_image_size);
+    }
+  });
+
+  if (has_error) {
+    if (is_3d) {
+      AT_ERROR("invalid max index ", error_index,
+          ", odepth= ", output_depth,
+          ", owidth= ", output_width,
+          ", oheight= ", output_height);
+      (void)error_index;
+    } else {
+      AT_ERROR("invalid max index ", error_index,
+          ", owidth= ", output_width,
+          ", oheight= ", output_height);
+      (void)error_index;
+    }
+  }
+
+  if (!grad_input_.is_contiguous()) {
+    grad_input_.copy_(grad_input);
+  }
+}
+
+template <typename scalar_t>
+void cpu_max_unpool_backward_channels_last(
+    Tensor& grad_input_,
+    const Tensor& grad_output,
+    const Tensor& indices) {
+  TORCH_CHECK(grad_output.ndimension() == 4,
+      "max_unpool2d  backward with channels last format supports tensors with 4 dims.");
+  auto memory_format = at::MemoryFormat::ChannelsLast;
+  auto grad_input = grad_input_.contiguous(memory_format);
+
+  auto grad_input_data = grad_input.data_ptr<scalar_t>();
+  auto grad_output_data = grad_output.data_ptr<scalar_t>();
+  auto indices_data = indices.data_ptr<int64_t>();
+
+  int64_t nbatch = grad_input.size(0);
+  int64_t channels = grad_input.size(1);
+  int64_t input_height = grad_input.size(2);
+  int64_t input_width = grad_input.size(3);
+  int64_t output_height = grad_output.size(2);
+  int64_t output_width = grad_output.size(3);
+  int64_t input_image_size = input_height * input_width;
+  int64_t output_image_size = output_height * output_width;
+
+  bool has_error = false;
+  int64_t error_index = 0;
+
+  // parallel on dim N, H, W
+  at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) {
+    int64_t n = 0;
+    int64_t ip = 0;
+    data_index_init(begin, n, nbatch, ip, input_image_size);
+
+    for (int64_t i = begin; i < end; i++) {
+      scalar_t* grad_output_ptr = grad_output_data + n * output_image_size * channels;
+      scalar_t* grad_input_ptr = grad_input_data + i * channels;
+      int64_t* indices_ptr = indices_data + i * channels;
+
+      for (int64_t c = 0; c < channels; c++) {
+        int64_t maxp = indices_ptr[c];
+        if (maxp < 0 || maxp >= output_image_size) {
+          #pragma omp critical
+          {
+            has_error = true;
+            error_index = maxp;
+          }
+        } else {
+          grad_input_ptr[c] = grad_output_ptr[maxp * channels + c];
+        }
+      }
+
+      // move on to next input index
+      data_index_step(n, nbatch, ip, input_image_size);
+    }
+  });
+
+  if (has_error) {
+    AT_ERROR("invalid max index ", error_index,
+        ", owidth= ", output_width,
+        ", oheight= ", output_height);
+    (void)error_index;
+  }
+
+  if (!grad_input_.is_contiguous(memory_format)) {
+    grad_input_.copy_(grad_input);
+  }
+}
+
+void max_unpool2d_kernel_impl(
+    Tensor& output,
+    const Tensor& input,
+    const Tensor& indices) {
+  switch(input.suggest_memory_format()) {
+    case at::MemoryFormat::Contiguous: {
+      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d", [&] {
+        cpu_max_unpool<scalar_t, /*is_3d*/false>(output, input, indices);
+      });
+      break;
+    }
+    case at::MemoryFormat::ChannelsLast: {
+      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d_channels_last", [&] {
+        cpu_max_unpool_channels_last<scalar_t>(output, input, indices);
+      });
+      break;
+    }
+    default:
+      TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+  }
+}
+
+void max_unpool3d_kernel_impl(
+    Tensor& output,
+    const Tensor& input,
+    const Tensor& indices) {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool3d", [&] {
+    cpu_max_unpool<scalar_t, /*is_3d*/true>(output, input, indices);
+  });
+}
+
+void max_unpool2d_backward_kernel_impl(
+    Tensor& grad_input,
+    const Tensor& grad_output,
+    const Tensor& indices) {
+  switch(grad_output.suggest_memory_format()) {
+    case at::MemoryFormat::Contiguous: {
+      AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward", [&] {
+        cpu_max_unpool_backward<scalar_t, /*is_3d*/false>(grad_input, grad_output, indices);
+      });
+      break;
+    }
+    case at::MemoryFormat::ChannelsLast: {
+      AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward_channels_last", [&] {
+        cpu_max_unpool_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
+      });
+      break;
+    }
+    default:
+      TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+  }
+}
+
+void max_unpool3d_backward_kernel_impl(
+    Tensor& grad_input,
+    const Tensor& grad_output,
+    const Tensor& indices) {
+  AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool3d_backward", [&] {
+    cpu_max_unpool_backward<scalar_t, /*is_3d*/true>(grad_input, grad_output, indices);
+  });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(max_unpool2d_kernel, &max_unpool2d_kernel_impl);
+REGISTER_DISPATCH(max_unpool2d_backward_kernel, &max_unpool2d_backward_kernel_impl);
+REGISTER_DISPATCH(max_unpool3d_kernel, &max_unpool3d_kernel_impl);
+REGISTER_DISPATCH(max_unpool3d_backward_kernel, &max_unpool3d_backward_kernel_impl);
+
+}} // at::native
diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.h b/aten/src/ATen/native/cpu/MaxUnpoolKernel.h
new file mode 100644 (file)
index 0000000..00fbeb6
--- /dev/null
@@ -0,0 +1,16 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/native/DispatchStub.h>
+
+#pragma once
+
+namespace at { namespace native {
+
+using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_backward_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_backward_kernel);
+
+}} // at::native
index 4e01c94..7d26246 100644 (file)
@@ -6186,6 +6186,37 @@ class TestNN(NNTestCase):
                 else:
                     self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
 
+    def test_max_unpool2d_nhwc_cpu(self):
+        input = torch.randn(2, 10, 9, 9).float().cpu()
+        input = input.contiguous(memory_format=torch.channels_last)
+        ref_input = input.clone().contiguous()
+
+        pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
+        ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
+
+        out, ind = pool(input)
+        ref_out, ref_ind = ref_pool(ref_input)
+        out.requires_grad_()
+        ref_out.requires_grad_()
+
+        unpool = nn.MaxUnpool2d(3, stride=2).cpu()
+        ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu()
+
+        upout = unpool(out, ind)
+        ref_upout = ref_unpool(ref_out, ref_ind)
+
+        grad = torch.randn(upout.size()).float().cpu()
+        grad = grad.contiguous(memory_format=torch.channels_last)
+        ref_grad = grad.clone().contiguous()
+
+        upout.backward(grad)
+        ref_upout.backward(ref_grad)
+
+        self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last))
+        self.assertTrue(ref_upout.is_contiguous())
+        self.assertTrue(torch.allclose(upout, ref_upout))
+        self.assertTrue(torch.allclose(out.grad, ref_out.grad))
+
     def test_container_copy(self):
         class Model(nn.Module):
             def __init__(self):
index b2a1016..34846b5 100644 (file)
@@ -907,6 +907,7 @@ aten_native_source_codegen_list = [
     "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
     "aten/src/ATen/native/cpu/MaxPooling.cpp",
     "aten/src/ATen/native/cpu/MaxPoolKernel.cpp",
+    "aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp",
     "aten/src/ATen/native/cpu/MultinomialKernel.cpp",
     "aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp",
     "aten/src/ATen/native/cpu/PowKernel.cpp",