Porting legacy reflection_pad2d to ATen
authorShen Li <shenli@fb.com>
Thu, 10 Jan 2019 04:53:03 +0000 (20:53 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 04:55:27 +0000 (20:55 -0800)
Summary:
Other changes:
1. Avoided using `THCDeviceTensor` by re-calculating the mapping from cuda (blockIdx, threadIdx) to input/output tensor index.
2. Changed Camelcase naming to underscore naming.

Differential Revision: D13546803

fbshipit-source-id: 1df54f13e64934da3d803d9b6586bd5208d42d6d

13 files changed:
aten/src/ATen/native/LegacyNNDefinitions.cpp
aten/src/ATen/native/ReflectionPad.cpp
aten/src/ATen/native/cuda/ReflectionPad.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/nn.yaml
aten/src/THCUNN/CMakeLists.txt
aten/src/THCUNN/SpatialReflectionPadding.cu [deleted file]
aten/src/THCUNN/generic/SpatialReflectionPadding.cu [deleted file]
aten/src/THCUNN/generic/THCUNN.h
aten/src/THNN/generic/SpatialReflectionPadding.c [deleted file]
aten/src/THNN/generic/THNN.h
aten/src/THNN/init.cpp
torch/nn/_functions/thnn/auto.py

index 5f38bf1..57e8edb 100644 (file)
@@ -492,22 +492,6 @@ Tensor max_unpool3d_backward(const Tensor & grad_output, const Tensor & self, co
   return at::legacy::th::_thnn_max_unpool3d_backward(grad_output, self, indices, output_size, stride, padding);
 }
 
-Tensor & reflection_pad2d_out(Tensor & output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_reflection_pad2d_forward_out(output, self, padding);
-}
-
-Tensor reflection_pad2d(const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_reflection_pad2d_forward(self, padding);
-}
-
-Tensor & reflection_pad2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_reflection_pad2d_backward_out(grad_input, grad_output, self, padding);
-}
-
-Tensor reflection_pad2d_backward(const Tensor & grad_output, const Tensor & self, IntList padding) {
-  return at::legacy::th::_thnn_reflection_pad2d_backward(grad_output, self, padding);
-}
-
 Tensor & upsample_linear1d_out(Tensor & output, const Tensor & self, IntList output_size, bool align_corners) {
   return at::legacy::th::_thnn_upsample_linear1d_forward_out(output, self, output_size, align_corners);
 }
index 37a7b14..8cf3e66 100644 (file)
@@ -208,6 +208,267 @@ void reflection_pad1d_backward_out_template(
     );
   }
 }
