Port replication_pad2d and replication_pad3d to ATen (#15538)
authorLin Huang <lhuang04@fb.com>
Sat, 5 Jan 2019 00:59:18 +0000 (16:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 5 Jan 2019 01:08:14 +0000 (17:08 -0800)
Summary:
port replication padding 2D and 3D from legacy TH API implementation
to ATen implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15538

Differential Revision: D13547567

Pulled By: lhuang04

fbshipit-source-id: decfe100d9edfdcfb62f39ee23f37b6cae0d461f

16 files changed:
aten/src/ATen/native/LegacyNNDefinitions.cpp
aten/src/ATen/native/ReplicationPadding.cpp
aten/src/ATen/native/cuda/ReplicationPadding.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/nn.yaml
aten/src/THCUNN/CMakeLists.txt
aten/src/THCUNN/SpatialReplicationPadding.cu [deleted file]
aten/src/THCUNN/VolumetricReplicationPadding.cu [deleted file]
aten/src/THCUNN/generic/SpatialReplicationPadding.cu [deleted file]
aten/src/THCUNN/generic/THCUNN.h
aten/src/THCUNN/generic/VolumetricReplicationPadding.cu [deleted file]
aten/src/THNN/generic/SpatialReplicationPadding.c [deleted file]
aten/src/THNN/generic/THNN.h
aten/src/THNN/generic/VolumetricReplicationPadding.c [deleted file]
aten/src/THNN/init.cpp
torch/nn/_functions/thnn/auto.py

index 31c5d8b..5f38bf1 100644 (file)
@@ -508,38 +508,6 @@ Tensor reflection_pad2d_backward(const Tensor & grad_output, const Tensor & self
   return at::legacy::th::_thnn_reflection_pad2d_backward(grad_output, self, padding);
 }
 
-Tensor & replication_pad2d_out(Tensor & output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad2d_forward_out(output, self, padding);
-}
-
-Tensor replication_pad2d(const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad2d_forward(self, padding);
-}
-
-Tensor & replication_pad2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad2d_backward_out(grad_input, grad_output, self, padding);
-}
-
-Tensor replication_pad2d_backward(const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad2d_backward(grad_output, self, padding);
-}
-
-Tensor & replication_pad3d_out(Tensor & output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad3d_forward_out(output, self, padding);
-}
-
-Tensor replication_pad3d(const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad3d_forward(self, padding);
-}
-
-Tensor & replication_pad3d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad3d_backward_out(grad_input, grad_output, self, padding);
-}
-
-Tensor replication_pad3d_backward(const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_replication_pad3d_backward(grad_output, self, padding);
-}
-
 Tensor & upsample_linear1d_out(Tensor & output, const Tensor & self, IntList output_size, bool align_corners) {
   return at::legacy::th::_thnn_upsample_linear1d_forward_out(output, self, output_size, align_corners);
 }
index 5b84238..a2a3d04 100644 (file)
@@ -1,5 +1,6 @@
 #include "ATen/ATen.h"
 #include "ATen/NativeFunctions.h"
