From 23e28efed43d84366894fb6c6c963e64097730d8 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Wed, 9 Jan 2019 20:53:03 -0800 Subject: [PATCH] Porting legacy reflection_pad2d to ATen Summary: Other changes: 1. Avoided using `THCDeviceTensor` by re-calculating the mapping from cuda (blockIdx, threadIdx) to input/output tensor index. 2. Changed Camelcase naming to underscore naming. Differential Revision: D13546803 fbshipit-source-id: 1df54f13e64934da3d803d9b6586bd5208d42d6d --- aten/src/ATen/native/LegacyNNDefinitions.cpp | 16 -- aten/src/ATen/native/ReflectionPad.cpp | 295 +++++++++++++++++++++ aten/src/ATen/native/cuda/ReflectionPad.cu | 248 ++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 12 + aten/src/ATen/nn.yaml | 8 - aten/src/THCUNN/CMakeLists.txt | 1 - aten/src/THCUNN/SpatialReflectionPadding.cu | 87 ------ .../src/THCUNN/generic/SpatialReflectionPadding.cu | 137 ---------- aten/src/THCUNN/generic/THCUNN.h | 15 -- aten/src/THNN/generic/SpatialReflectionPadding.c | 270 ------------------- aten/src/THNN/generic/THNN.h | 15 -- aten/src/THNN/init.cpp | 3 - torch/nn/_functions/thnn/auto.py | 1 - 13 files changed, 552 insertions(+), 556 deletions(-) delete mode 100644 aten/src/THCUNN/SpatialReflectionPadding.cu delete mode 100644 aten/src/THCUNN/generic/SpatialReflectionPadding.cu delete mode 100644 aten/src/THNN/generic/SpatialReflectionPadding.c diff --git a/aten/src/ATen/native/LegacyNNDefinitions.cpp b/aten/src/ATen/native/LegacyNNDefinitions.cpp index 5f38bf1..57e8edb 100644 --- a/aten/src/ATen/native/LegacyNNDefinitions.cpp +++ b/aten/src/ATen/native/LegacyNNDefinitions.cpp @@ -492,22 +492,6 @@ Tensor max_unpool3d_backward(const Tensor & grad_output, const Tensor & self, co return at::legacy::th::_thnn_max_unpool3d_backward(grad_output, self, indices, output_size, stride, padding); } -Tensor & reflection_pad2d_out(Tensor & output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_reflection_pad2d_forward_out(output, self, padding); -} - -Tensor reflection_pad2d(const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_reflection_pad2d_forward(self, padding); -} - -Tensor & reflection_pad2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_reflection_pad2d_backward_out(grad_input, grad_output, self, padding); -} - -Tensor reflection_pad2d_backward(const Tensor & grad_output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_reflection_pad2d_backward(grad_output, self, padding); -} - Tensor & upsample_linear1d_out(Tensor & output, const Tensor & self, IntList output_size, bool align_corners) { return at::legacy::th::_thnn_upsample_linear1d_forward_out(output, self, output_size, align_corners); } diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp index 37a7b14..8cf3e66 100644 --- a/aten/src/ATen/native/ReflectionPad.cpp +++ b/aten/src/ATen/native/ReflectionPad.cpp @@ -208,6 +208,267 @@ void reflection_pad1d_backward_out_template( ); } } + +template +static void reflection_pad2d_out_frame( + scalar_t * input_p, scalar_t * output_p, + int64_t nplane, + int64_t input_w, int64_t input_h, + int64_t output_w, int64_t output_h, + int64_t pad_l, int64_t pad_t) { + auto i_start_x = std::max(int64_t(0), -pad_l); + auto i_start_y = std::max(int64_t(0), -pad_t); + auto o_start_x = std::max(int64_t(0), pad_l); + auto o_start_y = std::max(int64_t(0), pad_t); + + int64_t k, ip_x, ip_y; +#pragma omp parallel for private(k, ip_x, ip_y) + + for (k = 0; k < nplane; k++) { + for (int64_t i = 0; i < output_h; i++) { + for (int64_t j = 0; j < output_w; j++) { + if (j < pad_l) { + ip_x = pad_l * 2 - j; + } else if (j >= pad_l && j < input_w + pad_l) { + ip_x = j; + } else { + ip_x = (input_w + pad_l - 1) * 2 - j; + } + ip_x = ip_x - o_start_x + i_start_x; + + if (i < pad_t) { + ip_y = pad_t * 2 - i; + } else if (i >= pad_t && i < input_h + pad_t) { + ip_y = i; + } else { + ip_y = (input_h + pad_t - 1) * 2 - i; + } + ip_y = ip_y - o_start_y + i_start_y; + + scalar_t *dest_p = output_p + k*output_w*output_h + i * output_w + j; + scalar_t *src_p = input_p + k*input_w*input_h + ip_y * input_w + ip_x; + *dest_p = *src_p; + } + } + } +} + +template +inline void reflection_pad2d_out_loop( + scalar_t * input_p, scalar_t * output_p, + int64_t nbatch, int64_t nplane, + int64_t input_w, int64_t input_h, + int64_t output_w, int64_t output_h, + int64_t pad_l, int64_t pad_t) { + int64_t p; +#pragma omp parallel for private(p) + for (p = 0; p < nbatch; p++) { + reflection_pad2d_out_frame( + input_p + p * nplane * input_w * input_h, + output_p + p * nplane * output_w * output_h, + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + } +} + +void reflection_pad2d_out_template( + Tensor &output, const Tensor &input_, IntList padding) { + int dim_w = 2; + int dim_h = 1; + int dim_slices = 0; + int64_t nbatch = 1; + + AT_CHECK(input_.numel() > 0 && + (input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or " + "4D (batch mode) tensor expected for input, but got: ", input_); + + if (input_.ndimension() == 4) { + nbatch = input_.size(0); + dim_w++; + dim_h++; + dim_slices++; + } + + /* sizes */ + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + int64_t pad_t = padding[2]; + int64_t pad_b = padding[3]; + + int64_t nplane = input_.size(dim_slices); + int64_t input_h = input_.size(dim_h); + int64_t input_w = input_.size(dim_w); + int64_t output_h = input_h + pad_t + pad_b; + int64_t output_w = input_w + pad_l + pad_r; + + AT_CHECK(pad_l < input_w && pad_r < input_w, + "Argument #4: Padding size should be less than the corresponding " + "input dimension, but got: padding (", pad_l, ", ", pad_r, + ") at dimension ", dim_w, " of input ", input_.ndimension()); + + AT_CHECK(pad_t < input_h && pad_b < input_h, + "Argument #6: Padding size should be less than the corresponding " + "input dimension, but got: padding (", pad_t, ", ", pad_b, + ") at dimension ", dim_h, " of input ", input_.ndimension()); + + AT_CHECK(output_w >= 1 || output_h >= 1, + "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated " + "output H: ", output_h, " W: ", output_w); + + /* get contiguous input */ + Tensor input = input_.contiguous(); + + if (input.ndimension() == 3) { + /* resize output */ + output.resize_({nplane, output_h, output_w}); + AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] { + reflection_pad2d_out_frame( + input.data(), output.data(), + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } else { + /* resize output */ + output.resize_({nbatch, nplane, output_h, output_w}); + AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] { + reflection_pad2d_out_loop( + input.data(), output.data(), + nbatch, nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } +} + +template +static void reflection_pad2d_backward_out_frame( + scalar_t *grad_input, scalar_t *grad_output, + int64_t nplane, + int64_t input_w, int64_t input_h, + int64_t output_w, int64_t output_h, + int64_t pad_l, int64_t pad_t) { + auto i_start_x = std::max(int64_t(0), -pad_l); + auto i_start_y = std::max(int64_t(0), -pad_t); + auto o_start_x = std::max(int64_t(0), pad_l); + auto o_start_y = std::max(int64_t(0), pad_t); + + int64_t k, ip_x, ip_y; +#pragma omp parallel for private(k, ip_x, ip_y) + + for (k = 0; k < nplane; k++) { + for (int64_t i = 0; i < output_h; i++) { + for (int64_t j = 0; j < output_w; j++) { + if (j < pad_l) { + ip_x = pad_l * 2 - j; + } else if (j >= pad_l && j < input_w + pad_l) { + ip_x = j; + } else { + ip_x = (input_w + pad_l - 1) * 2 - j; + } + ip_x = ip_x - o_start_x + i_start_x; + + if (i < pad_t) { + ip_y = pad_t * 2 - i; + } else if (i >= pad_t && i < input_h + pad_t) { + ip_y = i; + } else { + ip_y = (input_h + pad_t - 1) * 2 - i; + } + ip_y = ip_y - o_start_y + i_start_y; + + scalar_t *src_p = + grad_output + k * output_w * output_h + i * output_w + j; + scalar_t *dest_p = + grad_input + k * input_w * input_h + ip_y * input_w + ip_x; + *dest_p += *src_p; + } + } + } +} + +template +inline void reflection_pad2d_backward_out_loop( + scalar_t *grad_input, scalar_t *grad_output, + int64_t nbatch, int64_t nplane, + int64_t input_w, int64_t input_h, + int64_t output_w, int64_t output_h, + int64_t pad_l, int64_t pad_t) { + int64_t p; +#pragma omp parallel for private(p) + for (p = 0; p < nbatch; p++) { + reflection_pad2d_backward_out_frame( + grad_input + p * nplane * input_h * input_w, + grad_output + p * nplane * output_h * output_w, + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + } +} + +void reflection_pad2d_backward_out_template( + Tensor &grad_input, const Tensor &grad_output_, + const Tensor &input, IntList padding) { + int dim_w = 2; + int dim_h = 1; + int dim_plane = 0; + int64_t nbatch = 1; + + if (input.ndimension() == 4) { + nbatch = input.size(0); + dim_w++; + dim_h++; + dim_plane++; + } + + /* sizes */ + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + int64_t pad_t = padding[2]; + int64_t pad_b = padding[3]; + + int64_t nplane = input.size(dim_plane); + int64_t input_h = input.size(dim_h); + int64_t input_w = input.size(dim_w); + int64_t output_h = input_h + pad_t + pad_b; + int64_t output_w = input_w + pad_l + pad_r; + + AT_CHECK(output_w == grad_output_.size(dim_w), + "gradOutput width unexpected. Expected: ", output_w, ", Got: ", + grad_output_.size(dim_w)); + + AT_CHECK(output_h == grad_output_.size(dim_h), + "gradOutput height unexpected. Expected: ", output_h, ", Got: ", + grad_output_.size(dim_h)); + + /* get contiguous gradOutput */ + Tensor grad_output = grad_output_.contiguous(); + + /* backprop */ + if (input.ndimension() == 3) { + AT_DISPATCH_FLOATING_TYPES( + grad_output.type(), "reflection_pad2d_backward", [&] { + reflection_pad2d_backward_out_frame( + grad_input.data(), grad_output.data(), + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + } + ); + } else { + AT_DISPATCH_FLOATING_TYPES( + grad_output.type(), "reflection_pad2d_backward", [&] { + reflection_pad2d_backward_out_loop( + grad_input.data(), grad_output.data(), + nbatch, nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + } + ); + } +} + } // namespace Tensor& reflection_pad1d_out_cpu( @@ -244,5 +505,39 @@ Tensor reflection_pad1d_backward_cpu( return grad_input; } +Tensor& reflection_pad2d_out_cpu( + Tensor& output, const Tensor& input, IntList padding) { + reflection_pad2d_out_template(output, input, padding); + return output; +} + +Tensor reflection_pad2d_cpu(const Tensor& input, IntList padding) { + auto output = at::empty({0}, input.options()); + reflection_pad2d_out_template(output, input, padding); + return output; +} + +Tensor& reflection_pad2d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + IntList padding) { + grad_input.resize_as_(input); + grad_input.zero_(); + reflection_pad2d_backward_out_template( + grad_input, grad_output, input, padding); + return grad_input; +} + +Tensor reflection_pad2d_backward_cpu( + const Tensor& grad_output, + const Tensor& input, + IntList padding) { + auto grad_input = at::zeros_like(input); + reflection_pad2d_backward_out_template( + grad_input, grad_output, input, padding); + return grad_input; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index b142503..6f1d5c7 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -16,7 +16,7 @@ namespace { using at::cuda::detail::canUse32BitIndexMath; __device__ -inline thrust::pair get_index_mapping( +inline thrust::pair get_index_mapping1d( int64_t input_w, int64_t output_w, int64_t output_x, int64_t pad_l) { @@ -39,6 +39,44 @@ inline thrust::pair get_index_mapping( input_offset + input_x, output_offset + output_x); } + +__device__ +inline thrust::pair get_index_mapping2d( + int64_t input_dim_x, int64_t input_dim_y, + int64_t output_dim_x, int64_t output_dim_y, + int64_t pad_l, int64_t pad_t, + int64_t output_xy) { + // 3D grid of 1D blocks + auto input_offset = + (blockIdx.y + blockIdx.z * gridDim.y) * input_dim_x * input_dim_y; + auto output_offset = + (blockIdx.y + blockIdx.z * gridDim.y) * output_dim_x * output_dim_y; + + auto output_x = output_xy % output_dim_x; + auto output_y = output_xy / output_dim_x; + + auto i_start_x = ::max(int64_t(0), -pad_l); + auto i_start_y = ::max(int64_t(0), -pad_t); + auto o_start_x = ::max(int64_t(0), pad_l); + auto o_start_y = ::max(int64_t(0), pad_t); + + auto input_x = ::abs(output_x - pad_l) + - ::abs(output_x - (input_dim_x + pad_l - 1)) + - output_x + + 2 * pad_l + input_dim_x - 1 + - o_start_x + i_start_x; + + auto input_y = ::abs(output_y - pad_t) + - ::abs(output_y - (input_dim_y + pad_t - 1)) + - output_y + + 2 * pad_t + input_dim_y - 1 + - o_start_y + i_start_y; + + return thrust::make_pair( + input_offset + input_y * input_dim_x + input_x, + output_offset + output_y * output_dim_x + output_x); +} + template __global__ void reflection_pad1d_out_kernel( scalar_t * input, scalar_t * output, @@ -48,7 +86,7 @@ __global__ void reflection_pad1d_out_kernel( auto output_w = input_w + pad_l + pad_r; if (output_x < output_w) { - auto index_pair = get_index_mapping(input_w, output_w, output_x, pad_l); + auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l); output[index_pair.second] = input[index_pair.first]; } } @@ -62,12 +100,52 @@ __global__ void reflection_pad1d_backward_out_kernel( auto output_w = input_w + pad_l + pad_r; if (output_x < output_w) { - auto index_pair = get_index_mapping(input_w, output_w, output_x, pad_l); + auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l); atomicAdd( &grad_input[index_pair.first], grad_output[index_pair.second]); } } +template +__global__ void reflection_pad2d_out_kernel( + scalar_t * input, scalar_t * output, + int64_t input_dim_x, int64_t input_dim_y, + int pad_t, int pad_b, int pad_l, int pad_r) { + auto output_xy = threadIdx.x + blockIdx.x * blockDim.x; + auto output_dim_x = input_dim_x + pad_l + pad_r; + auto output_dim_y = input_dim_y + pad_t + pad_b; + + if (output_xy < output_dim_x * output_dim_y) { + auto index_pair = get_index_mapping2d( + input_dim_x, input_dim_y, + output_dim_x, output_dim_y, + pad_l, pad_t, + output_xy); + + output[index_pair.second] = input[index_pair.first]; + } +} + +template +__global__ void reflection_pad2d_backward_out_kernel( + scalar_t * grad_input, scalar_t * grad_output, + int64_t input_dim_x, int64_t input_dim_y, + int pad_t, int pad_b, int pad_l, int pad_r) { + auto output_xy = threadIdx.x + blockIdx.x * blockDim.x; + auto output_dim_x = input_dim_x + pad_l + pad_r; + auto output_dim_y = input_dim_y + pad_t + pad_b; + + if (output_xy < output_dim_x * output_dim_y) { + auto index_pair = get_index_mapping2d( + input_dim_x, input_dim_y, + output_dim_x, output_dim_y, + pad_l, pad_t, + output_xy); + + atomicAdd(&grad_input[index_pair.first], grad_output[index_pair.second]); + } +} + void reflection_pad1d_out_template( Tensor &output, const Tensor &input_, IntList padding) { AT_CHECK(canUse32BitIndexMath(input_), @@ -172,8 +250,139 @@ void reflection_pad1d_backward_out_template( AT_CUDA_CHECK(cudaGetLastError()); } +void reflection_pad2d_out_template( + Tensor &output, const Tensor &input_, IntList padding) { + AT_CHECK(canUse32BitIndexMath(input_), + "input tensor must fit into 32-bit index math"); + + int plane_dim = 0; + int dim_h = 1; + int dim_w = 2; + int nbatch = 1; + + AT_CHECK(input_.numel() > 0 && + (input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or " + "4D (batch mode) tensor expected for input, but got: ", input_); + + if (input_.ndimension() == 4) { + nbatch = input_.size(0); + plane_dim++; + dim_h++; + dim_w++; + } + + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + int64_t pad_t = padding[2]; + int64_t pad_b = padding[3]; + + int nplane = input_.size(plane_dim); + int input_h = input_.size(dim_h); + int input_w = input_.size(dim_w); + + AT_CHECK(pad_l < input_w && pad_r < input_w, + "Padding size should be less than the corresponding input dimension, but " + "got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w, + " of input ", input_.sizes()); + + AT_CHECK(pad_t < input_h && pad_b < input_h, + "Padding size should be less than the corresponding input dimension, but " + "got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h, + " of input ", input_.sizes()); + + int output_h = input_h + pad_t + pad_b; + int output_w = input_w + pad_l + pad_r; + + AT_CHECK(output_w >= 1 || output_h >= 1, + "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated " + "output H: ", output_h, " W: ", output_w); + + if (input_.ndimension() == 3) { + output.resize_({nplane, output_h, output_w}); + } else { + output.resize_({nbatch, nplane, output_h, output_w}); + } + + Tensor input = input_.contiguous(); + + int output_plane_size = output_h * output_w; + dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size); + dim3 grid_size( + (int) std::ceil(output_plane_size/256.0), nplane, nbatch); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "reflection_pad2d_out_template", [&] { + reflection_pad2d_out_kernel<<< + grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( + input.data(), output.data(), + input_w, input_h, + pad_t, pad_b, pad_l, pad_r); + } + ); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void reflection_pad2d_backward_out_template( + Tensor &grad_input, const Tensor &grad_output_, + const Tensor &input, IntList padding) { + AT_CHECK(canUse32BitIndexMath(input), + "input tensor must fit into 32-bit index math"); + AT_CHECK(canUse32BitIndexMath(grad_output_), + "output gradient tensor must fit into 32-bit index math"); + + int plane_dim = 0; + int dim_h = 1; + int dim_w = 2; + int nbatch = 1; + + if (input.ndimension() == 4) { + nbatch = input.size(0); + plane_dim++; + dim_h++; + dim_w++; + } + + int64_t pad_l = padding[0]; + int64_t pad_r = padding[1]; + int64_t pad_t = padding[2]; + int64_t pad_b = padding[3]; + + int nplane = input.size(plane_dim); + int input_h = input.size(dim_h); + int input_w = input.size(dim_w); + + int output_h = input_h + pad_t + pad_b; + int output_w = input_w + pad_l + pad_r; + + AT_CHECK(output_w == grad_output_.size(dim_w), "grad_output width " + "unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w)); + AT_CHECK(output_h == grad_output_.size(dim_h), "grad_output height " + "unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h)); + + Tensor grad_output = grad_output_.contiguous(); + + int output_plane_size = output_h * output_w; + dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size); + dim3 grid_size( + (int) std::ceil(output_plane_size/256.0), nplane, nbatch); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "reflection_pad2d_backward_out_template", [&] { + reflection_pad2d_backward_out_kernel<<< + grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( + grad_input.data(), grad_output.data(), + input_w, input_h, + pad_t, pad_b, pad_l, pad_r); + } + ); + + AT_CUDA_CHECK(cudaGetLastError()); +} + } // namespace + Tensor& reflection_pad1d_out_cuda( Tensor& output, const Tensor& input, IntList padding) { reflection_pad1d_out_template(output, input, padding); @@ -207,5 +416,38 @@ Tensor reflection_pad1d_backward_cuda( return grad_input; } +Tensor& reflection_pad2d_out_cuda( + Tensor& output, const Tensor& input, IntList padding) { + reflection_pad2d_out_template(output, input, padding); + return output; +} + +Tensor reflection_pad2d_cuda(const Tensor& input, IntList padding) { + auto output = at::empty({0}, input.options()); + reflection_pad2d_out_template(output, input, padding); + return output; +} + +Tensor& reflection_pad2d_backward_out_cuda( + Tensor& grad_input, const Tensor& grad_output, + const Tensor& input, + IntList padding) { + grad_input.resize_as_(input); + grad_input.zero_(); + reflection_pad2d_backward_out_template( + grad_input, grad_output, input, padding); + return grad_input; +} + +Tensor reflection_pad2d_backward_cuda( + const Tensor& grad_output, + const Tensor& input, + IntList padding) { + auto grad_input = at::zeros_like(input); + reflection_pad2d_backward_out_template( + grad_input, grad_output, input, padding); + return grad_input; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d7a5b8b..6aaf0468 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3452,15 +3452,27 @@ - func: reflection_pad2d_out(Tensor output, Tensor self, IntList[4] padding) -> Tensor python_module: nn + dispatch: + CPU: reflection_pad2d_out_cpu + CUDA: reflection_pad2d_out_cuda - func: reflection_pad2d(Tensor self, IntList[4] padding) -> Tensor python_module: nn + dispatch: + CPU: reflection_pad2d_cpu + CUDA: reflection_pad2d_cuda - func: reflection_pad2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor python_module: nn + dispatch: + CPU: reflection_pad2d_backward_out_cpu + CUDA: reflection_pad2d_backward_out_cuda - func: reflection_pad2d_backward(Tensor grad_output, Tensor self, IntList[4] padding) -> Tensor python_module: nn + dispatch: + CPU: reflection_pad2d_backward_cpu + CUDA: reflection_pad2d_backward_cuda - func: replication_pad1d_out(Tensor output, Tensor self, IntList[2] padding) -> Tensor python_module: nn diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index a24a032..3f7ee96 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -188,14 +188,6 @@ output: 'false' grad_input: 'false' -# Padding - -- name: _thnn_reflection_pad2d(Tensor self, IntList[4] padding) - cname: SpatialReflectionPadding - scalar_check: - output: 'false' - grad_input: 'false' - # Upsampling # Note: The upsampling backwards functions also include an IntList input_size diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index e0b9ce8..d7f35a7 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -41,7 +41,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxUnpooling.cu -${CMAKE_CURRENT_SOURCE_DIR}/SpatialReflectionPadding.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialSubSampling.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBicubic.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBilinear.cu diff --git a/aten/src/THCUNN/SpatialReflectionPadding.cu b/aten/src/THCUNN/SpatialReflectionPadding.cu deleted file mode 100644 index 45d9dba..0000000 --- a/aten/src/THCUNN/SpatialReflectionPadding.cu +++ /dev/null @@ -1,87 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -template -__global__ void SpatialReflectionPadding_updateOutput( - THCDeviceTensor input, - THCDeviceTensor output, - int padT, int padB, int padL, int padR) { - - int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - if (outputPointId >= output.getSize(2) * output.getSize(3)) { - return; - } - int outputPointX = outputPointId % output.getSize(3); - int outputPointY = outputPointId / output.getSize(3); - - int iStartX = max(0, -padL); - int iStartY = max(0, -padT); - int oStartX = max(0, padL); - int oStartY = max(0, padT); - - int inputPointX = abs(outputPointX - padL) - - abs(outputPointX - (input.getSize(3) + padL - 1)) - - outputPointX - + 2 * padL + input.getSize(3) - 1 - - oStartX + iStartX; - - int inputPointY = abs(outputPointY - padT) - - abs(outputPointY - (input.getSize(2) + padT - 1)) - - outputPointY - + 2 * padT + input.getSize(2) - 1 - - oStartY + iStartY; - - Dtype valueToCopy = input[batch][plane][inputPointY][inputPointX]; - output[batch][plane][outputPointY][outputPointX] = valueToCopy; -} - -template -__global__ void SpatialReflectionPadding_updateGradInput( - THCDeviceTensor gradInput, - THCDeviceTensor gradOutput, - int padT, int padB, int padL, int padR) { - - int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - if (outputPointId >= gradOutput.getSize(2) * gradOutput.getSize(3)) { - return; - } - int outputPointX = outputPointId % gradOutput.getSize(3); - int outputPointY = outputPointId / gradOutput.getSize(3); - - int iStartX = max(0, -padL); - int iStartY = max(0, -padT); - int oStartX = max(0, padL); - int oStartY = max(0, padT); - - int inputPointX = abs(outputPointX - padL) - - abs(outputPointX - (gradInput.getSize(3) + padL - 1)) - - outputPointX - + 2 * padL + gradInput.getSize(3) - 1 - - oStartX + iStartX; - - int inputPointY = abs(outputPointY - padT) - - abs(outputPointY - (gradInput.getSize(2) + padT - 1)) - - outputPointY - + 2 * padT + gradInput.getSize(2) - 1 - - oStartY + iStartY; - - Dtype valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX]; - atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy); -} - -#include -#include diff --git a/aten/src/THCUNN/generic/SpatialReflectionPadding.cu b/aten/src/THCUNN/generic/SpatialReflectionPadding.cu deleted file mode 100644 index a6d6663..0000000 --- a/aten/src/THCUNN/generic/SpatialReflectionPadding.cu +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/SpatialReflectionPadding.cu" -#else - -void THNN_(SpatialReflectionPadding_updateOutput)(THCState *state, - THCTensor *input, - THCTensor *output, - int padL, int padR, - int padT, int padB) { - THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); - - int planeDim = 0; - int dimh = 1; - int dimw = 2; - int numBatch = 1; - - int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input, - "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s") - - if (numInputDims == 4) { - numBatch = THCTensor_(size)(state, input, 0); - planeDim++; - dimh++; - dimw++; - } - - int numPlanes = THCTensor_(size)(state, input, planeDim); - int inputH = THCTensor_(size)(state, input, dimh); - int inputW = THCTensor_(size)(state, input, dimw); - - THArgCheck(padL < inputW && padR < inputW, 4, - "Padding size should be less than the corresponding input dimension, " - "but got: padding (%d, %d) at dimension %d of input %s", - padL, padR, dimw, THCTensor_(sizeDesc)(state, input).str); - - THArgCheck(padT < inputH && padB < inputH, 6, - "Padding size should be less than the corresponding input dimension, " - "but got: padding (%d, %d) at dimension %d of input %s", - padT, padB, dimh, THCTensor_(sizeDesc)(state, input).str); - - int outputH = inputH + padT + padB; - int outputW = inputW + padL + padR; - - THArgCheck(outputW >= 1 || outputH >= 1, 2, - "input (H: %d, W: %d)is too small." - " Calculated output H: %d W: %d", - inputH, inputW, outputH, outputW); - - THCDeviceTensor devInput; - THCDeviceTensor devOutput; - - if (numInputDims == 3) { - THCTensor_(resize3d)(state, output, numPlanes, outputH, outputW); - - devInput = toDeviceTensor(state, input).upcastOuter<4>(); - devOutput = toDeviceTensor(state, output).upcastOuter<4>(); - } else { - THCTensor_(resize4d)(state, output, numBatch, numPlanes, outputH, outputW); - - devInput = toDeviceTensor(state, input); - devOutput = toDeviceTensor(state, output); - } - - int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3); - dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), - devOutput.getSize(1), - devOutput.getSize(0)); - dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); - - SpatialReflectionPadding_updateOutput<<>>( - devInput, devOutput, padT, padB, padL, padR); - THCudaCheck(cudaGetLastError()); -} - -void THNN_(SpatialReflectionPadding_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int padL, int padR, - int padT, int padB) { - - THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); - THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutput), 3, - "output gradient tensor must fit into 32-bit index math"); - - int planeDim = 0; - int dimh = 1; - int dimw = 2; - - int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - if (numInputDims == 4) { - planeDim++; - dimh++; - dimw++; - } - int iheight = input->size(dimh); - int iwidth = input->size(dimw); - int oheight = iheight + padT + padB; - int owidth = iwidth + padL + padR; - - THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3, - "gradOutput width unexpected. Expected: %d, Got: %d", - owidth, THCTensor_(size)(state, gradOutput, dimw)); - THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3, - "gradOutput height unexpected. Expected: %d, Got: %d", - oheight, THCTensor_(size)(state, gradOutput, dimh)); - - THCTensor_(resizeAs)(state, gradInput, input); - THCTensor_(zero)(state, gradInput); - - THCDeviceTensor devGradInput; - THCDeviceTensor devGradOutput; - - if (numInputDims == 3) { - devGradInput = toDeviceTensor(state, gradInput).upcastOuter<4>(); - devGradOutput = toDeviceTensor(state, gradOutput).upcastOuter<4>(); - } else { - devGradInput = toDeviceTensor(state, gradInput); - devGradOutput = toDeviceTensor(state, gradOutput); - } - - int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3); - dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), - devGradOutput.getSize(1), - devGradOutput.getSize(0)); - dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); - - SpatialReflectionPadding_updateGradInput<<>>( - devGradInput, devGradOutput, padT, padB, padL, padR); - THCudaCheck(cudaGetLastError()); -} - -#endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index fe3ef53..08fefca 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -847,21 +847,6 @@ THC_API void THNN_(SpatialMaxUnpooling_updateGradInput)( THCIndexTensor *indices, int owidth, int oheight); -THC_API void THNN_(SpatialReflectionPadding_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int padL, int padR, - int padT, int padB); - -THC_API void THNN_(SpatialReflectionPadding_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int padL, int padR, - int padT, int padB); - THC_API void THNN_(SpatialSubSampling_updateOutput)( THCState *state, THCTensor *input, diff --git a/aten/src/THNN/generic/SpatialReflectionPadding.c b/aten/src/THNN/generic/SpatialReflectionPadding.c deleted file mode 100644 index f7c240b..0000000 --- a/aten/src/THNN/generic/SpatialReflectionPadding.c +++ /dev/null @@ -1,270 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/SpatialReflectionPadding.c" -#else - -static void THNN_(SpatialReflectionPadding_updateOutput_frame)( - scalar_t *input_p, scalar_t *output_p, - int64_t nslices, - int64_t iwidth, int64_t iheight, - int64_t owidth, int64_t oheight, - int pad_l, int pad_r, - int pad_t, int pad_b) -{ - int iStartX = fmax(0, -pad_l); - int iStartY = fmax(0, -pad_t); - int oStartX = fmax(0, pad_l); - int oStartY = fmax(0, pad_t); - - int64_t k, ip_x, ip_y; -#pragma omp parallel for private(k, ip_x, ip_y) - - for (k = 0; k < nslices; k++) - { - int64_t i, j; - for (i = 0; i < oheight; i++) { - for (j = 0; j < owidth; j++) { - if (j < pad_l) { - ip_x = pad_l * 2 - j; - } else if (j >= pad_l && j < iwidth + pad_l) { - ip_x = j; - } else { - ip_x = (iwidth + pad_l - 1) * 2 - j; - } - ip_x = ip_x - oStartX + iStartX; - - if (i < pad_t) { - ip_y = pad_t * 2 - i; - } else if (i >= pad_t && i < iheight + pad_t) { - ip_y = i; - } else { - ip_y = (iheight + pad_t - 1) * 2 - i; - } - ip_y = ip_y - oStartY + iStartY; - - scalar_t *dest_p = output_p + k*owidth*oheight + i * owidth + j; - scalar_t *src_p = input_p + k*iwidth*iheight + ip_y * iwidth + ip_x; - *dest_p = *src_p; - } - } - } -} - -void THNN_(SpatialReflectionPadding_updateOutput)(THNNState *state, - THTensor *input, - THTensor *output, - int pad_l, int pad_r, - int pad_t, int pad_b) -{ - int dimw = 2; - int dimh = 1; - int dimslices = 0; - int64_t nbatch = 1; - int64_t nslices; - int64_t iheight; - int64_t iwidth; - int64_t oheight; - int64_t owidth; - scalar_t *input_data; - scalar_t *output_data; - - THNN_ARGCHECK(!input->is_empty() && (input->dim() == 3 || input->dim() == 4), 2, input, - "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s"); - - if (input->dim() == 4) - { - nbatch = input->size(0); - dimw++; - dimh++; - dimslices++; - } - - /* input sizes */ - nslices = input->size(dimslices); - iheight = input->size(dimh); - iwidth = input->size(dimw); - - AT_CHECK(pad_l < iwidth && pad_r < iwidth, - "Argument #4: Padding size should be less than the corresponding input dimension, " - "but got: padding (", pad_l, ", ", pad_r, ") at dimension ", dimw, " of input ", input->sizes()); - - AT_CHECK(pad_t < iheight && pad_b < iheight, - "Argument #6: Padding size should be less than the corresponding input dimension, " - "but got: padding (", pad_t, ", ", pad_b, ") at dimension ", dimh, " of input ", input->sizes()); - - /* output sizes */ - oheight = iheight + pad_t + pad_b; - owidth = iwidth + pad_l + pad_r; - - THArgCheck(owidth >= 1 || oheight >= 1 , 2, - "input (H: %d, W: %d)is too small." - " Calculated output H: %d W: %d", - iheight, iwidth, oheight, owidth); - - /* get contiguous input */ - input = THTensor_(newContiguous)(input); - - /* resize output */ - if (input->dim() == 3) - { - THTensor_(resize3d)(output, nslices, oheight, owidth); - - input_data = input->data(); - output_data = output->data(); - - THNN_(SpatialReflectionPadding_updateOutput_frame)(input_data, output_data, - nslices, - iwidth, iheight, - owidth, oheight, - pad_l, pad_r, - pad_t, pad_b); - } - else - { - int64_t p; - - THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth); - - input_data = input->data(); - output_data = output->data(); - -#pragma omp parallel for private(p) - for (p = 0; p < nbatch; p++) - { - THNN_(SpatialReflectionPadding_updateOutput_frame)( - input_data+p*nslices*iwidth*iheight, - output_data+p*nslices*owidth*oheight, - nslices, - iwidth, iheight, - owidth, oheight, - pad_l, pad_r, - pad_t, pad_b); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(input); -} - -static void THNN_(SpatialReflectionPadding_updateGradInput_frame)( - scalar_t *ginput_p, scalar_t *goutput_p, - int64_t nslices, - int64_t iwidth, int64_t iheight, - int64_t owidth, int64_t oheight, - int pad_l, int pad_r, - int pad_t, int pad_b) -{ - int iStartX = fmax(0, -pad_l); - int iStartY = fmax(0, -pad_t); - int oStartX = fmax(0, pad_l); - int oStartY = fmax(0, pad_t); - - int64_t k, ip_x, ip_y; -#pragma omp parallel for private(k, ip_x, ip_y) - - for (k = 0; k < nslices; k++) - { - int64_t i, j; - for (i = 0; i < oheight; i++) { - for (j = 0; j < owidth; j++) { - if (j < pad_l) { - ip_x = pad_l * 2 - j; - } else if (j >= pad_l && j < iwidth + pad_l) { - ip_x = j; - } else { - ip_x = (iwidth + pad_l - 1) * 2 - j; - } - ip_x = ip_x - oStartX + iStartX; - - if (i < pad_t) { - ip_y = pad_t * 2 - i; - } else if (i >= pad_t && i < iheight + pad_t) { - ip_y = i; - } else { - ip_y = (iheight + pad_t - 1) * 2 - i; - } - ip_y = ip_y - oStartY + iStartY; - - scalar_t *src_p = goutput_p + k*owidth*oheight + i * owidth + j; - scalar_t *dest_p = ginput_p + k*iwidth*iheight + ip_y * iwidth + ip_x; - *dest_p += *src_p; - } - } - } -} - -void THNN_(SpatialReflectionPadding_updateGradInput)(THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int pad_l, int pad_r, - int pad_t, int pad_b) -{ - int dimw = 2; - int dimh = 1; - int dimslices = 0; - int64_t nbatch = 1; - int64_t nslices; - int64_t iheight; - int64_t iwidth; - int64_t oheight; - int64_t owidth; - - if (input->dim() == 4) - { - nbatch = input->size(0); - dimw++; - dimh++; - dimslices++; - } - - /* sizes */ - nslices = input->size(dimslices); - iheight = input->size(dimh); - iwidth = input->size(dimw); - oheight = iheight + pad_t + pad_b; - owidth = iwidth + pad_l + pad_r; - - THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3, - "gradOutput width unexpected. Expected: %d, Got: %d", - owidth, THTensor_(size)(gradOutput, dimw)); - THArgCheck(oheight == THTensor_(size)(gradOutput, dimh), 3, - "gradOutput height unexpected. Expected: %d, Got: %d", - oheight, THTensor_(size)(gradOutput, dimh)); - - /* get contiguous gradOutput */ - gradOutput = THTensor_(newContiguous)(gradOutput); - - /* resize */ - THTensor_(resizeAs)(gradInput, input); - THTensor_(zero)(gradInput); - - /* backprop */ - if (input->dim() == 3) { - THNN_(SpatialReflectionPadding_updateGradInput_frame)( - gradInput->data(), - gradOutput->data(), - nslices, - iwidth, iheight, - owidth, oheight, - pad_l, pad_r, - pad_t, pad_b); - } else { - int64_t p; -#pragma omp parallel for private(p) - for (p = 0; p < nbatch; p++) { - THNN_(SpatialReflectionPadding_updateGradInput_frame)( - gradInput->data() + p * nslices * iheight * iwidth, - gradOutput->data() + p * nslices * oheight * owidth, - nslices, - iwidth, iheight, - owidth, oheight, - pad_l, pad_r, - pad_t, pad_b); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(gradOutput); -} - -#endif diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index 355c819..f98077c 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -923,21 +923,6 @@ TH_API void THNN_(VolumetricAdaptiveMaxPooling_updateGradInput)( THTensor *gradInput, THIndexTensor *indices); -TH_API void THNN_(SpatialReflectionPadding_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int pad_left, int pad_right, - int pad_top, int pad_bottom); - -TH_API void THNN_(SpatialReflectionPadding_updateGradInput)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int pad_left, int pad_right, - int pad_top, int pad_bottom); - TH_API void THNN_(FeatureLPPooling_updateOutput)( THNNState *state, THTensor *input, diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index 845374e..9120420 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -202,9 +202,6 @@ #include #include -#include -#include - #include #include diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py index f18f60e..2b12ffc 100644 --- a/torch/nn/_functions/thnn/auto.py +++ b/torch/nn/_functions/thnn/auto.py @@ -306,7 +306,6 @@ def _generate_function_classes(scope_dict): 'TemporalConvolution': 'Conv1d', 'SpatialDilatedConvolution': 'DilatedConv2d', 'SpatialMaxUnpooling': 'MaxUnpool2d', - 'SpatialReflectionPadding': 'ReflectionPad2d', 'VolumetricMaxUnpooling': 'MaxUnpool3d', 'HardTanh': 'Hardtanh', 'HardShrink': 'Hardshrink', -- 2.7.4