+
+template <typename scalar_t>
+static void reflection_pad2d_out_frame(
+    scalar_t * input_p, scalar_t * output_p,
+    int64_t nplane,
+    int64_t input_w, int64_t input_h,
+    int64_t output_w, int64_t output_h,
+    int64_t pad_l, int64_t pad_t) {
+  auto i_start_x = std::max(int64_t(0), -pad_l);
+  auto i_start_y = std::max(int64_t(0), -pad_t);
+  auto o_start_x = std::max(int64_t(0), pad_l);
+  auto o_start_y = std::max(int64_t(0), pad_t);
+
+  int64_t k, ip_x, ip_y;
+#pragma omp parallel for private(k, ip_x, ip_y)
+
+  for (k = 0; k < nplane; k++) {
+    for (int64_t i = 0; i < output_h; i++) {
+      for (int64_t j = 0; j < output_w; j++) {
+        if (j < pad_l) {
+          ip_x = pad_l * 2 - j;
+        } else if (j >= pad_l && j < input_w + pad_l) {
+          ip_x = j;
+        } else {
+          ip_x = (input_w + pad_l - 1) * 2 - j;
+        }
+        ip_x = ip_x - o_start_x + i_start_x;
+
+        if (i < pad_t) {
+          ip_y = pad_t * 2 - i;
+        } else if (i >= pad_t && i < input_h + pad_t) {
+          ip_y = i;
+        } else {
+          ip_y = (input_h + pad_t - 1) * 2 - i;
+        }
+        ip_y = ip_y - o_start_y + i_start_y;
+
+        scalar_t *dest_p = output_p + k*output_w*output_h + i * output_w + j;
+        scalar_t *src_p = input_p + k*input_w*input_h + ip_y * input_w + ip_x;
+        *dest_p = *src_p;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+inline void reflection_pad2d_out_loop(
+    scalar_t * input_p, scalar_t * output_p,
+    int64_t nbatch, int64_t nplane,
+    int64_t input_w, int64_t input_h,
+    int64_t output_w, int64_t output_h,
+    int64_t pad_l, int64_t pad_t) {
+  int64_t p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++) {
+    reflection_pad2d_out_frame(
+      input_p + p * nplane * input_w * input_h,
+      output_p + p * nplane * output_w * output_h,
+      nplane,
+      input_w, input_h, output_w, output_h,
+      pad_l, pad_t);
+  }
+}
+
+void reflection_pad2d_out_template(
+    Tensor &output, const Tensor &input_, IntList padding) {
+  int dim_w = 2;
+  int dim_h = 1;
+  int dim_slices = 0;
+  int64_t nbatch = 1;
+
+  AT_CHECK(input_.numel() > 0 &&
+    (input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
+    "4D (batch mode) tensor expected for input, but got: ", input_);
+
+  if (input_.ndimension() == 4) {
+    nbatch = input_.size(0);
+    dim_w++;
+    dim_h++;
+    dim_slices++;
+  }
+
+  /* sizes */
+  int64_t pad_l = padding[0];
+  int64_t pad_r = padding[1];
+  int64_t pad_t = padding[2];
+  int64_t pad_b = padding[3];
+
+  int64_t nplane = input_.size(dim_slices);
+  int64_t input_h = input_.size(dim_h);
+  int64_t input_w = input_.size(dim_w);
+  int64_t output_h = input_h + pad_t + pad_b;
+  int64_t output_w  = input_w + pad_l + pad_r;
+
+  AT_CHECK(pad_l < input_w && pad_r < input_w,
+    "Argument #4: Padding size should be less than the corresponding "
+    "input dimension, but got: padding (", pad_l, ", ", pad_r,
+    ") at dimension ", dim_w, " of input ", input_.ndimension());
+
+  AT_CHECK(pad_t < input_h && pad_b < input_h,
+    "Argument #6: Padding size should be less than the corresponding "
+    "input dimension, but got: padding (", pad_t, ", ", pad_b,
+    ") at dimension ", dim_h, " of input ", input_.ndimension());
+
+  AT_CHECK(output_w >= 1 || output_h >= 1,
+    "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
+    "output H: ", output_h, " W: ", output_w);
+
+  /* get contiguous input */
+  Tensor input = input_.contiguous();
+
+  if (input.ndimension() == 3) {
+    /* resize output */
+    output.resize_({nplane, output_h, output_w});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] {
+      reflection_pad2d_out_frame(
+        input.data<scalar_t>(), output.data<scalar_t>(),
+        nplane,
+        input_w, input_h, output_w, output_h,
+        pad_l, pad_t);
+    });
+  } else {
+    /* resize output */
+    output.resize_({nbatch, nplane, output_h, output_w});
+    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] {
+      reflection_pad2d_out_loop(
+        input.data<scalar_t>(), output.data<scalar_t>(),
+        nbatch, nplane,
+        input_w, input_h, output_w, output_h,
+        pad_l, pad_t);
+    });
+  }
+}
+
+template <typename scalar_t>
+static void reflection_pad2d_backward_out_frame(
+    scalar_t *grad_input, scalar_t *grad_output,
+    int64_t nplane,
+    int64_t input_w, int64_t input_h,
+    int64_t output_w, int64_t output_h,
+    int64_t pad_l, int64_t pad_t) {
+  auto i_start_x = std::max(int64_t(0), -pad_l);
+  auto i_start_y = std::max(int64_t(0), -pad_t);
+  auto o_start_x = std::max(int64_t(0), pad_l);
+  auto o_start_y = std::max(int64_t(0), pad_t);
+
+  int64_t k, ip_x, ip_y;
+#pragma omp parallel for private(k, ip_x, ip_y)
+
+  for (k = 0; k < nplane; k++) {
+    for (int64_t i = 0; i < output_h; i++) {
+      for (int64_t j = 0; j < output_w; j++) {
+        if (j < pad_l) {
+          ip_x = pad_l * 2 - j;
+        } else if (j >= pad_l && j < input_w + pad_l) {
+          ip_x = j;
+        } else {
+          ip_x = (input_w + pad_l - 1) * 2 - j;
+        }
+        ip_x = ip_x - o_start_x + i_start_x;
+
+        if (i < pad_t) {
+          ip_y = pad_t * 2 - i;
+        } else if (i >= pad_t && i < input_h + pad_t) {
+          ip_y = i;
+        } else {
+          ip_y = (input_h + pad_t - 1) * 2 - i;
+        }
+        ip_y = ip_y - o_start_y + i_start_y;
+
+        scalar_t *src_p =
+          grad_output + k * output_w * output_h + i * output_w + j;
+        scalar_t *dest_p =
+          grad_input + k * input_w * input_h + ip_y * input_w + ip_x;
+        *dest_p += *src_p;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+inline void reflection_pad2d_backward_out_loop(
+    scalar_t *grad_input, scalar_t *grad_output,
+    int64_t nbatch, int64_t nplane,
+    int64_t input_w, int64_t input_h,
+    int64_t output_w, int64_t output_h,
+    int64_t pad_l, int64_t pad_t) {
+  int64_t p;
+#pragma omp parallel for private(p)
+  for (p = 0; p < nbatch; p++) {
+    reflection_pad2d_backward_out_frame(
+      grad_input + p * nplane * input_h * input_w,
+      grad_output + p * nplane * output_h * output_w,
+      nplane,
+      input_w, input_h, output_w, output_h,
+      pad_l, pad_t);
+  }
+}
+
+void reflection_pad2d_backward_out_template(
+    Tensor &grad_input, const Tensor &grad_output_,
+    const Tensor &input, IntList padding) {
+  int dim_w = 2;
+  int dim_h = 1;
+  int dim_plane = 0;
+  int64_t nbatch = 1;
+
+  if (input.ndimension() == 4) {
+    nbatch = input.size(0);
+    dim_w++;
+    dim_h++;
+    dim_plane++;
+  }
+
+  /* sizes */
+  int64_t pad_l = padding[0];
+  int64_t pad_r = padding[1];
+  int64_t pad_t = padding[2];
+  int64_t pad_b = padding[3];
+
+  int64_t nplane = input.size(dim_plane);
+  int64_t input_h = input.size(dim_h);
+  int64_t input_w = input.size(dim_w);
+  int64_t output_h = input_h + pad_t + pad_b;
+  int64_t output_w  = input_w + pad_l + pad_r;
+
+  AT_CHECK(output_w == grad_output_.size(dim_w),
+    "gradOutput width unexpected. Expected: ", output_w, ", Got: ",
+    grad_output_.size(dim_w));
+
+  AT_CHECK(output_h == grad_output_.size(dim_h),
+    "gradOutput height unexpected. Expected: ", output_h, ", Got: ",
+    grad_output_.size(dim_h));
+
+  /* get contiguous gradOutput */
+  Tensor grad_output = grad_output_.contiguous();
+
+  /* backprop */
+  if (input.ndimension() == 3) {
+    AT_DISPATCH_FLOATING_TYPES(
+      grad_output.type(), "reflection_pad2d_backward", [&] {
+        reflection_pad2d_backward_out_frame(
+          grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
+          nplane,
+          input_w, input_h, output_w, output_h,
+          pad_l, pad_t);
+      }
+    );
+  } else {
+    AT_DISPATCH_FLOATING_TYPES(
+      grad_output.type(), "reflection_pad2d_backward", [&] {
+        reflection_pad2d_backward_out_loop(
+          grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
+          nbatch, nplane,
+          input_w, input_h, output_w, output_h,
+          pad_l, pad_t);
+      }
+    );
+  }
+}
+
 } // namespace
 
 Tensor& reflection_pad1d_out_cpu(
@@ -244,5 +505,39 @@ Tensor reflection_pad1d_backward_cpu(
   return grad_input;
 }
 
+Tensor& reflection_pad2d_out_cpu(
+    Tensor& output, const Tensor& input, IntList padding) {
+  reflection_pad2d_out_template(output, input, padding);
+  return output;
+}
+
+Tensor reflection_pad2d_cpu(const Tensor& input, IntList padding) {
+  auto output = at::empty({0}, input.options());
+  reflection_pad2d_out_template(output, input, padding);
+  return output;
+}
+
+Tensor& reflection_pad2d_backward_out_cpu(
+    Tensor& grad_input,
+    const Tensor& grad_output,
+    const Tensor& input,
+    IntList padding) {
+  grad_input.resize_as_(input);
+  grad_input.zero_();
+  reflection_pad2d_backward_out_template(
+    grad_input, grad_output, input, padding);
+  return grad_input;
+}
+
+Tensor reflection_pad2d_backward_cpu(
+    const Tensor& grad_output,
+    const Tensor& input,
+    IntList padding) {
+  auto grad_input = at::zeros_like(input);
+  reflection_pad2d_backward_out_template(
+    grad_input, grad_output, input, padding);
+  return grad_input;
+}
+
 } // namespace native
 } // namespace at
index b142503..6f1d5c7 100644 (file)
@@ -16,7 +16,7 @@ namespace {
 using at::cuda::detail::canUse32BitIndexMath;
 
 __device__
-inline thrust::pair<int64_t, int64_t> get_index_mapping(
+inline thrust::pair<int64_t, int64_t> get_index_mapping1d(
     int64_t input_w, int64_t output_w,
     int64_t output_x,
     int64_t pad_l) {
@@ -39,6 +39,44 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping(
     input_offset + input_x, output_offset + output_x);
 }
 
+
+__device__
+inline thrust::pair<int64_t, int64_t>  get_index_mapping2d(
+    int64_t input_dim_x, int64_t input_dim_y,
+    int64_t output_dim_x, int64_t output_dim_y,
+    int64_t pad_l, int64_t pad_t,
+    int64_t output_xy) {
+  // 3D grid of 1D blocks
+  auto input_offset =
+    (blockIdx.y + blockIdx.z * gridDim.y) * input_dim_x * input_dim_y;
+  auto output_offset =
+    (blockIdx.y + blockIdx.z * gridDim.y) * output_dim_x * output_dim_y;
+
+  auto output_x = output_xy % output_dim_x;
+  auto output_y = output_xy / output_dim_x;
+
+  auto i_start_x = ::max(int64_t(0), -pad_l);
+  auto i_start_y = ::max(int64_t(0), -pad_t);
+  auto o_start_x = ::max(int64_t(0), pad_l);
+  auto o_start_y = ::max(int64_t(0), pad_t);
+
+  auto input_x = ::abs(output_x - pad_l)
+                 - ::abs(output_x - (input_dim_x + pad_l - 1))
+                 - output_x
+                 + 2 * pad_l + input_dim_x - 1
+                 - o_start_x + i_start_x;
+
+  auto input_y = ::abs(output_y - pad_t)
+                 - ::abs(output_y - (input_dim_y + pad_t - 1))
+                 - output_y
+                 + 2 * pad_t + input_dim_y - 1
+                 - o_start_y + i_start_y;
+
+  return thrust::make_pair<int64_t, int64_t>(
+    input_offset + input_y * input_dim_x + input_x,
+    output_offset + output_y * output_dim_x + output_x);
+}
+
 template<typename scalar_t>
 __global__ void reflection_pad1d_out_kernel(
     scalar_t * input, scalar_t * output,
@@ -48,7 +86,7 @@ __global__ void reflection_pad1d_out_kernel(
   auto output_w = input_w + pad_l + pad_r;
 
   if (output_x < output_w) {
-    auto index_pair = get_index_mapping(input_w, output_w, output_x, pad_l);
+    auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
     output[index_pair.second] = input[index_pair.first];
   }
 }
@@ -62,12 +100,52 @@ __global__ void reflection_pad1d_backward_out_kernel(
   auto output_w = input_w + pad_l + pad_r;
 
   if (output_x < output_w) {
-    auto index_pair = get_index_mapping(input_w, output_w, output_x, pad_l);
+    auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
     atomicAdd(
       &grad_input[index_pair.first], grad_output[index_pair.second]);
   }
 }
 
+template<typename scalar_t>
+__global__ void reflection_pad2d_out_kernel(
+    scalar_t * input, scalar_t * output,
+    int64_t input_dim_x, int64_t input_dim_y,
+    int pad_t, int pad_b, int pad_l, int pad_r) {
+  auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
+  auto output_dim_x = input_dim_x + pad_l + pad_r;
+  auto output_dim_y = input_dim_y + pad_t + pad_b;
+
+  if (output_xy < output_dim_x * output_dim_y) {
+    auto index_pair = get_index_mapping2d(
+      input_dim_x, input_dim_y,
+      output_dim_x, output_dim_y,
+      pad_l, pad_t,
+      output_xy);
+
+    output[index_pair.second] = input[index_pair.first];
+  }
+}
+
+template <typename scalar_t>
+__global__ void reflection_pad2d_backward_out_kernel(
+    scalar_t * grad_input, scalar_t * grad_output,
+    int64_t input_dim_x, int64_t input_dim_y,
+    int pad_t, int pad_b, int pad_l, int pad_r) {
+  auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
+  auto output_dim_x = input_dim_x + pad_l + pad_r;
+  auto output_dim_y = input_dim_y + pad_t + pad_b;
+
+  if (output_xy < output_dim_x * output_dim_y) {
+    auto index_pair = get_index_mapping2d(
+      input_dim_x, input_dim_y,
+      output_dim_x, output_dim_y,
+      pad_l, pad_t,
+      output_xy);
+
+    atomicAdd(&grad_input[index_pair.first], grad_output[index_pair.second]);
+  }
+}
+
 void reflection_pad1d_out_template(
     Tensor &output, const Tensor &input_, IntList padding) {
   AT_CHECK(canUse32BitIndexMath(input_),
@@ -172,8 +250,139 @@ void reflection_pad1d_backward_out_template(
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
+void reflection_pad2d_out_template(
+    Tensor &output, const Tensor &input_, IntList padding) {
+  AT_CHECK(canUse32BitIndexMath(input_),
+    "input tensor must fit into 32-bit index math");
+
+  int plane_dim = 0;
+  int dim_h = 1;
+  int dim_w = 2;
+  int nbatch = 1;
+
+  AT_CHECK(input_.numel() > 0 &&
+    (input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
+    "4D (batch mode) tensor expected for input, but got: ", input_);
+
+  if (input_.ndimension() == 4) {
+    nbatch = input_.size(0);
+    plane_dim++;
+    dim_h++;
+    dim_w++;
+  }
+
+  int64_t pad_l = padding[0];
+  int64_t pad_r = padding[1];
+  int64_t pad_t = padding[2];
+  int64_t pad_b = padding[3];
+
+  int nplane = input_.size(plane_dim);
+  int input_h = input_.size(dim_h);
+  int input_w = input_.size(dim_w);
+
+  AT_CHECK(pad_l < input_w && pad_r < input_w,
+    "Padding size should be less than the corresponding input dimension, but "
+    "got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
+    " of input ", input_.sizes());
+
+  AT_CHECK(pad_t < input_h && pad_b < input_h,
+    "Padding size should be less than the corresponding input dimension, but "
+    "got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
+    " of input ", input_.sizes());
+
+  int output_h = input_h + pad_t + pad_b;
+  int output_w  = input_w + pad_l + pad_r;
+
+  AT_CHECK(output_w >= 1 || output_h >= 1,
+    "input (H: ", input_h, ", W: ", input_w, ")is too small.  Calculated "
+    "output H: ", output_h, " W: ", output_w);
+
+  if (input_.ndimension() == 3) {
+    output.resize_({nplane, output_h, output_w});
+  } else {
+    output.resize_({nbatch, nplane, output_h, output_w});
+  }
+
+  Tensor input = input_.contiguous();
+
+  int output_plane_size = output_h * output_w;
+  dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
+  dim3 grid_size(
+    (int) std::ceil(output_plane_size/256.0), nplane, nbatch);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+    input.type(), "reflection_pad2d_out_template", [&] {
+      reflection_pad2d_out_kernel<<<
+        grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
+          input.data<scalar_t>(), output.data<scalar_t>(),
+          input_w, input_h,
+          pad_t, pad_b, pad_l, pad_r);
+    }
+  );
+
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void reflection_pad2d_backward_out_template(
+    Tensor &grad_input, const Tensor &grad_output_,
+    const Tensor &input, IntList padding) {
+  AT_CHECK(canUse32BitIndexMath(input),
+    "input tensor must fit into 32-bit index math");
+  AT_CHECK(canUse32BitIndexMath(grad_output_),
+    "output gradient tensor must fit into 32-bit index math");
+
+  int plane_dim = 0;
+  int dim_h = 1;
+  int dim_w = 2;
+  int nbatch = 1;
+
+  if (input.ndimension() == 4) {
+    nbatch = input.size(0);
+    plane_dim++;
+    dim_h++;
+    dim_w++;
+  }
+
+  int64_t pad_l = padding[0];
+  int64_t pad_r = padding[1];
+  int64_t pad_t = padding[2];
+  int64_t pad_b = padding[3];
+
+  int nplane = input.size(plane_dim);
+  int input_h = input.size(dim_h);
+  int input_w = input.size(dim_w);
+
+  int output_h = input_h + pad_t + pad_b;
+  int output_w  = input_w + pad_l + pad_r;
+
+  AT_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
+    "unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
+  AT_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
+    "unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
+
+  Tensor grad_output = grad_output_.contiguous();
+
+  int output_plane_size = output_h * output_w;
+  dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
+  dim3 grid_size(
+    (int) std::ceil(output_plane_size/256.0), nplane, nbatch);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+    input.type(), "reflection_pad2d_backward_out_template", [&] {
+      reflection_pad2d_backward_out_kernel<<<
+        grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
+          grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
+          input_w, input_h,
+          pad_t, pad_b, pad_l, pad_r);
+    }
+  );
+
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
 } // namespace
 
+
 Tensor& reflection_pad1d_out_cuda(
     Tensor& output, const Tensor& input, IntList padding) {
   reflection_pad1d_out_template(output, input, padding);
@@ -207,5 +416,38 @@ Tensor reflection_pad1d_backward_cuda(
   return grad_input;
 }
 
+Tensor& reflection_pad2d_out_cuda(
+    Tensor& output, const Tensor& input, IntList padding) {
+  reflection_pad2d_out_template(output, input, padding);
+  return output;
+}
+
+Tensor reflection_pad2d_cuda(const Tensor& input, IntList padding) {
+  auto output = at::empty({0}, input.options());
+  reflection_pad2d_out_template(output, input, padding);
+  return output;
+}
+
+Tensor& reflection_pad2d_backward_out_cuda(
+    Tensor& grad_input, const Tensor& grad_output,
+    const Tensor& input,
+    IntList padding) {
+  grad_input.resize_as_(input);
+  grad_input.zero_();
+  reflection_pad2d_backward_out_template(
+    grad_input, grad_output, input, padding);
+  return grad_input;
+}
+
+Tensor reflection_pad2d_backward_cuda(
+    const Tensor& grad_output,
+    const Tensor& input,
+    IntList padding) {
+  auto grad_input = at::zeros_like(input);
+  reflection_pad2d_backward_out_template(
+    grad_input, grad_output, input, padding);
+  return grad_input;
+}
+
 } // namespace native
 } // namespace at
index d7a5b8b..6aaf046 100644 (file)
 
 - func: reflection_pad2d_out(Tensor output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_out_cpu
+    CUDA: reflection_pad2d_out_cuda
 
 - func: reflection_pad2d(Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_cpu
+    CUDA: reflection_pad2d_cuda
 
 - func: reflection_pad2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_backward_out_cpu
+    CUDA: reflection_pad2d_backward_out_cuda
 
 - func: reflection_pad2d_backward(Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_backward_cpu
+    CUDA: reflection_pad2d_backward_cuda
 
 - func: replication_pad1d_out(Tensor output, Tensor self, IntList[2] padding) -> Tensor
   python_module: nn
index a24a032..3f7ee96 100644 (file)
     output: 'false'
     grad_input: 'false'
 
-# Padding
-
-- name: _thnn_reflection_pad2d(Tensor self, IntList[4] padding)
-  cname: SpatialReflectionPadding
-  scalar_check:
-    output: 'false'
-    grad_input: 'false'
-
 # Upsampling
 
 # Note: The upsampling backwards functions also include an IntList input_size
index e0b9ce8..d7f35a7 100644 (file)
@@ -41,7 +41,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxUnpooling.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SpatialReflectionPadding.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialSubSampling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBicubic.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBilinear.cu
diff --git a/aten/src/THCUNN/SpatialReflectionPadding.cu b/aten/src/THCUNN/SpatialReflectionPadding.cu
deleted file mode 100644 (file)
index 45d9dba..0000000
+++ /dev/null
@@ -1,87 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THC/THCTensor.hpp>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCReduceApplyUtils.cuh>
-#include <THC/THCApply.cuh>
-
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-template<typename Dtype>
-__global__ void SpatialReflectionPadding_updateOutput(
-  THCDeviceTensor<Dtype, 4> input,
-  THCDeviceTensor<Dtype, 4> output,
-  int padT, int padB, int padL, int padR) {
-
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-  if (outputPointId >= output.getSize(2) * output.getSize(3)) {
-    return;
-  }
-  int outputPointX = outputPointId % output.getSize(3);
-  int outputPointY = outputPointId / output.getSize(3);
-
-  int iStartX = max(0, -padL);
-  int iStartY = max(0, -padT);
-  int oStartX = max(0, padL);
-  int oStartY = max(0, padT);
-
-  int inputPointX = abs(outputPointX - padL)
-                  - abs(outputPointX - (input.getSize(3) + padL - 1))
-                  - outputPointX
-                  + 2 * padL + input.getSize(3) - 1
-                  - oStartX + iStartX;
-
-  int inputPointY = abs(outputPointY - padT)
-                  - abs(outputPointY - (input.getSize(2) + padT - 1))
-                  - outputPointY
-                  + 2 * padT + input.getSize(2) - 1
-                  - oStartY + iStartY;
-
-  Dtype valueToCopy = input[batch][plane][inputPointY][inputPointX];
-  output[batch][plane][outputPointY][outputPointX] = valueToCopy;
-}
-
-template <typename Dtype>
-__global__ void SpatialReflectionPadding_updateGradInput(
-  THCDeviceTensor<Dtype, 4> gradInput,
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  int padT, int padB, int padL, int padR) {
-
-  int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-  if (outputPointId >= gradOutput.getSize(2) * gradOutput.getSize(3)) {
-    return;
-  }
-  int outputPointX = outputPointId % gradOutput.getSize(3);
-  int outputPointY = outputPointId / gradOutput.getSize(3);
-
-  int iStartX = max(0, -padL);
-  int iStartY = max(0, -padT);
-  int oStartX = max(0, padL);
-  int oStartY = max(0, padT);
-
-  int inputPointX = abs(outputPointX - padL)
-                  - abs(outputPointX - (gradInput.getSize(3) + padL - 1))
-                  - outputPointX
-                  + 2 * padL + gradInput.getSize(3) - 1
-                  - oStartX + iStartX;
-
-  int inputPointY = abs(outputPointY - padT)
-                  - abs(outputPointY - (gradInput.getSize(2) + padT - 1))
-                  - outputPointY
-                  + 2 * padT + gradInput.getSize(2) - 1
-                  - oStartY + iStartY;
-
-  Dtype valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
-  atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
-}
-
-#include <THCUNN/generic/SpatialReflectionPadding.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/SpatialReflectionPadding.cu b/aten/src/THCUNN/generic/SpatialReflectionPadding.cu
deleted file mode 100644 (file)
index a6d6663..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SpatialReflectionPadding.cu"
-#else
-
-void THNN_(SpatialReflectionPadding_updateOutput)(THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           int padL, int padR,
-           int padT, int padB) {
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
-             "input tensor must fit into 32-bit index math");
-
-  int planeDim = 0;
-  int dimh = 1;
-  int dimw = 2;
-  int numBatch = 1;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
-                  "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s")
-
-  if (numInputDims == 4) {
-    numBatch = THCTensor_(size)(state, input, 0);
-    planeDim++;
-    dimh++;
-    dimw++;
-  }
-
-  int numPlanes = THCTensor_(size)(state, input, planeDim);
-  int inputH = THCTensor_(size)(state, input, dimh);
-  int inputW = THCTensor_(size)(state, input, dimw);
-
-  THArgCheck(padL < inputW && padR < inputW, 4,
-             "Padding size should be less than the corresponding input dimension, "
-             "but got: padding (%d, %d) at dimension %d of input %s",
-             padL, padR, dimw, THCTensor_(sizeDesc)(state, input).str);
-
-  THArgCheck(padT < inputH && padB < inputH, 6,
-             "Padding size should be less than the corresponding input dimension, "
-             "but got: padding (%d, %d) at dimension %d of input %s",
-             padT, padB, dimh, THCTensor_(sizeDesc)(state, input).str);
-
-  int outputH = inputH + padT + padB;
-  int outputW  = inputW + padL + padR;
-
-  THArgCheck(outputW >= 1 || outputH >= 1, 2,
-             "input (H: %d, W: %d)is too small."
-             " Calculated output H: %d W: %d",
-             inputH, inputW, outputH, outputW);
-
-  THCDeviceTensor<scalar_t, 4> devInput;
-  THCDeviceTensor<scalar_t, 4> devOutput;
-
-  if (numInputDims == 3) {
-    THCTensor_(resize3d)(state, output, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 3>(state, input).upcastOuter<4>();
-    devOutput = toDeviceTensor<scalar_t, 3>(state, output).upcastOuter<4>();
-  } else {
-    THCTensor_(resize4d)(state, output, numBatch, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 4>(state, input);
-    devOutput = toDeviceTensor<scalar_t, 4>(state, output);
-  }
-
-  int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devOutput.getSize(1),
-            devOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  SpatialReflectionPadding_updateOutput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devInput, devOutput, padT, padB, padL, padR);
-  THCudaCheck(cudaGetLastError());
-}
-
-void THNN_(SpatialReflectionPadding_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           int padL, int padR,
-           int padT, int padB) {
-
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
-                "input tensor must fit into 32-bit index math");
-  THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutput), 3,
-                "output gradient tensor must fit into 32-bit index math");
-
-  int planeDim = 0;
-  int dimh = 1;
-  int dimw = 2;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  if (numInputDims == 4) {
-    planeDim++;
-    dimh++;
-    dimw++;
-  }
-  int iheight = input->size(dimh);
-  int iwidth = input->size(dimw);
-  int oheight = iheight + padT + padB;
-  int owidth  = iwidth + padL + padR;
-
-  THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3,
-             "gradOutput width unexpected. Expected: %d, Got: %d",
-             owidth, THCTensor_(size)(state, gradOutput, dimw));
-  THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3,
-             "gradOutput height unexpected. Expected: %d, Got: %d",
-             oheight, THCTensor_(size)(state, gradOutput, dimh));
-
-  THCTensor_(resizeAs)(state, gradInput, input);
-  THCTensor_(zero)(state, gradInput);
-
-  THCDeviceTensor<scalar_t, 4> devGradInput;
-  THCDeviceTensor<scalar_t, 4> devGradOutput;
-
-  if (numInputDims == 3) {
-    devGradInput = toDeviceTensor<scalar_t, 3>(state, gradInput).upcastOuter<4>();
-    devGradOutput = toDeviceTensor<scalar_t, 3>(state, gradOutput).upcastOuter<4>();
-  } else {
-    devGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput);
-    devGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
-  }
-
-  int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3);
-  dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
-            devGradOutput.getSize(1),
-            devGradOutput.getSize(0));
-  dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
-
-  SpatialReflectionPadding_updateGradInput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>(
-    devGradInput, devGradOutput, padT, padB, padL, padR);
-  THCudaCheck(cudaGetLastError());
-}
-
-#endif
index fe3ef53..08fefca 100644 (file)
@@ -847,21 +847,6 @@ THC_API void THNN_(SpatialMaxUnpooling_updateGradInput)(
                   THCIndexTensor *indices,
                   int owidth, int oheight);
 