+#include <algorithm>
 
 namespace at {
 namespace native {
@@ -13,15 +14,14 @@ static void replication_pad1d_out_frame(
     long owidth,
     int pad_l, int pad_r)
 {
-  int iStartX = fmax(0, -pad_l);
-  int oStartX = fmax(0, pad_l);
+  int iStartX = std::max(0, -pad_l);
+  int oStartX = std::max(0, pad_l);
 
   long k, ip_x;
 #pragma omp parallel for private(k, ip_x)
   for (k = 0; k < nslices; k++)
   {
-    long j;
-    for (j = 0; j < owidth; j++) {
+    for (long j = 0; j < owidth; j++) {
       if (j < pad_l) {
         ip_x = pad_l;
       } else if (j >= pad_l && j < iwidth + pad_l) {
@@ -38,35 +38,52 @@ static void replication_pad1d_out_frame(
   }
 }
 
+template <typename scalar_t>
+static void replication_pad1d_out_batch(
+    scalar_t *input_data, scalar_t *output_data,
+    long nslices,
+    long iwidth,
+    long owidth,
+    int pad_l, int pad_r,
+    int nbatch)
+{
+  long p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *input_p = input_data+p*nslices*iwidth;
+    scalar_t *output_p = output_data+p*nslices*owidth;
+    replication_pad1d_out_frame(input_p, output_p, nslices, iwidth, owidth, pad_l, pad_r);
+  }
+}
+
 void replication_pad1d_out_cpu_template(
-    at::Tensor& output,
-    at::Tensor const& input,
+    Tensor& output,
+    const Tensor& input_,
     IntList paddingSize)
 {
   int dimw = 1;
   int dimslices = 0;
   long nbatch = 1;
-  long nslices;
-  long iwidth;
-  long owidth;
+  AT_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
   int pad_l = paddingSize[0];
   int pad_r = paddingSize[1];
 
-  AT_CHECK(input.numel() > 0
-      && (input.ndimension() == 2 || input.ndimension() == 3),
+  AT_CHECK(input_.numel() > 0
+      && (input_.ndimension() == 2 || input_.ndimension() == 3),
       "non-empty 2D or 3D (batch mode) tensor expected for input");
 
-  if (input.ndimension() == 3)
+  if (input_.ndimension() == 3)
   {
-    nbatch = input.size(0);
+    nbatch = input_.size(0);
     dimw++;
     dimslices++;
   }
 
   /* sizes */
-  nslices = input.size(dimslices);
-  iwidth = input.size(dimw);
-  owidth  = iwidth + pad_l + pad_r;
+  long nslices = input_.size(dimslices);
+  long iwidth = input_.size(dimw);
+  long owidth  = iwidth + pad_l + pad_r;
 
   AT_CHECK(owidth >= 1,
       "input (W: ", iwidth, ") is too small."
@@ -74,45 +91,41 @@ void replication_pad1d_out_cpu_template(
 
 
   /* get contiguous input */
-  auto input_ = input.contiguous();
+  auto input = input_.contiguous();
 
   /* resize output */
-  if (input_.ndimension() == 2)
+  if (input.ndimension() == 2)
   {
     output.resize_({nslices, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input_.type(), "replication_pad1d", [&] {
-        auto input_data = input_.data<scalar_t>();
-        auto output_data = output.data<scalar_t>();
-        replication_pad1d_out_frame<scalar_t> (input_data, output_data,
-          nslices,
-          iwidth,
-          owidth,
-          pad_l, pad_r);
-        }
-        );
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad1d_out_frame<scalar_t>(
+        input_data,
+        output_data,
+        nslices,
+        iwidth,
+        owidth,
+        pad_l, pad_r);
+      }
+    );
   }
   else
   {
-    long p;
-
     output.resize_({nbatch, nslices, owidth});
-
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
-    {
-      AT_DISPATCH_FLOATING_TYPES(input_.type(), "replication_pad1d", [&] {
-          auto input_data = input_.data<scalar_t>();
-          auto output_data = output.data<scalar_t>();
-          replication_pad1d_out_frame<scalar_t>(
-            input_data+p*nslices*iwidth,
-            output_data+p*nslices*owidth,
-            nslices,
-            iwidth,
-            owidth,
-            pad_l, pad_r);
-          }
-          );
-    }
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad1d_out_batch<scalar_t>(
+        input_data,
+        output_data,
+        nslices,
+        iwidth,
+        owidth,
+        pad_l, pad_r,
+        nbatch);
+      }
+    );
   }
 }
 
@@ -124,15 +137,14 @@ static void replication_pad1d_backward_out_frame(
     long owidth,
     int pad_l, int pad_r)
 {
-  int iStartX = fmax(0, -pad_l);
-  int oStartX = fmax(0, pad_l);
+  int iStartX = std::max(0, -pad_l);
+  int oStartX = std::max(0, pad_l);
 
   long k, ip_x;
 #pragma omp parallel for private(k, ip_x)
   for (k = 0; k < nslices; k++)
   {
-    long j;
-    for (j = 0; j < owidth; j++) {
+    for (long j = 0; j < owidth; j++) {
       if (j < pad_l) {
         ip_x = pad_l;
       } else if (j >= pad_l && j < iwidth + pad_l) {
@@ -149,6 +161,26 @@ static void replication_pad1d_backward_out_frame(
   }
 }
 
+template <typename scalar_t>
+static void replication_pad1d_backward_out_batch(
+    scalar_t *ginput_data, scalar_t *goutput_data,
+    long nslices,
+    long iwidth,
+    long owidth,
+    int pad_l, int pad_r,
+    int nbatch)
+{
+  long p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *ginput_p = ginput_data + p * nslices * iwidth;
+    scalar_t *goutput_p = goutput_data + p * nslices * owidth;
+    replication_pad1d_backward_out_frame(ginput_p, goutput_p,
+      nslices, iwidth, owidth, pad_l, pad_r);
+  }
+}
+
 Tensor& replication_pad1d_backward_out_cpu_template(
     Tensor& gradInput,
     const Tensor& gradOutput_,
@@ -158,9 +190,7 @@ Tensor& replication_pad1d_backward_out_cpu_template(
   int dimw = 1;
   int dimslices = 0;
   long nbatch = 1;
-  long nslices;
-  long iwidth;
-  long owidth;
+  AT_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
   int pad_l = paddingSize[0];
   int pad_r = paddingSize[1];
 
@@ -172,9 +202,9 @@ Tensor& replication_pad1d_backward_out_cpu_template(
   }
 
   /* sizes */
-  nslices = input.size(dimslices);
-  iwidth = input.size(dimw);
-  owidth  = iwidth + pad_l + pad_r;
+  long nslices = input.size(dimslices);
+  long iwidth = input.size(dimw);
+  long owidth  = iwidth + pad_l + pad_r;
 
   AT_CHECK(owidth == gradOutput_.size(dimw),
       "gradOutput width unexpected. Expected: ", owidth,
@@ -186,41 +216,691 @@ Tensor& replication_pad1d_backward_out_cpu_template(
   gradInput.zero_();
 
   /* backprop */
-  if (input.ndimension() == 2) {
+  if (input.ndimension() == 2)
+  {
     AT_DISPATCH_FLOATING_TYPES(
-        input.type(), "replication_pad1d_backward", [&] {
-        scalar_t *gradInput_data = gradInput.data<scalar_t>();
-        scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
-
-        replication_pad1d_backward_out_frame<scalar_t> (
-          gradInput_data,
-          gradOutput_data,
-          nslices,
-          iwidth,
-          owidth,
-          pad_l, pad_r);
+      input.type(), "replication_pad1d_backward", [&] {
+      scalar_t *gradInput_data = gradInput.data<scalar_t>();
+      scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+      replication_pad1d_backward_out_frame<scalar_t> (
+        gradInput_data,
+        gradOutput_data,
+        nslices,
+        iwidth,
+        owidth,
+        pad_l, pad_r);
+      }
+    );
+  }
+  else
+  {
+    AT_DISPATCH_FLOATING_TYPES(
+      input.type(), "replication_pad1d_backward", [&] {
+      scalar_t *gradInput_data = gradInput.data<scalar_t>();
+      scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+      replication_pad1d_backward_out_batch<scalar_t> (
+        gradInput_data,
+        gradOutput_data,
+        nslices,
+        iwidth,
+        owidth,
+        pad_l, pad_r,
+        nbatch);
+      }
+    );
+  }
+  return gradInput;
+}
+
+template <typename scalar_t>
+static void replication_pad2d_out_frame(
+    scalar_t *input_p, scalar_t *output_p,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight,
+    int64_t owidth, int64_t oheight,
+    int pad_l, int pad_r,
+    int pad_t, int pad_b)
+{
+  int iStartX = std::max(0, -pad_l);
+  int iStartY = std::max(0, -pad_t);
+  int oStartX = std::max(0, pad_l);
+  int oStartY = std::max(0, pad_t);
+
+  int64_t k, ip_x, ip_y;
+#pragma omp parallel for private(k, ip_x, ip_y)
+  for (k = 0; k < nslices; k++)
+  {
+    for (int64_t i = 0; i < oheight; i++) {
+      for (int64_t j = 0; j < owidth; j++) {
+        if (j < pad_l) {
+          ip_x = pad_l;
+        } else if (j >= pad_l && j < iwidth + pad_l) {
+          ip_x = j;
+        } else {
+          ip_x = iwidth + pad_l - 1;
+        }
+        ip_x = ip_x - oStartX + iStartX;
+
+        if (i < pad_t) {
+          ip_y = pad_t;
+        } else if (i >= pad_t && i < iheight + pad_t) {
+          ip_y = i;
+        } else {
+          ip_y = iheight + pad_t - 1;
+        }
+        ip_y = ip_y - oStartY + iStartY;
+
+        scalar_t *dest_p = output_p + k*owidth*oheight + i * owidth + j;
+        scalar_t *src_p = input_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
+        *dest_p = *src_p;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+static void replication_pad2d_out_batch(
+    scalar_t *input_data, scalar_t *output_data,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight,
+    int64_t owidth, int64_t oheight,
+    int pad_l, int pad_r,
+    int pad_t, int pad_b,
+    int nbatch)
+{
+  int64_t p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *input_p = input_data+p*nslices*iwidth*iheight;
+    scalar_t *output_p = output_data+p*nslices*owidth*oheight;
+    replication_pad2d_out_frame(input_p, output_p, nslices,
+        iwidth, iheight, owidth, oheight, pad_l, pad_r, pad_t, pad_b);
+  }
+}
+
+void replication_pad2d_out_cpu_template(Tensor& output,
+    const Tensor& input_,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
+  int pad_l = paddingSize[0];
+  int pad_r = paddingSize[1];
+  int pad_t = paddingSize[2];
+  int pad_b = paddingSize[3];
+  int dimw = 2;
+  int dimh = 1;
+  int dimslices = 0;
+  int64_t nbatch = 1;
+
+  AT_CHECK(input_.numel() > 0 && (input_.dim() == 3 || input_.dim() == 4),
+      "3D or 4D (batch mode) tensor expected for input, but got: ", input_);
+
+  if (input_.dim() == 4)
+  {
+    nbatch = input_.size(0);
+    dimw++;
+    dimh++;
+    dimslices++;
+  }
+
+  /* sizes */
+  int64_t nslices = input_.size(dimslices);
+  int64_t iheight = input_.size(dimh);
+  int64_t iwidth = input_.size(dimw);
+  int64_t oheight = iheight + pad_t + pad_b;
+  int64_t owidth  = iwidth + pad_l + pad_r;
+
+  AT_CHECK(owidth >= 1 || oheight >= 1,
+      "input (H: ", iheight, ", W: ", iwidth, " ) is too small."
+      " Calculated output H: ", oheight, " W: ", owidth);
+
+
+  /* get contiguous input */
+  auto input = input_.contiguous();
+
+  /* resize output */
+  if (input.dim() == 3)
+  {
+    output.resize_({nslices, oheight, owidth});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad2d_out_frame<scalar_t> (input_data, output_data,
+        nslices,
+        iwidth, iheight,
+        owidth, oheight,
+        pad_l, pad_r,
+        pad_t, pad_b);
+      }
+    );
+  }
+  else
+  {
+    output.resize_({nbatch, nslices, oheight, owidth});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad2d_out_batch<scalar_t> (input_data, output_data,
+        nslices,
+        iwidth, iheight,
+        owidth, oheight,
+        pad_l, pad_r,
+        pad_t, pad_b,
+        nbatch);
+      }
+    );
+  }
+}
+
+template <typename scalar_t>
+static void replication_pad2d_backward_out_frame(
+    scalar_t *ginput_p, scalar_t *goutput_p,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight,
+    int64_t owidth, int64_t oheight,
+    int pad_l, int pad_r,
+    int pad_t, int pad_b)
+{
+  int iStartX = std::max(0, -pad_l);
+  int iStartY = std::max(0, -pad_t);
+  int oStartX = std::max(0, pad_l);
+  int oStartY = std::max(0, pad_t);
+
+  int64_t k, ip_x, ip_y;
+#pragma omp parallel for private(k, ip_x, ip_y)
+  for (k = 0; k < nslices; k++)
+  {
+    for (int64_t i = 0; i < oheight; i++) {
+      for (int64_t j = 0; j < owidth; j++) {
+        if (j < pad_l) {
+          ip_x = pad_l;
+        } else if (j >= pad_l && j < iwidth + pad_l) {
+          ip_x = j;
+        } else {
+          ip_x = iwidth + pad_l - 1;
         }
-        );
-  } else {
-    long p;
+        ip_x = ip_x - oStartX + iStartX;
+
+        if (i < pad_t) {
+          ip_y = pad_t;
+        } else if (i >= pad_t && i < iheight + pad_t) {
+          ip_y = i;
+        } else {
+          ip_y = iheight + pad_t - 1;
+        }
+        ip_y = ip_y - oStartY + iStartY;
+
+        scalar_t *src_p = goutput_p + k*owidth*oheight + i * owidth + j;
+        scalar_t *dest_p = ginput_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
+        *dest_p += *src_p;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+static void replication_pad2d_backward_out_batch(
+    scalar_t *ginput_data, scalar_t *goutput_data,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight,
+    int64_t owidth, int64_t oheight,
+    int pad_l, int pad_r,
+    int pad_t, int pad_b,
+    int nbatch)
+{
+  int64_t p;
 #pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++) {
-      AT_DISPATCH_FLOATING_TYPES(
-          input.type(), "replication_pad1d_backward", [&] {
-          scalar_t *gradInput_data = gradInput.data<scalar_t>();
-          scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
-
-          replication_pad1d_backward_out_frame<scalar_t>(
-            gradInput_data + p * nslices * iwidth,
-            gradOutput_data + p * nslices * owidth,
-            nslices,
-            iwidth,
-            owidth,
-            pad_l, pad_r);
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *ginput_p = ginput_data + p * nslices * iheight * iwidth;
+    scalar_t *goutput_p = goutput_data + p * nslices * oheight * owidth;
+    replication_pad2d_backward_out_frame(ginput_p, goutput_p, nslices,
+        iwidth, iheight, owidth, oheight, pad_l, pad_r, pad_t, pad_b);
+  }
+}
+
+Tensor& replication_pad2d_backward_out_cpu_template(
+    Tensor& gradInput,
+    const Tensor& gradOutput_,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
+  int pad_l = paddingSize[0];
+  int pad_r = paddingSize[1];
+  int pad_t = paddingSize[2];
+  int pad_b = paddingSize[3];
+  int dimw = 2;
+  int dimh = 1;
+  int dimslices = 0;
+  int64_t nbatch = 1;
+
+  if (input.dim() == 4)
+  {
+    nbatch = input.size(0);
+    dimw++;
+    dimh++;
+    dimslices++;
+  }
+
+  /* sizes */
+  int64_t nslices = input.size(dimslices);
+  int64_t iheight = input.size(dimh);
+  int64_t iwidth = input.size(dimw);
+  int64_t oheight = iheight + pad_t + pad_b;
+  int64_t owidth  = iwidth + pad_l + pad_r;
+
+  AT_CHECK(owidth == gradOutput_.size(dimw),
+      "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
+      gradOutput_.size(dimw));
+  AT_CHECK(oheight == gradOutput_.size(dimh),
+      "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
+      gradOutput_.size(dimh));
+
+  /* get contiguous gradOutput */
+  auto gradOutput = gradOutput_.contiguous();
+
+  /* resize */
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  /* backprop */
+  if (input.dim() == 3)
+  {
+    AT_DISPATCH_FLOATING_TYPES(
+      input.type(), "replication_pad2d_backward", [&] {
+      replication_pad2d_backward_out_frame<scalar_t>(
+        gradInput.data<scalar_t>(),
+        gradOutput.data<scalar_t>(),
+        nslices,
+        iwidth, iheight,
+        owidth, oheight,
+        pad_l, pad_r,
+        pad_t, pad_b);
+      }
+    );
+  }
+  else
+  {
+    AT_DISPATCH_FLOATING_TYPES(
+      input.type(), "replication_pad2d_backward", [&] {
+      replication_pad2d_backward_out_batch<scalar_t>(
+        gradInput.data<scalar_t>(),
+        gradOutput.data<scalar_t>(),
+        nslices,
+        iwidth, iheight,
+        owidth, oheight,
+        pad_l, pad_r,
+        pad_t, pad_b,
+        nbatch);
+      }
+    );
+  }
+  return gradInput;
+}
+
+static inline void shapeCheck3d(
+    const Tensor& input,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback) {
+  int dimw = 3;
+  int dimh = 2;
+  int dimd = 1;
+  int dimslices = 0;
+
+  AT_CHECK(input.numel() > 0 && (input.dim() == 4 || input.dim() == 5),
+      "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input);
+
+  if (input.dim() == 5)
+  {
+    dimw++;
+    dimh++;
+    dimd++;
+    dimslices++;
+  }
+
+  /* sizes */
+  int64_t nslices = input.size(dimslices);
+  int64_t idepth = input.size(dimd);
+  int64_t iheight = input.size(dimh);
+  int64_t iwidth = input.size(dimw);
+  int64_t odepth = idepth + pfront + pback;
+  int64_t oheight = iheight + ptop + pbottom;
+  int64_t owidth  = iwidth + pleft + pright;
+
+  AT_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
+      "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
+      ") is too small."
+      " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
+
+}
+
+template <typename scalar_t>
+static void replication_pad3d_out_frame(
+    scalar_t *input_p, scalar_t *output_p,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight, int64_t idepth,
+    int64_t owidth, int64_t oheight, int64_t odepth,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback)
+{
+  int iStartX = std::max(0, -pleft);
+  int iStartY = std::max(0, -ptop);
+  int iStartZ = std::max(0, -pfront);
+  int oStartX = std::max(0, pleft);
+  int oStartY = std::max(0, ptop);
+  int oStartZ = std::max(0, pfront);
+
+  int64_t k, ip_x, ip_y, ip_z;
+#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
+  for (k = 0; k < nslices; k++) {
+    for (int64_t z = 0; z < odepth; z++) {
+      for (int64_t i = 0; i < oheight; i++) {
+        for (int64_t j = 0; j < owidth; j++) {
+          if (j < pleft) {
+            ip_x = pleft;
+          } else if (j >= pleft && j < iwidth + pleft) {
+            ip_x = j;
+          } else {
+            ip_x = iwidth + pleft - 1;
+          }
+          ip_x = ip_x - oStartX + iStartX;
+
+          if (i < ptop) {
+            ip_y = ptop;
+          } else if (i >= ptop && i < iheight + ptop) {
+            ip_y = i;
+          } else {
+            ip_y = iheight + ptop - 1;
+          }
+          ip_y = ip_y - oStartY + iStartY;
+
+          if (z < pfront) {
+            ip_z = pfront;
+          } else if (z >= pfront && z < idepth + pfront) {
+            ip_z = z;
+          } else {
+            ip_z = idepth + pfront - 1;
+          }
+          ip_z = ip_z - oStartZ + iStartZ;
+
+          scalar_t *dest_p = output_p + k * owidth * oheight * odepth +
+            z * owidth * oheight + i * owidth + j;
+          scalar_t *src_p = input_p + k * iwidth * iheight * idepth +
+            ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
+          *dest_p = *src_p;
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+static void replication_pad3d_out_batch(
+    scalar_t *input_data, scalar_t *output_data,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight, int64_t idepth,
+    int64_t owidth, int64_t oheight, int64_t odepth,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback,
+    int nbatch)
+{
+  int64_t p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *input_p = input_data + p * nslices * iwidth * iheight * idepth;
+    scalar_t *output_p = output_data + p * nslices * owidth * oheight * odepth;
+    replication_pad3d_out_frame(input_p, output_p, nslices,
+        iwidth, iheight, idepth, owidth, oheight, odepth,
+        pleft, pright, ptop, pbottom, pfront, pback);
+  }
+}
+
+void replication_pad3d_out_cpu_template(
+    Tensor& output,
+    const Tensor& input_,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
+  int pleft = paddingSize[0];
+  int pright = paddingSize[1];
+  int ptop = paddingSize[2];
+  int pbottom = paddingSize[3];
+  int pfront = paddingSize[4];
+  int pback = paddingSize[5];
+  int dimw = 3;
+  int dimh = 2;
+  int dimd = 1;
+  int dimslices = 0;
+  int64_t nbatch = 1;
+
+  shapeCheck3d(input_, pleft, pright,
+      ptop, pbottom, pfront, pback);
+
+  if (input_.dim() == 5)
+  {
+    nbatch = input_.size(0);
+    dimw++;
+    dimh++;
+    dimd++;
+    dimslices++;
+  }
+
+  /* sizes */
+  int64_t nslices = input_.size(dimslices);
+  int64_t idepth = input_.size(dimd);
+  int64_t iheight = input_.size(dimh);
+  int64_t iwidth = input_.size(dimw);
+  int64_t odepth = idepth + pfront + pback;
+  int64_t oheight = iheight + ptop + pbottom;
+  int64_t owidth  = iwidth + pleft + pright;
+
+  /* get contiguous input */
+  auto input = input_.contiguous();
+
+  /* resize output */
+  if (input.dim() == 4)
+  {
+    output.resize_({nslices, odepth, oheight, owidth});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad3d_out_frame<scalar_t>(
+        input_data, output_data, nslices, iwidth, iheight, idepth,
+        owidth, oheight, odepth, pleft, pright, ptop, pbottom, pfront,
+        pback);
+      }
+    );
+  }
+  else
+  {
+    output.resize_({nbatch, nslices, odepth, oheight, owidth});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      replication_pad3d_out_batch<scalar_t>(
+        input_data, output_data, nslices, iwidth, iheight, idepth,
+        owidth, oheight, odepth, pleft, pright, ptop, pbottom, pfront,
+        pback,
+        nbatch);
+      }
+    );
+  }
+}
+
+template <typename scalar_t>
+static void replication_pad3d_backward_out_frame(
+    scalar_t *ginput_p, scalar_t *goutput_p,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight, int64_t idepth,
+    int64_t owidth, int64_t oheight, int64_t odepth,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback)
+{
+  int iStartX = std::max(0, -pleft);
+  int iStartY = std::max(0, -ptop);
+  int iStartZ = std::max(0, -pfront);
+  int oStartX = std::max(0, pleft);
+  int oStartY = std::max(0, ptop);
+  int oStartZ = std::max(0, pfront);
+
+  int64_t k, ip_x, ip_y, ip_z;
+#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
+  for (k = 0; k < nslices; k++) {
+    for (int64_t z = 0; z < odepth; z++) {
+      for (int64_t i = 0; i < oheight; i++) {
+        for (int64_t j = 0; j < owidth; j++) {
+          if (j < pleft) {
+            ip_x = pleft;
+          } else if (j >= pleft && j < iwidth + pleft) {
+            ip_x = j;
+          } else {
+            ip_x = iwidth + pleft - 1;
           }
-          );
+          ip_x = ip_x - oStartX + iStartX;
+
+          if (i < ptop) {
+            ip_y = ptop;
+          } else if (i >= ptop && i < iheight + ptop) {
+            ip_y = i;
+          } else {
+            ip_y = iheight + ptop - 1;
+          }
+          ip_y = ip_y - oStartY + iStartY;
+
+          if (z < pfront) {
+            ip_z = pfront;
+          } else if (z >= pfront && z < idepth + pfront) {
+            ip_z = z;
+          } else {
+            ip_z = idepth + pfront - 1;
+          }
+          ip_z = ip_z - oStartZ + iStartZ;
+
+          scalar_t *src_p = goutput_p + k * owidth * oheight * odepth +
+            z * owidth * oheight + i * owidth + j;
+          scalar_t *dest_p = ginput_p + k * iwidth * iheight * idepth +
+            ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
+          *dest_p += *src_p;
+        }
+      }
     }
   }
+}
+
+template <typename scalar_t>
+static void replication_pad3d_backward_out_batch(
+    scalar_t *ginput_data, scalar_t *goutput_data,
+    int64_t nslices,
+    int64_t iwidth, int64_t iheight, int64_t idepth,
+    int64_t owidth, int64_t oheight, int64_t odepth,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback,
+    int nbatch)
+{
+  int64_t p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++)
+  {
+    scalar_t *ginput_p = ginput_data + p * nslices * idepth * iheight * iwidth;
+    scalar_t *goutput_p = goutput_data + p * nslices * odepth * oheight * owidth;
+    replication_pad3d_backward_out_frame(ginput_p, goutput_p, nslices,
+        iwidth, iheight, idepth, owidth, oheight, odepth,
+        pleft, pright, ptop, pbottom, pfront, pback);
+  }
+}
+
+Tensor& replication_pad3d_backward_out_cpu_template(
+    Tensor& gradInput,
+    const Tensor& gradOutput_,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
+  int pleft = paddingSize[0];
+  int pright = paddingSize[1];
+  int ptop = paddingSize[2];
+  int pbottom = paddingSize[3];
+  int pfront = paddingSize[4];
+  int pback = paddingSize[5];
+  int dimw = 3;
+  int dimh = 2;
+  int dimd = 1;
+  int dimslices = 0;
+  int64_t nbatch = 1;
+
+  if (input.dim() == 5)
+  {
+    nbatch = input.size(0);
+    dimw++;
+    dimh++;
+    dimd++;
+    dimslices++;
+  }
+
+  /* sizes */
+  int64_t nslices = input.size(dimslices);
+  int64_t idepth = input.size(dimd);
+  int64_t iheight = input.size(dimh);
+  int64_t iwidth = input.size(dimw);
+  int64_t odepth = idepth + pfront + pback;
+  int64_t oheight = iheight + ptop + pbottom;
+  int64_t owidth  = iwidth + pleft + pright;
+
+
+  shapeCheck3d(input, pleft, pright,
+      ptop, pbottom, pfront, pback);
+
+  /* get contiguous gradOutput */
+  auto gradOutput = gradOutput_.contiguous();
+
+  /* resize */
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  /* backprop */
+  if (input.dim() == 4)
+  {
+    AT_DISPATCH_FLOATING_TYPES(
+      input.type(), "replication_pad3d_backward", [&] {
+      replication_pad3d_backward_out_frame<scalar_t> (
+        gradInput.data<scalar_t>(),
+        gradOutput.data<scalar_t>(),
+        nslices,
+        iwidth, iheight, idepth,
+        owidth, oheight, odepth,
+        pleft, pright,
+        ptop, pbottom,
+        pfront, pback);
+      }
+    );
+  }
+  else
+  {
+    AT_DISPATCH_FLOATING_TYPES(
+      input.type(), "replication_pad3d_backward", [&] {
+      replication_pad3d_backward_out_batch<scalar_t> (
+        gradInput.data<scalar_t>(),
+        gradOutput.data<scalar_t>(),
+        nslices,
+        iwidth, iheight, idepth,
+        owidth, oheight, odepth,
+        pleft, pright,
+        ptop, pbottom,
+        pfront, pback,
+        nbatch);
+      }
+    );
+  }
   return gradInput;
 }
 } // namespace
@@ -236,7 +916,7 @@ Tensor& replication_pad1d_out_cpu(
 }
 
 Tensor replication_pad1d_cpu(
-    at::Tensor const& input,
+    const Tensor& input,
     IntList paddingSize)
 {
   auto output = at::empty({0}, input.options());
@@ -268,5 +948,88 @@ Tensor replication_pad1d_backward_cpu(
   return gradInput;
 }
 
+Tensor& replication_pad2d_out_cpu(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad2d_out_cpu_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor replication_pad2d_cpu(
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto output = at::empty({0}, input.options());
+  replication_pad2d_out_cpu_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor& replication_pad2d_backward_out_cpu(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad2d_backward_out_cpu_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor replication_pad2d_backward_cpu(
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto gradInput = at::zeros_like(input);
+  replication_pad2d_backward_out_cpu_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor& replication_pad3d_out_cpu(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad3d_out_cpu_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor replication_pad3d_cpu(
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto output = at::empty({0}, input.options());
+  replication_pad3d_out_cpu_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor& replication_pad3d_backward_out_cpu(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad3d_backward_out_cpu_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor replication_pad3d_backward_cpu(
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto gradInput = at::zeros_like(input);
+  replication_pad3d_backward_out_cpu_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
 } // at::native
 } // at
index 23dc9ce..f8b1be2 100644 (file)
@@ -24,13 +24,9 @@ __host__ __device__ __forceinline__ int imax(int a, int b) {
   return a > b ? a : b;
 }
 
-__host__ __device__ __forceinline__ int iabs(int a) {
-  return a >= 0 ? a : -a;
-}
-
 namespace {
 template <typename scalar_t>
-__global__ void replication_pad_forward_kernel(
+__global__ void replication_pad_forward_kernel1d(
     PackedTensorAccessor<scalar_t, 3> input,
     PackedTensorAccessor<scalar_t, 3> output,
     int padL, int padR) {
@@ -75,6 +71,135 @@ __global__ void replication_pad_backward_kernel(
   atomicAdd(&gradInput[batch][plane][inputPointX], valueToCopy);
 }
 
+template <typename scalar_t>
+__global__ void replication_pad_forward_kernel2d(
+    PackedTensorAccessor<scalar_t, 4> input,
+    PackedTensorAccessor<scalar_t, 4> output,
+    int padT, int padB, int padL, int padR) {
+
+  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+  if (outputPointId >= output.size(2) * output.size(3)) {
+    return;
+  }
+  int outputPointX = outputPointId % output.size(3);
+  int outputPointY = outputPointId / output.size(3);
+
+  int iStartX = imax(0, -padL);
+  int iStartY = imax(0, -padT);
+  int oStartX = imax(0, padL);
+  int oStartY = imax(0, padT);
+
+  int inputPointX = imin(imax(padL, outputPointX), input.size(3) + padL - 1) - oStartX + iStartX;
+  int inputPointY = imin(imax(padT, outputPointY), input.size(2) + padT - 1) - oStartY + iStartY;
+
+  scalar_t valueToCopy = input[batch][plane][inputPointY][inputPointX];
+  output[batch][plane][outputPointY][outputPointX] = valueToCopy;
+}
+
+template <typename scalar_t>
+__global__ void replication_pad_backward_kernel(
+    PackedTensorAccessor<scalar_t, 4> gradInput,
+    PackedTensorAccessor<scalar_t, 4> gradOutput,
+    int padT, int padB, int padL, int padR) {
+
+  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+  if (outputPointId >= gradOutput.size(2) * gradOutput.size(3)) {
+    return;
+  }
+  int outputPointX = outputPointId % gradOutput.size(3);
+  int outputPointY = outputPointId / gradOutput.size(3);
+
+  int iStartX = imax(0, -padL);
+  int iStartY = imax(0, -padT);
+  int oStartX = imax(0, padL);
+  int oStartY = imax(0, padT);
+
+  int inputPointX = imin(imax(padL, outputPointX), gradInput.size(3) + padL - 1) - oStartX + iStartX;
+  int inputPointY = imin(imax(padT, outputPointY), gradInput.size(2) + padT - 1) - oStartY + iStartY;
+
+  scalar_t valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
+  atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
+}
+
+template <typename scalar_t>
+__global__ void replication_pad_forward_kernel3d(
+    PackedTensorAccessor<scalar_t, 5> input,
+    PackedTensorAccessor<scalar_t, 5> output,
+    int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
+
+  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+  if (outputPointId >= (output.size(2) * output.size(3) *
+        output.size(4))) {
+    return;
+  }
+  int outputPointX = outputPointId % output.size(4);
+  int outputPointY = (outputPointId / output.size(4)) % output.size(3);
+  int outputPointZ = outputPointId / (output.size(3) * output.size(4));
+
+  int iStartX = imax(0, -pleft);
+  int iStartY = imax(0, -ptop);
+  int iStartZ = imax(0, -pfront);
+  int oStartX = imax(0, pleft);
+  int oStartY = imax(0, ptop);
+  int oStartZ = imax(0, pfront);
+
+  int inputPointX = imin(imax(pleft, outputPointX),
+      input.size(4) + pleft - 1) - oStartX + iStartX;
+  int inputPointY = imin(imax(ptop, outputPointY),
+      input.size(3) + ptop - 1) - oStartY + iStartY;
+  int inputPointZ = imin(imax(pfront, outputPointZ),
+      input.size(2) + pfront - 1) - oStartZ + iStartZ;
+
+  scalar_t valueToCopy =
+    input[batch][plane][inputPointZ][inputPointY][inputPointX];
+  output[batch][plane][outputPointZ][outputPointY][outputPointX] = valueToCopy;
+}
+
+template <typename scalar_t>
+__global__ void replication_pad_backward_kernel(
+    PackedTensorAccessor<scalar_t, 5> gradInput,
+    PackedTensorAccessor<scalar_t, 5> gradOutput,
+    int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
+  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+
+  if (outputPointId >= (gradOutput.size(2) * gradOutput.size(3) *
+        gradOutput.size(4))) {
+    return;
+  }
+  int outputPointX = outputPointId % gradOutput.size(4);
+  int outputPointY = (outputPointId / gradOutput.size(4)) %
+    gradOutput.size(3);
+  int outputPointZ = outputPointId / (gradOutput.size(3) *
+      gradOutput.size(4));
+
+  int iStartX = imax(0, -pleft);
+  int iStartY = imax(0, -ptop);
+  int iStartZ = imax(0, -pfront);
+  int oStartX = imax(0, pleft);
+  int oStartY = imax(0, ptop);
+  int oStartZ = imax(0, pfront);
+
+  int inputPointX = imin(imax(pleft, outputPointX),
+      gradInput.size(4) + pleft - 1) - oStartX + iStartX;
+  int inputPointY = imin(imax(ptop, outputPointY),
+      gradInput.size(3) + ptop - 1) - oStartY + iStartY;
+  int inputPointZ = imin(imax(pfront, outputPointZ),
+      gradInput.size(2) + pfront - 1) - oStartZ + iStartZ;
+
+  scalar_t valueToCopy =
+    gradOutput[batch][plane][outputPointZ][outputPointY][outputPointX];
+  atomicAdd(&gradInput[batch][plane][inputPointZ][inputPointY][inputPointX],
+      valueToCopy);
+}
+
 void replication_pad1d_out_cuda_template(
     Tensor& output,
     const Tensor& input,
@@ -82,6 +207,7 @@ void replication_pad1d_out_cuda_template(
 {
   AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
       "input tensor must fit into 32-bit index math");
+  AT_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
 
   int padL = paddingSize[0];
   int padR = paddingSize[1];
@@ -114,18 +240,18 @@ void replication_pad1d_out_cuda_template(
 
       if (numInputDims == 2) {
         output.resize_({numPlanes, outputW});
-        auto input_ = input.reshape({1, input.size(0), input.size(1)});
-        auto output_ = output.reshape({1, output.size(0), output.size(1)});
+        auto input_ = input.unsqueeze(0);
+        auto output_ = output.unsqueeze(0);
         auto devInput = input_.packed_accessor<scalar_t, 3>();
         auto devOutput = output_.packed_accessor<scalar_t, 3>();
 
         int outputPlaneSize = devOutput.size(2);
         dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-          devOutput.size(1),
-          devOutput.size(0));
+            devOutput.size(1),
+            devOutput.size(0));
         dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
 
-        replication_pad_forward_kernel <<<gridSize, blockSize, 0,
+        replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
           at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR);
       } else {
         output.resize_({numBatch, numPlanes, outputW});
@@ -138,12 +264,12 @@ void replication_pad1d_out_cuda_template(
             devOutput.size(0));
         dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
 
-        replication_pad_forward_kernel <<<gridSize, blockSize, 0,
+        replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
            at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR);
       }
-    }
+      }
   );
-  THCudaCheck(cudaGetLastError());
+  AT_CUDA_CHECK(cudaGetLastError());
 }
 
 void replication_pad1d_backward_out_cuda_template(
@@ -157,6 +283,7 @@ void replication_pad1d_backward_out_cuda_template(
       "input tensor must fit into 32-bit index math");
   AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
       "output gradient tensor must fit into 32-bit index math");
+  AT_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
 
   int padL = paddingSize[0];
   int padR = paddingSize[1];
@@ -184,26 +311,400 @@ void replication_pad1d_backward_out_cuda_template(
       auto gradInput_ = gradInput;
       auto gradOutput_ = gradOutput;
       if (numInputDims == 2) {
-        gradInput_ = gradInput.reshape({1, gradInput.size(0),
-          gradInput.size(1)});
-        gradOutput_ = gradOutput.reshape({1, gradOutput.size(0),
-          gradOutput.size(1)});
+      gradInput_ = gradInput.unsqueeze(0);
+      gradOutput_ = gradOutput.unsqueeze(0);
       }
       auto devGradInput = gradInput_.packed_accessor<scalar_t, 3>();
       auto devGradOutput = gradOutput_.packed_accessor<scalar_t, 3>();
 
       int outputPlaneSize = devGradOutput.size(2);
       dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-        devGradOutput.size(1),
-        devGradOutput.size(0));
+          devGradOutput.size(1),
+          devGradOutput.size(0));
       dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
 
       replication_pad_backward_kernel <<<gridSize, blockSize, 0,
+                                      at::cuda::getCurrentCUDAStream()>>>(devGradInput, devGradOutput,
+                                          padL, padR);
+      }
+  );
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void replication_pad2d_out_cuda_template(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
+      "input tensor must fit into 32-bit index math");
+  AT_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
+
+  int padL = paddingSize[0];
+  int padR = paddingSize[1];
+  int padT = paddingSize[2];
+  int padB = paddingSize[3];
+  int planeDim = 0;
+  int dimh = 1;
+  int dimw = 2;
+  int numBatch = 1;
+
+  int numInputDims = input.dim();
+  AT_CHECK(input.numel() && (numInputDims == 3 || numInputDims == 4),
+      "non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
+      input)
+
+  if (numInputDims == 4) {
+    numBatch = input.size(0);
+    planeDim++;
+    dimh++;
+    dimw++;
+  }
+
+  int numPlanes = input.size(planeDim);
+  int inputH = input.size(dimh);
+  int inputW = input.size(dimw);
+  int outputH = inputH + padT + padB;
+  int outputW  = inputW + padL + padR;
+
+  AT_CHECK(outputW >= 1 || outputH >= 1,
+      "input (H: ", inputH, ", W: ", inputW, ") is too small."
+      " Calculated output H: ", outputH, " W: ", outputW);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.type(), "replication_pad2d", [&] {
+
+
+      if (numInputDims == 3) {
+        output.resize_({numPlanes, outputH, outputW});
+        auto input_ = input.unsqueeze(0);
+        auto output_ = output.unsqueeze(0);
+        auto devInput = input_.packed_accessor<scalar_t, 4>();
+        auto devOutput = output_.packed_accessor<scalar_t, 4>();
+
+        int outputPlaneSize = devOutput.size(2) * devOutput.size(3);
+        dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+            devOutput.size(1),
+            devOutput.size(0));
+        dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+
+        replication_pad_forward_kernel2d <<<gridSize, blockSize, 0,
+        at::cuda::getCurrentCUDAStream()>>>(
+            devInput, devOutput, padT, padB, padL, padR);
+      } else {
+        output.resize_({numBatch, numPlanes, outputH, outputW});
+        auto devInput = input.packed_accessor<scalar_t, 4>();
+        auto devOutput = output.packed_accessor<scalar_t, 4>();
+
+        int outputPlaneSize = devOutput.size(2) * devOutput.size(3);
+        dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+            devOutput.size(1),
+            devOutput.size(0));
+        dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+
+        replication_pad_forward_kernel2d <<<gridSize, blockSize, 0,
+                                       at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, 
+                                           padT, padB, padL, padR);
+      }
+      }
+  );
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void replication_pad2d_backward_out_cuda_template(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
+      "input tensor must fit into 32-bit index math");
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
+      "output gradient tensor must fit into 32-bit index math");
+  AT_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
+
+  int padL = paddingSize[0];
+  int padR = paddingSize[1];
+  int padT = paddingSize[2];
+  int padB = paddingSize[3]; 
+  int planeDim = 0;
+  int dimh = 1;
+  int dimw = 2;
+
+  int numInputDims = input.dim();
+  if (numInputDims == 4) {
+    planeDim++;
+    dimh++;
+    dimw++;
+  }
+  int iheight = input.size(dimh);
+  int iwidth = input.size(dimw);
+  int oheight = iheight + padT + padB;
+  int owidth  = iwidth + padL + padR;
+
+  AT_CHECK(owidth == gradOutput.size(dimw),
+      "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
+      gradOutput.size(dimw));
+  AT_CHECK(oheight == gradOutput.size(dimh),
+      "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
+      gradOutput.size(dimh));
+
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.type(), "replication_pad2d_backward", [&] {
+
+        auto gradInput_ = gradInput;
+        auto gradOutput_ = gradOutput;
+        if (numInputDims == 3) {
+          gradInput_ = gradInput.unsqueeze(0);
+          gradOutput_ = gradOutput.unsqueeze(0);
+        }
+        auto devGradInput = gradInput_.packed_accessor<scalar_t, 4>();
+        auto devGradOutput = gradOutput_.packed_accessor<scalar_t, 4>();
+
+        int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3);
+        dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+          devGradOutput.size(1),
+          devGradOutput.size(0));
+        dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+        replication_pad_backward_kernel <<<gridSize, blockSize, 0,
         at::cuda::getCurrentCUDAStream()>>>(devGradInput, devGradOutput,
-            padL, padR);
-    }
+          padT, padB, padL, padR);
+      }
+  );
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+static inline void shapeCheck3d(
+    const Tensor& input,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback) {
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), 
+      "input tensor must fit into 32-bit index math");
+  int numInputDims = input.dim();
+
+  AT_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
+      "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input);
+
+  int planeDim = 0;
+  int dimd = 1;
+  int dimh = 2;
+  int dimw = 3;
+  if (numInputDims == 5) {
+    planeDim++;
+    dimd++;
+    dimh++;
+    dimw++;
+  }
+
+  int numPlanes = input.size(planeDim);
+  int idepth = input.size(dimd);
+  int iheight = input.size(dimh);
+  int iwidth = input.size(dimw);
+  int odepth = idepth + pfront + pback;
+  int oheight = iheight + ptop + pbottom;
+  int owidth  = iwidth + pleft + pright;
+  AT_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
+      "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
+      ") is too small."
+      " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
+
+}
+
+static inline void shapeAndGradOutputCheck3d(
+    const Tensor& input,
+    const Tensor& gradOutput,
+    int pleft, int pright,
+    int ptop, int pbottom,
+    int pfront, int pback) {
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), 
+      "input tensor must fit into 32-bit index math");
+  int numInputDims = input.dim();
+
+  AT_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
+      "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input);
+
+  int planeDim = 0;
+  int dimd = 1;
+  int dimh = 2;
+  int dimw = 3;
+  if (numInputDims == 5) {
+    planeDim++;
+    dimd++;
+    dimh++;
+    dimw++;
+  }
+
+  int numPlanes = input.size(planeDim);
+  int idepth = input.size(dimd);
+  int iheight = input.size(dimh);
+  int iwidth = input.size(dimw);
+  int odepth = idepth + pfront + pback;
+  int oheight = iheight + ptop + pbottom;
+  int owidth  = iwidth + pleft + pright;
+  AT_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
+      "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
+      ") is too small."
+      " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
+
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
+      "output gradient tensor must fit into 32-bit index math");
+
+  AT_CHECK(numPlanes == gradOutput.size(planeDim),
+      "gradOutput width unexpected. Expected: ", numPlanes, ", Got: ",
+      gradOutput.size(planeDim));
+  AT_CHECK(owidth == gradOutput.size(dimw),
+      "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
+      gradOutput.size(dimw));
+  AT_CHECK(oheight == gradOutput.size(dimh),
+      "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
+      gradOutput.size(dimh));
+  AT_CHECK(odepth == gradOutput.size(dimd),
+      "gradOutput depth unexpected. Expected: ", odepth, ", Got: ",
+      gradOutput.size(dimd));
+}
+
+void replication_pad3d_out_cuda_template(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
+  int pleft = paddingSize[0];
+  int pright = paddingSize[1];
+  int ptop = paddingSize[2];
+  int pbottom = paddingSize[3];
+  int pfront = paddingSize[4];
+  int pback = paddingSize[5]; 
+  shapeCheck3d(input, pleft, pright, ptop,
+      pbottom, pfront, pback);
+
+  int planeDim = 0;
+  int dimd = 1;
+  int dimh = 2;
+  int dimw = 3;
+  int numBatch = 1;
+
+  int numInputDims = input.dim();
+
+  if (numInputDims == 5) {
+    numBatch = input.size(0);
+    planeDim++;
+    dimd++;
+    dimh++;
+    dimw++;
+  }
+
+  int numPlanes = input.size(planeDim);
+  int inputD = input.size(dimd);
+  int inputH = input.size(dimh);
+  int inputW = input.size(dimw);
+  int outputD = inputD + pfront + pback;
+  int outputH = inputH + ptop + pbottom;
+  int outputW  = inputW + pleft + pright;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.type(), "replication_pad3d", [&] {
+
+      if (numInputDims == 4) {
+        output.resize_({numPlanes, outputD, outputH, outputW});
+        auto input_ = input.unsqueeze(0);
+        auto output_ = output.unsqueeze(0);
+        auto devInput = input_.packed_accessor<scalar_t, 5>();
+        auto devOutput = output_.packed_accessor<scalar_t, 5>();
+
+        int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
+        devOutput.size(4);
+        dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+            devOutput.size(1),
+            devOutput.size(0));
+        dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+
+        replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
+        at::cuda::getCurrentCUDAStream()>>>(
+            devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
+      } else {
+        output.resize_({numBatch, numPlanes, outputD, outputH, outputW});
+        auto devInput = input.packed_accessor<scalar_t, 5>();
+        auto devOutput = output.packed_accessor<scalar_t, 5>();
+
+        int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
+          devOutput.size(4);
+        dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+            devOutput.size(1),
+            devOutput.size(0));
+        dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+
+        replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
+                                       at::cuda::getCurrentCUDAStream()>>>(
+                                           devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
+      }
+      }
+  );
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void replication_pad3d_backward_out_cuda_template(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  AT_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
+  int pleft = paddingSize[0];
+  int pright = paddingSize[1];
+  int ptop = paddingSize[2];
+  int pbottom = paddingSize[3];
+  int pfront = paddingSize[4];
+  int pback = paddingSize[5]; 
+  shapeAndGradOutputCheck3d(input, gradOutput, pleft, pright, ptop,
+      pbottom, pfront, pback);
+
+  int planeDim = 0;
+  int dimd = 1;
+  int dimh = 2;
+  int dimw = 3;
+
+  int numInputDims = input.dim();
+  if (numInputDims == 5) {
+    planeDim++;
+    dimd++;
+    dimh++;
+    dimw++;
+  }
+
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      input.type(), "replication_pad3d_backward", [&] {
+
+      auto gradInput_ = gradInput;
+      auto gradOutput_ = gradOutput;
+      if (numInputDims == 4) {
+        gradInput_ = gradInput.unsqueeze(0);
+        gradOutput_ = gradOutput.unsqueeze(0);
+      }
+      auto devGradInput = gradInput_.packed_accessor<scalar_t, 5>();
+      auto devGradOutput = gradOutput_.packed_accessor<scalar_t, 5>();
+
+      int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) *
+      devGradOutput.size(4);
+      dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
+          devGradOutput.size(1),
+          devGradOutput.size(0));
+      dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
+
+      replication_pad_backward_kernel <<<gridSize, blockSize, 0,
+                                      at::cuda::getCurrentCUDAStream()>>>(
+                                          devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright);
+      }
   );