-THC_API void THNN_(SpatialReflectionPadding_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  int padL, int padR,
-                  int padT, int padB);
-
-THC_API void THNN_(SpatialReflectionPadding_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  int padL, int padR,
-                  int padT, int padB);
-
 THC_API void THNN_(SpatialSubSampling_updateOutput)(
                   THCState *state,
                   THCTensor *input,
diff --git a/aten/src/THNN/generic/SpatialReflectionPadding.c b/aten/src/THNN/generic/SpatialReflectionPadding.c
deleted file mode 100644 (file)
index f7c240b..0000000
+++ /dev/null
@@ -1,270 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SpatialReflectionPadding.c"
-#else
-
-static void THNN_(SpatialReflectionPadding_updateOutput_frame)(
-  scalar_t *input_p, scalar_t *output_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight,
-  int64_t owidth, int64_t oheight,
-  int pad_l, int pad_r,
-  int pad_t, int pad_b)
-{
-  int iStartX = fmax(0, -pad_l);
-  int iStartY = fmax(0, -pad_t);
-  int oStartX = fmax(0, pad_l);
-  int oStartY = fmax(0, pad_t);
-
-  int64_t k, ip_x, ip_y;
-#pragma omp parallel for private(k, ip_x, ip_y)
-
-  for (k = 0; k < nslices; k++)
-  {
-    int64_t i, j;
-    for (i = 0; i < oheight; i++) {
-      for (j = 0; j < owidth; j++) {
-        if (j < pad_l) {
-          ip_x = pad_l * 2 - j;
-        } else if (j >= pad_l && j < iwidth + pad_l) {
-          ip_x = j;
-        } else {
-          ip_x = (iwidth + pad_l - 1) * 2 - j;
-        }
-        ip_x = ip_x - oStartX + iStartX;
-
-        if (i < pad_t) {
-          ip_y = pad_t * 2 - i;
-        } else if (i >= pad_t && i < iheight + pad_t) {
-          ip_y = i;
-        } else {
-          ip_y = (iheight + pad_t - 1) * 2 - i;
-        }
-        ip_y = ip_y - oStartY + iStartY;
-
-        scalar_t *dest_p = output_p + k*owidth*oheight + i * owidth + j;
-        scalar_t *src_p = input_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
-        *dest_p = *src_p;
-      }
-    }
-  }
-}
-
-void THNN_(SpatialReflectionPadding_updateOutput)(THNNState *state,
-                                                  THTensor *input,
-                                                  THTensor *output,
-                                                  int pad_l, int pad_r,
-                                                  int pad_t, int pad_b)
-{
-  int dimw = 2;
-  int dimh = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t oheight;
-  int64_t owidth;
-  scalar_t *input_data;
-  scalar_t *output_data;
-
-  THNN_ARGCHECK(!input->is_empty() && (input->dim() == 3 || input->dim() == 4), 2, input,
-               "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
-
-  if (input->dim() == 4)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimslices++;
-  }
-
-  /* input sizes */
-  nslices = input->size(dimslices);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-
-  AT_CHECK(pad_l < iwidth && pad_r < iwidth,
-           "Argument #4: Padding size should be less than the corresponding input dimension, "
-           "but got: padding (", pad_l, ", ", pad_r, ") at dimension ", dimw, " of input ", input->sizes());
-
-  AT_CHECK(pad_t < iheight && pad_b < iheight,
-           "Argument #6: Padding size should be less than the corresponding input dimension, "
-           "but got: padding (", pad_t, ", ", pad_b, ") at dimension ", dimh, " of input ", input->sizes());
-
-  /* output sizes */
-  oheight = iheight + pad_t + pad_b;
-  owidth  = iwidth + pad_l + pad_r;
-
-  THArgCheck(owidth >= 1 || oheight >= 1 , 2,
-            "input (H: %d, W: %d)is too small."
-            " Calculated output H: %d W: %d",
-            iheight, iwidth, oheight, owidth);
-
-  /* get contiguous input */
-  input = THTensor_(newContiguous)(input);
-
-  /* resize output */
-  if (input->dim() == 3)
-  {
-    THTensor_(resize3d)(output, nslices, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-    THNN_(SpatialReflectionPadding_updateOutput_frame)(input_data, output_data,
-                                                    nslices,
-                                                    iwidth, iheight,
-                                                    owidth, oheight,
-                                                    pad_l, pad_r,
-                                                    pad_t, pad_b);
-  }
-  else
-  {
-    int64_t p;
-
-    THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
-    {
-      THNN_(SpatialReflectionPadding_updateOutput_frame)(
-        input_data+p*nslices*iwidth*iheight,
-        output_data+p*nslices*owidth*oheight,
-        nslices,
-        iwidth, iheight,
-        owidth, oheight,
-        pad_l, pad_r,
-        pad_t, pad_b);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(SpatialReflectionPadding_updateGradInput_frame)(
-  scalar_t *ginput_p, scalar_t *goutput_p,
-  int64_t nslices,
-  int64_t iwidth, int64_t iheight,
-  int64_t owidth, int64_t oheight,
-  int pad_l, int pad_r,
-  int pad_t, int pad_b)
-{
-  int iStartX = fmax(0, -pad_l);
-  int iStartY = fmax(0, -pad_t);
-  int oStartX = fmax(0, pad_l);
-  int oStartY = fmax(0, pad_t);
-
-  int64_t k, ip_x, ip_y;
-#pragma omp parallel for private(k, ip_x, ip_y)
-
-  for (k = 0; k < nslices; k++)
-  {
-    int64_t i, j;
-    for (i = 0; i < oheight; i++) {
-      for (j = 0; j < owidth; j++) {
-        if (j < pad_l) {
-          ip_x = pad_l * 2 - j;
-        } else if (j >= pad_l && j < iwidth + pad_l) {
-          ip_x = j;
-        } else {
-          ip_x = (iwidth + pad_l - 1) * 2 - j;
-        }
-        ip_x = ip_x - oStartX + iStartX;
-
-        if (i < pad_t) {
-          ip_y = pad_t * 2 - i;
-        } else if (i >= pad_t && i < iheight + pad_t) {
-          ip_y = i;
-        } else {
-          ip_y = (iheight + pad_t - 1) * 2 - i;
-        }
-        ip_y = ip_y - oStartY + iStartY;
-
-        scalar_t *src_p = goutput_p + k*owidth*oheight + i * owidth + j;
-        scalar_t *dest_p = ginput_p + k*iwidth*iheight + ip_y * iwidth + ip_x;
-        *dest_p += *src_p;
-      }
-    }
-  }
-}
-
-void THNN_(SpatialReflectionPadding_updateGradInput)(THNNState *state,
-                                                      THTensor *input,
-                                                      THTensor *gradOutput,
-                                                      THTensor *gradInput,
-                                                      int pad_l, int pad_r,
-                                                      int pad_t, int pad_b)
-{
-  int dimw = 2;
-  int dimh = 1;
-  int dimslices = 0;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t oheight;
-  int64_t owidth;
-
-  if (input->dim() == 4)
-  {
-    nbatch = input->size(0);
-    dimw++;
-    dimh++;
-    dimslices++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimslices);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  oheight = iheight + pad_t + pad_b;
-  owidth  = iwidth + pad_l + pad_r;
-
-  THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3,
-            "gradOutput width unexpected. Expected: %d, Got: %d",
-            owidth, THTensor_(size)(gradOutput, dimw));
-  THArgCheck(oheight == THTensor_(size)(gradOutput, dimh), 3,
-                "gradOutput height unexpected. Expected: %d, Got: %d",
-            oheight, THTensor_(size)(gradOutput, dimh));
-
-  /* get contiguous gradOutput */
-  gradOutput = THTensor_(newContiguous)(gradOutput);
-
-  /* resize */
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(zero)(gradInput);
-
-  /* backprop */
-  if (input->dim() == 3) {
-    THNN_(SpatialReflectionPadding_updateGradInput_frame)(
-      gradInput->data<scalar_t>(),
-      gradOutput->data<scalar_t>(),
-      nslices,
-      iwidth, iheight,
-      owidth, oheight,
-      pad_l, pad_r,
-      pad_t, pad_b);
-  } else {
-    int64_t p;
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++) {
-      THNN_(SpatialReflectionPadding_updateGradInput_frame)(
-        gradInput->data<scalar_t>() + p * nslices * iheight * iwidth,
-        gradOutput->data<scalar_t>() + p * nslices * oheight * owidth,
-        nslices,
-        iwidth, iheight,
-        owidth, oheight,
-        pad_l, pad_r,
-        pad_t, pad_b);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
index 355c819..f98077c 100644 (file)
@@ -923,21 +923,6 @@ TH_API void THNN_(VolumetricAdaptiveMaxPooling_updateGradInput)(
           THTensor *gradInput,
           THIndexTensor *indices);
 
-TH_API void THNN_(SpatialReflectionPadding_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom);
-
-TH_API void THNN_(SpatialReflectionPadding_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int pad_left, int pad_right,
-          int pad_top, int pad_bottom);
-
 TH_API void THNN_(FeatureLPPooling_updateOutput)(
           THNNState *state,
           THTensor *input,
index 845374e..9120420 100644 (file)
 #include <THNN/generic/VolumetricMaxUnpooling.c>
 #include <TH/THGenerateFloatTypes.h>
 
-#include <THNN/generic/SpatialReflectionPadding.c>
-#include <TH/THGenerateFloatTypes.h>
-
 #include <THNN/generic/VolumetricUpSamplingNearest.c>
 #include <TH/THGenerateFloatTypes.h>
 
index f18f60e..2b12ffc 100644 (file)
@@ -306,7 +306,6 @@ def _generate_function_classes(scope_dict):
         'TemporalConvolution': 'Conv1d',
         'SpatialDilatedConvolution': 'DilatedConv2d',
         'SpatialMaxUnpooling': 'MaxUnpool2d',
-        'SpatialReflectionPadding': 'ReflectionPad2d',
         'VolumetricMaxUnpooling': 'MaxUnpool3d',
         'HardTanh': 'Hardtanh',
         'HardShrink': 'Hardshrink',