-  THCudaCheck(cudaGetLastError());
+  AT_CUDA_CHECK(cudaGetLastError());
 }
 } // namespace
 
@@ -218,7 +719,7 @@ Tensor& replication_pad1d_out_cuda(
 }
 
 Tensor replication_pad1d_cuda(
-    at::Tensor const& input,
+    const Tensor& input,
     IntList paddingSize)
 {
   auto output = at::empty({0}, input.options());
@@ -233,7 +734,6 @@ Tensor& replication_pad1d_backward_out_cuda(
     const Tensor& input,
     IntList paddingSize)
 {
-  gradInput.resize_as_(input);
   replication_pad1d_backward_out_cuda_template(
       gradInput, gradOutput, input, paddingSize);
   return gradInput;
@@ -250,5 +750,89 @@ Tensor replication_pad1d_backward_cuda(
   return gradInput;
 }
 
+Tensor& replication_pad2d_out_cuda(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad2d_out_cuda_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor replication_pad2d_cuda(
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto output = at::empty({0}, input.options());
+  replication_pad2d_out_cuda_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor& replication_pad2d_backward_out_cuda(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad2d_backward_out_cuda_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor replication_pad2d_backward_cuda(
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto gradInput = at::zeros_like(input);
+  replication_pad2d_backward_out_cuda_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor& replication_pad3d_out_cuda(
+    Tensor& output,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad3d_out_cuda_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor replication_pad3d_cuda(
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto output = at::empty({0}, input.options());
+  replication_pad3d_out_cuda_template(
+      output, input, paddingSize);
+  return output;
+}
+
+Tensor& replication_pad3d_backward_out_cuda(
+    Tensor& gradInput,
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  replication_pad3d_backward_out_cuda_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
+Tensor replication_pad3d_backward_cuda(
+    const Tensor& gradOutput,
+    const Tensor& input,
+    IntList paddingSize)
+{
+  auto gradInput = at::zeros_like(input);
+  replication_pad3d_backward_out_cuda_template(
+      gradInput, gradOutput, input, paddingSize);
+  return gradInput;
+}
+
 } // at::native
 } // at
index c1cbd81..4f4eaf8 100644 (file)
 
 - func: replication_pad2d_out(Tensor output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad2d_out_cpu
+    CUDA: replication_pad2d_out_cuda
 
 - func: replication_pad2d(Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad2d_cpu
+    CUDA: replication_pad2d_cuda
 
 - func: replication_pad2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad2d_backward_out_cpu
+    CUDA: replication_pad2d_backward_out_cuda
 
 - func: replication_pad2d_backward(Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad2d_backward_cpu
+    CUDA: replication_pad2d_backward_cuda
 
 - func: replication_pad3d_out(Tensor output, Tensor self, IntList[6] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad3d_out_cpu
+    CUDA: replication_pad3d_out_cuda
 
 - func: replication_pad3d(Tensor self, IntList[6] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad3d_cpu
+    CUDA: replication_pad3d_cuda
 
 - func: replication_pad3d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[6] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad3d_backward_out_cpu
+    CUDA: replication_pad3d_backward_out_cuda
 
 - func: replication_pad3d_backward(Tensor grad_output, Tensor self, IntList[6] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: replication_pad3d_backward_cpu
+    CUDA: replication_pad3d_backward_cuda
 
 - func: upsample_linear1d_out(Tensor output, Tensor self, IntList[1] output_size, bool align_corners) -> Tensor
   python_module: nn
index 3f23a18..a24a032 100644 (file)
     output: 'false'
     grad_input: 'false'
 
-- name: _thnn_replication_pad2d(Tensor self, IntList[4] padding)
-  cname: SpatialReplicationPadding
-  scalar_check:
-    output: 'false'
-    grad_input: 'false'
-
-- name: _thnn_replication_pad3d(Tensor self, IntList[6] padding)
-  cname: VolumetricReplicationPadding
-  scalar_check:
-    output: 'false'
-    grad_input: 'false'
-
 # Upsampling
 
 # Note: The upsampling backwards functions also include an IntList input_size
index 95d1a33..e0b9ce8 100644 (file)
@@ -42,7 +42,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxUnpooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialReflectionPadding.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SpatialReplicationPadding.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialSubSampling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBicubic.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBilinear.cu
@@ -66,7 +65,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricMaxPooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricMaxUnpooling.cu
-${CMAKE_CURRENT_SOURCE_DIR}/VolumetricReplicationPadding.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricUpSamplingNearest.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricUpSamplingTrilinear.cu
 PARENT_SCOPE)
diff --git a/aten/src/THCUNN/SpatialReplicationPadding.cu b/aten/src/THCUNN/SpatialReplicationPadding.cu
deleted file mode 100644 (file)
index 39f63c8..0000000
+++ /dev/null
@@ -1,70 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THC/THCTensor.hpp>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCReduceApplyUtils.cuh>
-#include <THC/THCApply.cuh>
-
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-template <typename Dtype>
-__global__ void SpatialReplicationPadding_updateOutput(
-  THCDeviceTensor<Dtype, 4> input,
-  THCDeviceTensor<Dtype, 4> output,
-  int padT, int padB, int padL, int padR) {
-
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-  if (outputPointId >= output.getSize(2) * output.getSize(3)) {
-    return;
-  }
-  int outputPointX = outputPointId % output.getSize(3);
-  int outputPointY = outputPointId / output.getSize(3);
-
-  int iStartX = max(0, -padL);
-  int iStartY = max(0, -padT);
-  int oStartX = max(0, padL);
-  int oStartY = max(0, padT);
-
-  int inputPointX = min(max(padL, outputPointX), input.getSize(3) + padL - 1) - oStartX + iStartX;
-  int inputPointY = min(max(padT, outputPointY), input.getSize(2) + padT - 1) - oStartY + iStartY;
-
-  Dtype valueToCopy = input[batch][plane][inputPointY][inputPointX];
-  output[batch][plane][outputPointY][outputPointX] = valueToCopy;
-}
-
-template <typename Dtype>
-__global__ void SpatialReplicationPadding_updateGradInput(
-  THCDeviceTensor<Dtype, 4> gradInput,
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  int padT, int padB, int padL, int padR) {
-
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-  if (outputPointId >= gradOutput.getSize(2) * gradOutput.getSize(3)) {
-    return;
-  }
-  int outputPointX = outputPointId % gradOutput.getSize(3);
-  int outputPointY = outputPointId / gradOutput.getSize(3);
-
-  int iStartX = max(0, -padL);
-  int iStartY = max(0, -padT);
-  int oStartX = max(0, padL);
-  int oStartY = max(0, padT);
-
-  int inputPointX = min(max(padL, outputPointX), gradInput.getSize(3) + padL - 1) - oStartX + iStartX;
-  int inputPointY = min(max(padT, outputPointY), gradInput.getSize(2) + padT - 1) - oStartY + iStartY;
-
-  Dtype valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
-  atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
-}
-
-
-#include <THCUNN/generic/SpatialReplicationPadding.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/VolumetricReplicationPadding.cu b/aten/src/THCUNN/VolumetricReplicationPadding.cu
deleted file mode 100644 (file)
index d3859e4..0000000
+++ /dev/null
@@ -1,90 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THC/THCTensor.hpp>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCReduceApplyUtils.cuh>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-#include <THC/THCApply.cuh>
-
-template <typename Dtype>
-__global__ void VolumetricReplicationPadding_updateOutput(
-  THCDeviceTensor<Dtype, 5> input,
-  THCDeviceTensor<Dtype, 5> output,
-  int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
-
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-  if (outputPointId >= (output.getSize(2) * output.getSize(3) *
-                        output.getSize(4))) {
-    return;
-  }
-  int outputPointX = outputPointId % output.getSize(4);
-  int outputPointY = (outputPointId / output.getSize(4)) % output.getSize(3);
-  int outputPointZ = outputPointId / (output.getSize(3) * output.getSize(4));
-
-  int iStartX = max(0, -pleft);
-  int iStartY = max(0, -ptop);
-  int iStartZ = max(0, -pfront);
-  int oStartX = max(0, pleft);
-  int oStartY = max(0, ptop);
-  int oStartZ = max(0, pfront);
-
-  int inputPointX = min(max(pleft, outputPointX),
-                        input.getSize(4) + pleft - 1) - oStartX + iStartX;
-  int inputPointY = min(max(ptop, outputPointY),
-                        input.getSize(3) + ptop - 1) - oStartY + iStartY;
-  int inputPointZ = min(max(pfront, outputPointZ),
-                        input.getSize(2) + pfront - 1) - oStartZ + iStartZ;
-
-  Dtype valueToCopy =
-      input[batch][plane][inputPointZ][inputPointY][inputPointX];
-  output[batch][plane][outputPointZ][outputPointY][outputPointX] = valueToCopy;
-}
-
-template <typename Dtype>
-__global__ void VolumetricReplicationPadding_updateGradInput(
-  THCDeviceTensor<Dtype, 5> gradInput,
-  THCDeviceTensor<Dtype, 5> gradOutput,
-  int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-
-  if (outputPointId >= (gradOutput.getSize(2) * gradOutput.getSize(3) *
-                        gradOutput.getSize(4))) {
-    return;
-  }
-  int outputPointX = outputPointId % gradOutput.getSize(4);
-  int outputPointY = (outputPointId / gradOutput.getSize(4)) %
-      gradOutput.getSize(3);
-  int outputPointZ = outputPointId / (gradOutput.getSize(3) *
-      gradOutput.getSize(4));
-
-  int iStartX = max(0, -pleft);
-  int iStartY = max(0, -ptop);
-  int iStartZ = max(0, -pfront);
-  int oStartX = max(0, pleft);
-  int oStartY = max(0, ptop);
-  int oStartZ = max(0, pfront);
-
-  int inputPointX = min(max(pleft, outputPointX),
-                        gradInput.getSize(4) + pleft - 1) - oStartX + iStartX;
-  int inputPointY = min(max(ptop, outputPointY),
-                        gradInput.getSize(3) + ptop - 1) - oStartY + iStartY;
-  int inputPointZ = min(max(pfront, outputPointZ),
-                        gradInput.getSize(2) + pfront - 1) - oStartZ + iStartZ;
-
-  Dtype valueToCopy =
-      gradOutput[batch][plane][outputPointZ][outputPointY][outputPointX];
-  atomicAdd(&gradInput[batch][plane][inputPointZ][inputPointY][inputPointX],
-            valueToCopy);
-}
-
-
-#include <THCUNN/generic/VolumetricReplicationPadding.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/SpatialReplicationPadding.cu b/aten/src/THCUNN/generic/SpatialReplicationPadding.cu
deleted file mode 100644 (file)
index 53a4a6e..0000000
+++ /dev/null
@@ -1,127 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SpatialReplicationPadding.cu"
-#else
-
-void THNN_(SpatialReplicationPadding_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           int padL, int padR,
-           int padT, int padB) {
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
-             "input tensor must fit into 32-bit index math");
-
-  int planeDim = 0;
-  int dimh = 1;
-  int dimw = 2;
-  int numBatch = 1;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
-                  "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s")
-
-  if (numInputDims == 4) {
-    numBatch = THCTensor_(size)(state, input, 0);
-    planeDim++;
-    dimh++;
-    dimw++;
-  }
-
-  int numPlanes = THCTensor_(size)(state, input, planeDim);
-  int inputH = THCTensor_(size)(state, input, dimh);
-  int inputW = THCTensor_(size)(state, input, dimw);
-  int outputH = inputH + padT + padB;
-  int outputW  = inputW + padL + padR;
-
-  THArgCheck(outputW >= 1 || outputH >= 1 , 2,
-             "input (H: %d, W: %d)is too small."
-             " Calculated output H: %d W: %d",
-             inputH, inputW, outputH, outputW);
-
-  THCDeviceTensor<scalar_t, 4> devInput;
-  THCDeviceTensor<scalar_t, 4> devOutput;
-
-  if (numInputDims == 3) {
-    THCTensor_(resize3d)(state, output, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 3>(state, input).upcastOuter<4>();
-    devOutput = toDeviceTensor<scalar_t, 3>(state, output).upcastOuter<4>();
-  } else {
-    THCTensor_(resize4d)(state, output, numBatch, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 4>(state, input);
-    devOutput = toDeviceTensor<scalar_t, 4>(state, output);
-  }
-
-  int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devOutput.getSize(1),
-            devOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  SpatialReplicationPadding_updateOutput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devInput, devOutput, padT, padB, padL, padR);
-
-}
-
-void THNN_(SpatialReplicationPadding_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           int padL, int padR,
-           int padT, int padB) {
-
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
-                "input tensor must fit into 32-bit index math");
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutput), 3,
-                "output gradient tensor must fit into 32-bit index math");
-
-  int planeDim = 0;
-  int dimh = 1;
-  int dimw = 2;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  if (numInputDims == 4) {
-    planeDim++;
-    dimh++;
-    dimw++;
-  }
-  int iheight = input->size(dimh);
-  int iwidth = input->size(dimw);
-  int oheight = iheight + padT + padB;
-  int owidth  = iwidth + padL + padR;
-
-  THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3,
-             "gradOutput width unexpected. Expected: %d, Got: %d",
-             owidth, THCTensor_(size)(state, gradOutput, dimw));
-  THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3,
-             "gradOutput height unexpected. Expected: %d, Got: %d",
-             oheight, THCTensor_(size)(state, gradOutput, dimh));
-
-  THCTensor_(resizeAs)(state, gradInput, input);
-  THCTensor_(zero)(state, gradInput);
-
-  THCDeviceTensor<scalar_t, 4> devGradInput;
-  THCDeviceTensor<scalar_t, 4> devGradOutput;
-
-  if (numInputDims == 3) {
-    devGradInput = toDeviceTensor<scalar_t, 3>(state, gradInput).upcastOuter<4>();
-    devGradOutput = toDeviceTensor<scalar_t, 3>(state, gradOutput).upcastOuter<4>();
-  } else {
-    devGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput);
-    devGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
-  }
-
-  int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devGradOutput.getSize(1),
-            devGradOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  SpatialReplicationPadding_updateGradInput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devGradInput, devGradOutput, padT, padB, padL, padR);
-
-}
-
-#endif
index e8d2ec0..fe3ef53 100644 (file)
@@ -862,21 +862,6 @@ THC_API void THNN_(SpatialReflectionPadding_updateGradInput)(
                   int padL, int padR,
                   int padT, int padB);
 
-THC_API void THNN_(SpatialReplicationPadding_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  int padL, int padR,
-                  int padT, int padB);
-
-THC_API void THNN_(SpatialReplicationPadding_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  int padL, int padR,
-                  int padT, int padB);
-
 THC_API void THNN_(SpatialSubSampling_updateOutput)(
                   THCState *state,
                   THCTensor *input,
@@ -1477,23 +1462,6 @@ THC_API void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput)(
                   THCTensor *gradOutput,
                   THCTensor *gradInput);
 
-THC_API void THNN_(VolumetricReplicationPadding_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  int pleft, int pright,
-                  int ptop, int pbottom,
-                  int pfront, int pback);
-
-THC_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  int pleft, int pright,
-                  int ptop, int pbottom,
-                  int pfront, int pback);
-
 THC_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
                   THCState *state,
                   THCTensor *gradOutput,
diff --git a/aten/src/THCUNN/generic/VolumetricReplicationPadding.cu b/aten/src/THCUNN/generic/VolumetricReplicationPadding.cu
deleted file mode 100644 (file)
index e4d7b4a..0000000
+++ /dev/null
@@ -1,174 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/VolumetricReplicationPadding.cu"
-#else
-
-static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
-                         THCState *state,
-                         THCTensor *input,
-                         THCTensor *gradOutput,
-                         int pleft, int pright,
-                         int ptop, int pbottom,
-                         int pfront, int pback) {
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
-             "input tensor must fit into 32-bit index math");
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-
-  THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input,
-    "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
-
-  int planeDim = 0;
-  int dimd = 1;
-  int dimh = 2;
-  int dimw = 3;
-  if (numInputDims == 5) {
-    planeDim++;
-    dimd++;
-    dimh++;
-    dimw++;
-    }
-
-  int numPlanes = THCTensor_(size)(state, input, planeDim);
-  int idepth = input->size(dimd);
-  int iheight = input->size(dimh);
-  int iwidth = input->size(dimw);
-  int odepth = idepth + pfront + pback;
-  int oheight = iheight + ptop + pbottom;
-  int owidth  = iwidth + pleft + pright;
-  THArgCheck(owidth >= 1 || oheight >= 1 || odepth >= 1, 2,
-             "input (D: %d H: %d, W: %d) is too small."
-             " Calculated output D: %d H: %d W: %d",
-             idepth, iheight, iwidth, odepth, oheight, owidth);
-
-  if (gradOutput != NULL) {
-    THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutput),
-               3, "output gradient tensor must fit into 32-bit index math");
-
-    THArgCheck(numPlanes == THCTensor_(size)(state, gradOutput, planeDim), 3,
-               "gradOutput width unexpected. Expected: %d, Got: %d",
-               numPlanes, THCTensor_(size)(state, gradOutput, planeDim));
-    THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3,
-               "gradOutput width unexpected. Expected: %d, Got: %d",
-               owidth, THCTensor_(size)(state, gradOutput, dimw));
-    THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3,
-               "gradOutput height unexpected. Expected: %d, Got: %d",
-               oheight, THCTensor_(size)(state, gradOutput, dimh));
-    THArgCheck(odepth == THCTensor_(size)(state, gradOutput, dimd), 3,
-               "gradOutput depth unexpected. Expected: %d, Got: %d",
-               odepth, THCTensor_(size)(state, gradOutput, dimd));
-  }
-}
-
-void THNN_(VolumetricReplicationPadding_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           int pleft, int pright,
-           int ptop, int pbottom,
-           int pfront, int pback) {
-  THNN_(VolumetricReplicationPadding_shapeCheck)(
-        state, input, NULL, pleft, pright, ptop,
-        pbottom, pfront, pback);
-
-  int planeDim = 0;
-  int dimd = 1;
-  int dimh = 2;
-  int dimw = 3;
-  int numBatch = 1;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-
-  if (numInputDims == 5) {
-    numBatch = THCTensor_(size)(state, input, 0);
-    planeDim++;
-    dimd++;
-    dimh++;
-    dimw++;
-  }
-
-  int numPlanes = THCTensor_(size)(state, input, planeDim);
-  int inputD = THCTensor_(size)(state, input, dimd);
-  int inputH = THCTensor_(size)(state, input, dimh);
-  int inputW = THCTensor_(size)(state, input, dimw);
-  int outputD = inputD + pfront + pback;
-  int outputH = inputH + ptop + pbottom;
-  int outputW  = inputW + pleft + pright;
-
-  THCDeviceTensor<scalar_t, 5> devInput;
-  THCDeviceTensor<scalar_t, 5> devOutput;
-
-  if (numInputDims == 4) {
-    THCTensor_(resize4d)(state, output, numPlanes, outputD, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 4>(state, input).upcastOuter<5>();
-    devOutput = toDeviceTensor<scalar_t, 4>(state, output).upcastOuter<5>();
-  } else {
-    THCTensor_(resize5d)(state, output, numBatch, numPlanes, outputD, outputH,
-                          outputW);
-
-    devInput = toDeviceTensor<scalar_t, 5>(state, input);
-    devOutput = toDeviceTensor<scalar_t, 5>(state, output);
-  }
-
-  int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3) *
-      devOutput.getSize(4);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devOutput.getSize(1),
-            devOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  VolumetricReplicationPadding_updateOutput<scalar_t><<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
-}
-
-void THNN_(VolumetricReplicationPadding_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           int pleft, int pright,
-           int ptop, int pbottom,
-           int pfront, int pback) {
-  THNN_(VolumetricReplicationPadding_shapeCheck)(
-        state, input, gradOutput, pleft, pright, ptop,
-        pbottom, pfront, pback);
-
-  int planeDim = 0;
-  int dimd = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  if (numInputDims == 5) {
-    planeDim++;
-    dimd++;
-    dimh++;
-    dimw++;
-  }
-
-  THCTensor_(resizeAs)(state, gradInput, input);
-  THCTensor_(zero)(state, gradInput);
-
-  THCDeviceTensor<scalar_t, 5> devGradInput;
-  THCDeviceTensor<scalar_t, 5> devGradOutput;
-
-  if (numInputDims == 4) {
-    devGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput).upcastOuter<5>();
-    devGradOutput =
-        toDeviceTensor<scalar_t, 4>(state, gradOutput).upcastOuter<5>();
-  } else {
-    devGradInput = toDeviceTensor<scalar_t, 5>(state, gradInput);
-    devGradOutput = toDeviceTensor<scalar_t, 5>(state, gradOutput);
-  }
-
-  int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3) *
-      devGradOutput.getSize(4);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devGradOutput.getSize(1),
-            devGradOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  VolumetricReplicationPadding_updateGradInput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright);
-}
-
-#endif
diff --git a/aten/src/THNN/generic/SpatialReplicationPadding.c b/aten/src/THNN/generic/SpatialReplicationPadding.c
deleted file mode 100644 (file)
index 8d3e1fc..0000000
+++ /dev/null
@@ -1,260 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SpatialReplicationPadding.c"
-#else
-
-static void THNN_(SpatialReplicationPadding_updateOutput_frame)(
-  scalar_t *input_p, scalar_t *output_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight,
-  int64_t owidth, int64_t oheight,
-  int pad_l, int pad_r,
-  int pad_t, int pad_b)
-{
-  int iStartX = fmax(0, -pad_l);
-  int iStartY = fmax(0, -pad_t);
-  int oStartX = fmax(0, pad_l);
-  int oStartY = fmax(0, pad_t);
-
-  int64_t k, ip_x, ip_y;
-#pragma omp parallel for private(k, ip_x, ip_y)
-  for (k = 0; k < nslices; k++)
-  {
-    int64_t i, j;
-    for (i = 0; i < oheight; i++) {
-      for (j = 0; j < owidth; j++) {
-        if (j < pad_l) {
-          ip_x = pad_l;
-        } else if (j >= pad_l && j < iwidth + pad_l) {
-          ip_x = j;
-        } else {
-          ip_x = iwidth + pad_l - 1;
-        }
-        ip_x = ip_x - oStartX + iStartX;
-
-        if (i < pad_t) {
-          ip_y = pad_t;
-        } else if (i >= pad_t && i < iheight + pad_t) {
-          ip_y = i;
-        } else {
-          ip_y = iheight + pad_t - 1;
-        }
-        ip_y = ip_y - oStartY + iStartY;
-
-        scalar_t *dest_p = output_p + k*owidth*oheight + i * owidth + j;
-        scalar_t *src_p = input_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
-        *dest_p = *src_p;
-      }
-    }
-  }
-}
-
-void THNN_(SpatialReplicationPadding_updateOutput)(THNNState *state,
-                                                         THTensor *input,
-                                                         THTensor *output,
-                                                         int pad_l, int pad_r,
-                                                         int pad_t, int pad_b)
-{
-  int dimw = 2;
-  int dimh = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t oheight;
-  int64_t owidth;
-  scalar_t *input_data;
-  scalar_t *output_data;
-
-  THNN_ARGCHECK(!input->is_empty() && (input->dim() == 3 || input->dim() == 4), 2, input,
-               "3D or 4D (batch mode) tensor expected for input, but got: %s");
-
-  if (input->dim() == 4)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  oheight = iheight + pad_t + pad_b;
-  owidth  = iwidth + pad_l + pad_r;
-
-  THArgCheck(owidth >= 1 || oheight >= 1 , 2,
-            "input (H: %d, W: %d)is too small."
-            " Calculated output H: %d W: %d",
-            iheight, iwidth, oheight, owidth);
-
-
-  /* get contiguous input */
-  input = THTensor_(newContiguous)(input);
-
-  /* resize output */
-  if (input->dim() == 3)
-  {
-    THTensor_(resize3d)(output, nslices, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-    THNN_(SpatialReplicationPadding_updateOutput_frame)(input_data, output_data,
-                                                    nslices,
-                                                    iwidth, iheight,
-                                                    owidth, oheight,
-                                                    pad_l, pad_r,
-                                                    pad_t, pad_b);
-  }
-  else
-  {
-    int64_t p;
-
-    THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
-    {
-      THNN_(SpatialReplicationPadding_updateOutput_frame)(
-        input_data+p*nslices*iwidth*iheight,
-        output_data+p*nslices*owidth*oheight,
-        nslices,
-        iwidth, iheight,
-        owidth, oheight,
-        pad_l, pad_r,
-        pad_t, pad_b);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(SpatialReplicationPadding_updateGradInput_frame)(
-  scalar_t *ginput_p, scalar_t *goutput_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight,
-  int64_t owidth, int64_t oheight,
-  int pad_l, int pad_r,
-  int pad_t, int pad_b)
-{
-  int iStartX = fmax(0, -pad_l);
-  int iStartY = fmax(0, -pad_t);
-  int oStartX = fmax(0, pad_l);
-  int oStartY = fmax(0, pad_t);
-
-  int64_t k, ip_x, ip_y;
-#pragma omp parallel for private(k, ip_x, ip_y)
-  for (k = 0; k < nslices; k++)
-  {
-    int64_t i, j;
-    for (i = 0; i < oheight; i++) {
-      for (j = 0; j < owidth; j++) {
-        if (j < pad_l) {
-          ip_x = pad_l;
-        } else if (j >= pad_l && j < iwidth + pad_l) {
-          ip_x = j;
-        } else {
-          ip_x = iwidth + pad_l - 1;
-        }
-        ip_x = ip_x - oStartX + iStartX;
-
-        if (i < pad_t) {
-          ip_y = pad_t;
-        } else if (i >= pad_t && i < iheight + pad_t) {
-          ip_y = i;
-        } else {
-          ip_y = iheight + pad_t - 1;
-        }
-        ip_y = ip_y - oStartY + iStartY;
-
-        scalar_t *src_p = goutput_p + k*owidth*oheight + i * owidth + j;
-        scalar_t *dest_p = ginput_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
-        *dest_p += *src_p;
-      }
-    }
-  }
-}
-
-void THNN_(SpatialReplicationPadding_updateGradInput)(THNNState *state,
-                                                      THTensor *input,
-                                                      THTensor *gradOutput,
-                                                      THTensor *gradInput,
-                                                      int pad_l, int pad_r,
-                                                      int pad_t, int pad_b)
-{
-  int dimw = 2;
-  int dimh = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t oheight;
-  int64_t owidth;
-
-  if (input->dim() == 4)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  oheight = iheight + pad_t + pad_b;
-  owidth  = iwidth + pad_l + pad_r;
-
-  THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3,
-            "gradOutput width unexpected. Expected: %d, Got: %d",
-            owidth, THTensor_(size)(gradOutput, dimw));
-  THArgCheck(oheight == THTensor_(size)(gradOutput, dimh), 3,
-                "gradOutput height unexpected. Expected: %d, Got: %d",
-            oheight, THTensor_(size)(gradOutput, dimh));
-
-  /* get contiguous gradOutput */
-  gradOutput = THTensor_(newContiguous)(gradOutput);
-
-  /* resize */
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(zero)(gradInput);
-
-  /* backprop */
-  if (input->dim() == 3) {
-    THNN_(SpatialReplicationPadding_updateGradInput_frame)(
-      gradInput->data<scalar_t>(),
-      gradOutput->data<scalar_t>(),
-      nslices,
-      iwidth, iheight,
-      owidth, oheight,
-      pad_l, pad_r,
-      pad_t, pad_b);
-  } else {
-    int64_t p;
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++) {
-      THNN_(SpatialReplicationPadding_updateGradInput_frame)(
-        gradInput->data<scalar_t>() + p * nslices * iheight * iwidth,
-        gradOutput->data<scalar_t>() + p * nslices * oheight * owidth,
-        nslices,
-        iwidth, iheight,
-        owidth, oheight,
-        pad_l, pad_r,
-        pad_t, pad_b);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-
-#endif
index a8f4e69..355c819 100644 (file)
@@ -938,21 +938,6 @@ TH_API void THNN_(SpatialReflectionPadding_updateGradInput)(
           int pad_left, int pad_right,
           int pad_top, int pad_bottom);
 
-TH_API void THNN_(SpatialReplicationPadding_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom);
-
-TH_API void THNN_(SpatialReplicationPadding_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom);
-
 TH_API void THNN_(FeatureLPPooling_updateOutput)(
           THNNState *state,
           THTensor *input,
@@ -973,23 +958,6 @@ TH_API void THNN_(FeatureLPPooling_updateGradInput)(
           int stride,
           bool batchMode);
 
-TH_API void THNN_(VolumetricReplicationPadding_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom,
-          int pad_front, int pad_back);
-
-TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom,
-          int pad_front, int pad_back);
-
 TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)(
           THNNState *state,
           THTensor *input,
diff --git a/aten/src/THNN/generic/VolumetricReplicationPadding.c b/aten/src/THNN/generic/VolumetricReplicationPadding.c
deleted file mode 100644 (file)
index 2b948a1..0000000
+++ /dev/null
@@ -1,357 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/VolumetricReplicationPadding.c"
-#else
-
-static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
-                         THNNState *state,
-                         THTensor *input,
-                         THTensor *gradOutput,
-                         int pleft, int pright,
-                         int ptop, int pbottom,
-                         int pfront, int pback) {
-  int dimw = 3;
-  int dimh = 2;
-  int dimd = 1;
-  int dimslices = 0;
-  int64_t nslices;
-  int64_t idepth;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t odepth;
-  int64_t oheight;
-  int64_t owidth;
-
-  THNN_ARGCHECK(!input->is_empty() && (input->dim() == 4 || input->dim() == 5), 2, input,
-               "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
-
-  if (input->dim() == 5)
-  {
-    dimw++;
-    dimh++;
-    dimd++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  idepth = input->size(dimd);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  odepth = idepth + pfront + pback;
-  oheight = iheight + ptop + pbottom;
-  owidth  = iwidth + pleft + pright;
-
-  THArgCheck(owidth >= 1 || oheight >= 1 || odepth >= 1, 2,
-             "input (D: %d H: %d, W: %d)is too small."
-             " Calculated output D: %d H: %d W: %d",
-             idepth, iheight, iwidth, odepth, oheight, owidth);
-
-  if (gradOutput != NULL) {
-    THArgCheck(nslices == THTensor_(size)(gradOutput, dimslices), 3,
-               "gradOutput width unexpected. Expected: %d, Got: %d",
-               nslices, THTensor_(size)(gradOutput, dimslices));
-    THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3,
-               "gradOutput width unexpected. Expected: %d, Got: %d",
-               owidth, THTensor_(size)(gradOutput, dimw));
-    THArgCheck(oheight == THTensor_(size)(gradOutput, dimh), 3,
-               "gradOutput height unexpected. Expected: %d, Got: %d",
-               oheight, THTensor_(size)(gradOutput, dimh));
-    THArgCheck(odepth == THTensor_(size)(gradOutput, dimd), 3,
-               "gradOutput depth unexpected. Expected: %d, Got: %d",
-               odepth, THTensor_(size)(gradOutput, dimd));
-  }
-}
-
-static void THNN_(VolumetricReplicationPadding_updateOutput_frame)(
-  scalar_t *input_p, scalar_t *output_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight, int64_t idepth,
-  int64_t owidth, int64_t oheight, int64_t odepth,
-  int pleft, int pright,
-  int ptop, int pbottom,
-  int pfront, int pback)
-{
-  int iStartX = fmax(0, -pleft);
-  int iStartY = fmax(0, -ptop);
-  int iStartZ = fmax(0, -pfront);
-  int oStartX = fmax(0, pleft);
-  int oStartY = fmax(0, ptop);
-  int oStartZ = fmax(0, pfront);
-
-  int64_t k, ip_x, ip_y, ip_z;
-#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
-  for (k = 0; k < nslices; k++) {
-    int64_t i, j, z;
-    for (z = 0; z < odepth; z++) {
-      for (i = 0; i < oheight; i++) {
-        for (j = 0; j < owidth; j++) {
-          if (j < pleft) {
-            ip_x = pleft;
-          } else if (j >= pleft && j < iwidth + pleft) {
-            ip_x = j;
-          } else {
-            ip_x = iwidth + pleft - 1;
-          }
-          ip_x = ip_x - oStartX + iStartX;
-
-          if (i < ptop) {
-            ip_y = ptop;
-          } else if (i >= ptop && i < iheight + ptop) {
-            ip_y = i;
-          } else {
-            ip_y = iheight + ptop - 1;
-          }
-          ip_y = ip_y - oStartY + iStartY;
-
-          if (z < pfront) {
-            ip_z = pfront;
-          } else if (z >= pfront && z < idepth + pfront) {
-            ip_z = z;
-          } else {
-            ip_z = idepth + pfront - 1;
-          }
-          ip_z = ip_z - oStartZ + iStartZ;
-
-          scalar_t *dest_p = output_p + k * owidth * oheight * odepth +
-              z * owidth * oheight + i * owidth + j;
-          scalar_t *src_p = input_p + k * iwidth * iheight * idepth +
-              ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
-          *dest_p = *src_p;
-        }
-      }
-    }
-  }
-}
-
-void THNN_(VolumetricReplicationPadding_updateOutput)(THNNState *state,
-                                                      THTensor *input,
-                                                      THTensor *output,
-                                                      int pleft, int pright,
-                                                      int ptop, int pbottom,
-                                                      int pfront, int pback)
-{
-  int dimw = 3;
-  int dimh = 2;
-  int dimd = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t idepth;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t odepth;
-  int64_t oheight;
-  int64_t owidth;
-  scalar_t *input_data;
-  scalar_t *output_data;
-
-THNN_(VolumetricReplicationPadding_shapeCheck)(
-      state, input, NULL, pleft, pright,
-      ptop, pbottom, pfront, pback);
-
-  if (input->dim() == 5)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimd++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  idepth = input->size(dimd);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  odepth = idepth + pfront + pback;
-  oheight = iheight + ptop + pbottom;
-  owidth  = iwidth + pleft + pright;
-
-  /* get contiguous input */
-  input = THTensor_(newContiguous)(input);
-
-  /* resize output */
-  if (input->dim() == 4)
-  {
-    THTensor_(resize4d)(output, nslices, odepth, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-    THNN_(VolumetricReplicationPadding_updateOutput_frame)(
-         input_data, output_data, nslices, iwidth, iheight, idepth,
-         owidth, oheight, odepth, pleft, pright, ptop, pbottom, pfront,
-         pback);
-  }
-  else
-  {
-    int64_t p;
-
-    THTensor_(resize5d)(output, nbatch, nslices, odepth, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
-    {
-      THNN_(VolumetricReplicationPadding_updateOutput_frame)(
-        input_data + p * nslices * iwidth * iheight * idepth,
-        output_data + p * nslices * owidth * oheight * odepth,
-        nslices,
-        iwidth, iheight, idepth,
-        owidth, oheight, odepth,
-        pleft, pright,
-        ptop, pbottom,
-        pfront, pback);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(VolumetricReplicationPadding_updateGradInput_frame)(
-  scalar_t *ginput_p, scalar_t *goutput_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight, int64_t idepth,
-  int64_t owidth, int64_t oheight, int64_t odepth,
-  int pleft, int pright,
-  int ptop, int pbottom,
-  int pfront, int pback)
-{
-  int iStartX = fmax(0, -pleft);
-  int iStartY = fmax(0, -ptop);
-  int iStartZ = fmax(0, -pfront);
-  int oStartX = fmax(0, pleft);
-  int oStartY = fmax(0, ptop);
-  int oStartZ = fmax(0, pfront);
-
-  int64_t k, ip_x, ip_y, ip_z;
-#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
-  for (k = 0; k < nslices; k++) {
-    int64_t i, j, z;
-    for (z = 0; z < odepth; z++) {
-      for (i = 0; i < oheight; i++) {
-        for (j = 0; j < owidth; j++) {
-          if (j < pleft) {
-            ip_x = pleft;
-          } else if (j >= pleft && j < iwidth + pleft) {
-            ip_x = j;
-          } else {
-            ip_x = iwidth + pleft - 1;
-          }
-          ip_x = ip_x - oStartX + iStartX;
-
-          if (i < ptop) {
-            ip_y = ptop;
-          } else if (i >= ptop && i < iheight + ptop) {
-            ip_y = i;
-          } else {
-            ip_y = iheight + ptop - 1;
-          }
-          ip_y = ip_y - oStartY + iStartY;
-
-          if (z < pfront) {
-            ip_z = pfront;
-          } else if (z >= pfront && z < idepth + pfront) {
-            ip_z = z;
-          } else {
-            ip_z = idepth + pfront - 1;
-          }
-          ip_z = ip_z - oStartZ + iStartZ;
-
-          scalar_t *src_p = goutput_p + k * owidth * oheight * odepth +
-              z * owidth * oheight + i * owidth + j;
-          scalar_t *dest_p = ginput_p + k * iwidth * iheight * idepth +
-              ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
-          *dest_p += *src_p;
-        }
-      }
-    }
-  }
-}
-
-void THNN_(VolumetricReplicationPadding_updateGradInput)(THNNState *state,
-                                                         THTensor *input,
-                                                         THTensor *gradOutput,
-                                                         THTensor *gradInput,
-                                                         int pleft, int pright,
-                                                         int ptop, int pbottom,
-                                                         int pfront, int pback)
-{
-  int dimw = 3;
-  int dimh = 2;
-  int dimd = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t idepth;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t odepth;
-  int64_t oheight;
-  int64_t owidth;
-
-  if (input->dim() == 5)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimd++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  idepth = input->size(dimd);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  odepth = idepth + pfront + pback;
-  oheight = iheight + ptop + pbottom;
-  owidth  = iwidth + pleft + pright;
-
-
-THNN_(VolumetricReplicationPadding_shapeCheck)(
-      state, input, NULL, pleft, pright,
-      ptop, pbottom, pfront, pback);
-
-  /* get contiguous gradOutput */
-  gradOutput = THTensor_(newContiguous)(gradOutput);
-
-  /* resize */
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(zero)(gradInput);
-
-  /* backprop */
-  if (input->dim() == 4) {
-    THNN_(VolumetricReplicationPadding_updateGradInput_frame)(
-      gradInput->data<scalar_t>(),
-      gradOutput->data<scalar_t>(),
-      nslices,
-      iwidth, iheight, idepth,
-      owidth, oheight, odepth,
-      pleft, pright,
-      ptop, pbottom,
-      pfront, pback);
-  } else {
-    int64_t p;
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++) {
-      THNN_(VolumetricReplicationPadding_updateGradInput_frame)(
-        gradInput->data<scalar_t>() + p * nslices * idepth * iheight * iwidth,
-        gradOutput->data<scalar_t>() + p * nslices * odepth * oheight * owidth,
-        nslices,
-        iwidth, iheight, idepth,
-        owidth, oheight, odepth,
-        pleft, pright,
-        ptop, pbottom,
-        pfront, pback);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
index 206814b..845374e 100644 (file)
 #include <THNN/generic/SpatialReflectionPadding.c>
 #include <TH/THGenerateFloatTypes.h>
 
-#include <THNN/generic/SpatialReplicationPadding.c>
-#include <TH/THGenerateFloatTypes.h>
-
-#include <THNN/generic/VolumetricReplicationPadding.c>
-#include <TH/THGenerateFloatTypes.h>
-
 #include <THNN/generic/VolumetricUpSamplingNearest.c>
 #include <TH/THGenerateFloatTypes.h>
 
index ef9db68..f18f60e 100644 (file)
@@ -307,8 +307,6 @@ def _generate_function_classes(scope_dict):
         'SpatialDilatedConvolution': 'DilatedConv2d',
         'SpatialMaxUnpooling': 'MaxUnpool2d',
         'SpatialReflectionPadding': 'ReflectionPad2d',
-        'SpatialReplicationPadding': 'ReplicationPad2d',
-        'VolumetricReplicationPadding': 'ReplicationPad3d',
         'VolumetricMaxUnpooling': 'MaxUnpool3d',
         'HardTanh': 'Hardtanh',
         'HardShrink': 'Hardshrink',