From b7bc49ad703faf3ea92aaef804351126e63158da Mon Sep 17 00:00:00 2001 From: Lin Huang Date: Mon, 24 Dec 2018 06:29:34 -0800 Subject: [PATCH] Port replication_pad1d to ATen (#15507) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15507 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15485 port replication_pad1d Reviewed By: ezyang Differential Revision: D13531920 fbshipit-source-id: dcd64ebd2c24b7431996231b8d5addfb600b1072 --- aten/src/ATen/native/LegacyNNDefinitions.cpp | 16 -- aten/src/ATen/native/ReplicationPadding.cpp | 272 +++++++++++++++++++++ aten/src/ATen/native/cuda/ReplicationPadding.cu | 254 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 12 + aten/src/ATen/nn.yaml | 6 - aten/src/THCUNN/CMakeLists.txt | 1 - aten/src/THCUNN/TemporalReplicationPadding.cu | 62 ----- aten/src/THCUNN/generic/THCUNN.h | 13 - .../THCUNN/generic/TemporalReplicationPadding.cu | 114 --------- aten/src/THNN/generic/THNN.h | 13 - aten/src/THNN/generic/TemporalReplicationPadding.c | 211 ---------------- aten/src/THNN/init.cpp | 3 - torch/nn/_functions/thnn/auto.py | 1 - 13 files changed, 538 insertions(+), 440 deletions(-) create mode 100644 aten/src/ATen/native/ReplicationPadding.cpp create mode 100644 aten/src/ATen/native/cuda/ReplicationPadding.cu delete mode 100644 aten/src/THCUNN/TemporalReplicationPadding.cu delete mode 100644 aten/src/THCUNN/generic/TemporalReplicationPadding.cu delete mode 100644 aten/src/THNN/generic/TemporalReplicationPadding.c diff --git a/aten/src/ATen/native/LegacyNNDefinitions.cpp b/aten/src/ATen/native/LegacyNNDefinitions.cpp index e456663..cbb29d6 100644 --- a/aten/src/ATen/native/LegacyNNDefinitions.cpp +++ b/aten/src/ATen/native/LegacyNNDefinitions.cpp @@ -524,22 +524,6 @@ Tensor reflection_pad2d_backward(const Tensor & grad_output, const Tensor & self return at::legacy::th::_thnn_reflection_pad2d_backward(grad_output, self, padding); } -Tensor & replication_pad1d_out(Tensor & output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_replication_pad1d_forward_out(output, self, padding); -} - -Tensor replication_pad1d(const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_replication_pad1d_forward(self, padding); -} - -Tensor & replication_pad1d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_replication_pad1d_backward_out(grad_input, grad_output, self, padding); -} - -Tensor replication_pad1d_backward(const Tensor & grad_output, const Tensor & self, IntList padding) { - return at::legacy::th::_thnn_replication_pad1d_backward(grad_output, self, padding); -} - Tensor & replication_pad2d_out(Tensor & output, const Tensor & self, IntList padding) { return at::legacy::th::_thnn_replication_pad2d_forward_out(output, self, padding); } diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp new file mode 100644 index 0000000..5b84238 --- /dev/null +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -0,0 +1,272 @@ +#include "ATen/ATen.h" +#include "ATen/NativeFunctions.h" + +namespace at { +namespace native { + +namespace { +template +static void replication_pad1d_out_frame( + scalar_t *input_p, scalar_t *output_p, + long nslices, + long iwidth, + long owidth, + int pad_l, int pad_r) +{ + int iStartX = fmax(0, -pad_l); + int oStartX = fmax(0, pad_l); + + long k, ip_x; +#pragma omp parallel for private(k, ip_x) + for (k = 0; k < nslices; k++) + { + long j; + for (j = 0; j < owidth; j++) { + if (j < pad_l) { + ip_x = pad_l; + } else if (j >= pad_l && j < iwidth + pad_l) { + ip_x = j; + } else { + ip_x = iwidth + pad_l - 1; + } + ip_x = ip_x - oStartX + iStartX; + + scalar_t *dest_p = output_p + k*owidth + j; + scalar_t *src_p = input_p + k*iwidth + ip_x; + *dest_p = *src_p; + } + } +} + +void replication_pad1d_out_cpu_template( + at::Tensor& output, + at::Tensor const& input, + IntList paddingSize) +{ + int dimw = 1; + int dimslices = 0; + long nbatch = 1; + long nslices; + long iwidth; + long owidth; + int pad_l = paddingSize[0]; + int pad_r = paddingSize[1]; + + AT_CHECK(input.numel() > 0 + && (input.ndimension() == 2 || input.ndimension() == 3), + "non-empty 2D or 3D (batch mode) tensor expected for input"); + + if (input.ndimension() == 3) + { + nbatch = input.size(0); + dimw++; + dimslices++; + } + + /* sizes */ + nslices = input.size(dimslices); + iwidth = input.size(dimw); + owidth = iwidth + pad_l + pad_r; + + AT_CHECK(owidth >= 1, + "input (W: ", iwidth, ") is too small." + " Calculated output W: ", owidth); + + + /* get contiguous input */ + auto input_ = input.contiguous(); + + /* resize output */ + if (input_.ndimension() == 2) + { + output.resize_({nslices, owidth}); + AT_DISPATCH_FLOATING_TYPES(input_.type(), "replication_pad1d", [&] { + auto input_data = input_.data(); + auto output_data = output.data(); + replication_pad1d_out_frame (input_data, output_data, + nslices, + iwidth, + owidth, + pad_l, pad_r); + } + ); + } + else + { + long p; + + output.resize_({nbatch, nslices, owidth}); + +#pragma omp parallel for private(p) + for (p = 0; p < nbatch; p++) + { + AT_DISPATCH_FLOATING_TYPES(input_.type(), "replication_pad1d", [&] { + auto input_data = input_.data(); + auto output_data = output.data(); + replication_pad1d_out_frame( + input_data+p*nslices*iwidth, + output_data+p*nslices*owidth, + nslices, + iwidth, + owidth, + pad_l, pad_r); + } + ); + } + } +} + +template +static void replication_pad1d_backward_out_frame( + scalar_t *ginput_p, scalar_t *goutput_p, + long nslices, + long iwidth, + long owidth, + int pad_l, int pad_r) +{ + int iStartX = fmax(0, -pad_l); + int oStartX = fmax(0, pad_l); + + long k, ip_x; +#pragma omp parallel for private(k, ip_x) + for (k = 0; k < nslices; k++) + { + long j; + for (j = 0; j < owidth; j++) { + if (j < pad_l) { + ip_x = pad_l; + } else if (j >= pad_l && j < iwidth + pad_l) { + ip_x = j; + } else { + ip_x = iwidth + pad_l - 1; + } + ip_x = ip_x - oStartX + iStartX; + + scalar_t *src_p = goutput_p + k*owidth + j; + scalar_t *dest_p = ginput_p + k*iwidth + ip_x; + *dest_p += *src_p; + } + } +} + +Tensor& replication_pad1d_backward_out_cpu_template( + Tensor& gradInput, + const Tensor& gradOutput_, + const Tensor& input, + IntList paddingSize) +{ + int dimw = 1; + int dimslices = 0; + long nbatch = 1; + long nslices; + long iwidth; + long owidth; + int pad_l = paddingSize[0]; + int pad_r = paddingSize[1]; + + if (input.ndimension() == 3) + { + nbatch = input.size(0); + dimw++; + dimslices++; + } + + /* sizes */ + nslices = input.size(dimslices); + iwidth = input.size(dimw); + owidth = iwidth + pad_l + pad_r; + + AT_CHECK(owidth == gradOutput_.size(dimw), + "gradOutput width unexpected. Expected: ", owidth, + " Got: ", gradOutput_.size(dimw)); + + /* get contiguous gradOutput */ + auto gradOutput = gradOutput_.contiguous(); + gradInput.resize_as_(input); + gradInput.zero_(); + + /* backprop */ + if (input.ndimension() == 2) { + AT_DISPATCH_FLOATING_TYPES( + input.type(), "replication_pad1d_backward", [&] { + scalar_t *gradInput_data = gradInput.data(); + scalar_t *gradOutput_data = gradOutput.data(); + + replication_pad1d_backward_out_frame ( + gradInput_data, + gradOutput_data, + nslices, + iwidth, + owidth, + pad_l, pad_r); + } + ); + } else { + long p; +#pragma omp parallel for private(p) + for (p = 0; p < nbatch; p++) { + AT_DISPATCH_FLOATING_TYPES( + input.type(), "replication_pad1d_backward", [&] { + scalar_t *gradInput_data = gradInput.data(); + scalar_t *gradOutput_data = gradOutput.data(); + + replication_pad1d_backward_out_frame( + gradInput_data + p * nslices * iwidth, + gradOutput_data + p * nslices * owidth, + nslices, + iwidth, + owidth, + pad_l, pad_r); + } + ); + } + } + return gradInput; +} +} // namespace + +Tensor& replication_pad1d_out_cpu( + Tensor& output, + const Tensor& input, + IntList paddingSize) +{ + replication_pad1d_out_cpu_template( + output, input, paddingSize); + return output; +} + +Tensor replication_pad1d_cpu( + at::Tensor const& input, + IntList paddingSize) +{ + auto output = at::empty({0}, input.options()); + replication_pad1d_out_cpu_template( + output, input, paddingSize); + return output; +} + +Tensor& replication_pad1d_backward_out_cpu( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input, + IntList paddingSize) +{ + gradInput.resize_as_(input); + replication_pad1d_backward_out_cpu_template( + gradInput, gradOutput, input, paddingSize); + return gradInput; +} + +Tensor replication_pad1d_backward_cpu( + const Tensor& gradOutput, + const Tensor& input, + IntList paddingSize) +{ + auto gradInput = at::zeros_like(input); + replication_pad1d_backward_out_cpu_template( + gradInput, gradOutput, input, paddingSize); + return gradInput; +} + +} // at::native +} // at diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu new file mode 100644 index 0000000..23dc9ce --- /dev/null +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -0,0 +1,254 @@ +#include "ATen/ATen.h" +#include "ATen/cuda/CUDAApplyUtils.cuh" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/NativeFunctions.h" +#include "ATen/TensorUtils.h" +#include "ATen/Utils.h" +#include "c10/util/Exception.h" +#include +#include "THC/THCNumerics.cuh" +#include "THC/THCDeviceUtils.cuh" + +#include +#include +#include + + +namespace at { +namespace native { +__host__ __device__ __forceinline__ int imin(int a, int b) { + return a > b ? b : a; +} + +__host__ __device__ __forceinline__ int imax(int a, int b) { + return a > b ? a : b; +} + +__host__ __device__ __forceinline__ int iabs(int a) { + return a >= 0 ? a : -a; +} + +namespace { +template +__global__ void replication_pad_forward_kernel( + PackedTensorAccessor input, + PackedTensorAccessor output, + int padL, int padR) { + + int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + if (outputPointId >= output.size(2)) { + return; + } + int outputPointX = outputPointId % output.size(2); + + int iStartX = imax(0, -padL); + int oStartX = imax(0, padL); + + int inputPointX = imin(imax(padL, outputPointX), input.size(2) + padL - 1) - oStartX + iStartX; + + scalar_t valueToCopy = input[batch][plane][inputPointX]; + output[batch][plane][outputPointX] = valueToCopy; +} + +template +__global__ void replication_pad_backward_kernel( + PackedTensorAccessor gradInput, + PackedTensorAccessor gradOutput, + int padL, int padR) { + + int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + if (outputPointId >= gradOutput.size(2)) { + return; + } + int outputPointX = outputPointId % gradOutput.size(2); + + int iStartX = imax(0, -padL); + int oStartX = imax(0, padL); + + int inputPointX = imin(imax(padL, outputPointX), gradInput.size(2) + padL - 1) - oStartX + iStartX; + + scalar_t valueToCopy = gradOutput[batch][plane][outputPointX]; + atomicAdd(&gradInput[batch][plane][inputPointX], valueToCopy); +} + +void replication_pad1d_out_cuda_template( + Tensor& output, + const Tensor& input, + IntList paddingSize) +{ + AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), + "input tensor must fit into 32-bit index math"); + + int padL = paddingSize[0]; + int padR = paddingSize[1]; + int planeDim = 0; + int dimw = 1; + int numBatch = 1; + + int numInputDims = input.ndimension(); + AT_CHECK(input.numel() > 0 && (numInputDims == 2 || numInputDims == 3), + "2D or 3D (batch mode) tensor expected for input") + + if (numInputDims == 3) { + numBatch = input.size(0); + planeDim++; + dimw++; + } + + int numPlanes = input.size(planeDim); + int inputW = input.size(dimw); + int outputW = inputW + padL + padR; + + AT_CHECK(outputW >= 1, + "input (W: ", inputW, ")is too small." + " Calculated output W: ", outputW); + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "replication_pad1d", [&] { + + + if (numInputDims == 2) { + output.resize_({numPlanes, outputW}); + auto input_ = input.reshape({1, input.size(0), input.size(1)}); + auto output_ = output.reshape({1, output.size(0), output.size(1)}); + auto devInput = input_.packed_accessor(); + auto devOutput = output_.packed_accessor(); + + int outputPlaneSize = devOutput.size(2); + dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), + devOutput.size(1), + devOutput.size(0)); + dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); + + replication_pad_forward_kernel <<>>(devInput, devOutput, padL, padR); + } else { + output.resize_({numBatch, numPlanes, outputW}); + auto devInput = input.packed_accessor(); + auto devOutput = output.packed_accessor(); + + int outputPlaneSize = devOutput.size(2); + dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), + devOutput.size(1), + devOutput.size(0)); + dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); + + replication_pad_forward_kernel <<>>(devInput, devOutput, padL, padR); + } + } + ); + THCudaCheck(cudaGetLastError()); +} + +void replication_pad1d_backward_out_cuda_template( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input, + IntList paddingSize) +{ + + AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), + "input tensor must fit into 32-bit index math"); + AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput), + "output gradient tensor must fit into 32-bit index math"); + + int padL = paddingSize[0]; + int padR = paddingSize[1]; + int planeDim = 0; + int dimw = 1; + + int numInputDims = input.ndimension(); + if (numInputDims == 3) { + planeDim++; + dimw++; + } + int iwidth = input.size(dimw); + int owidth = iwidth + padL + padR; + + AT_CHECK(owidth == gradOutput.size(dimw), + "gradOutput width unexpected. Expected: ", owidth, ", Got: ", + gradOutput.size(dimw)); + + gradInput.resize_as_(input); + gradInput.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "replication_pad1d_backward", [&] { + + auto gradInput_ = gradInput; + auto gradOutput_ = gradOutput; + if (numInputDims == 2) { + gradInput_ = gradInput.reshape({1, gradInput.size(0), + gradInput.size(1)}); + gradOutput_ = gradOutput.reshape({1, gradOutput.size(0), + gradOutput.size(1)}); + } + auto devGradInput = gradInput_.packed_accessor(); + auto devGradOutput = gradOutput_.packed_accessor(); + + int outputPlaneSize = devGradOutput.size(2); + dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), + devGradOutput.size(1), + devGradOutput.size(0)); + dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); + + replication_pad_backward_kernel <<>>(devGradInput, devGradOutput, + padL, padR); + } + ); + THCudaCheck(cudaGetLastError()); +} +} // namespace + +Tensor& replication_pad1d_out_cuda( + Tensor& output, + const Tensor& input, + IntList paddingSize) +{ + replication_pad1d_out_cuda_template( + output, input, paddingSize); + return output; +} + +Tensor replication_pad1d_cuda( + at::Tensor const& input, + IntList paddingSize) +{ + auto output = at::empty({0}, input.options()); + replication_pad1d_out_cuda_template( + output, input, paddingSize); + return output; +} + +Tensor& replication_pad1d_backward_out_cuda( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input, + IntList paddingSize) +{ + gradInput.resize_as_(input); + replication_pad1d_backward_out_cuda_template( + gradInput, gradOutput, input, paddingSize); + return gradInput; +} + +Tensor replication_pad1d_backward_cuda( + const Tensor& gradOutput, + const Tensor& input, + IntList paddingSize) +{ + auto gradInput = at::zeros_like(input); + replication_pad1d_backward_out_cuda_template( + gradInput, gradOutput, input, paddingSize); + return gradInput; +} + +} // at::native +} // at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 400497a..0710bab 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3434,15 +3434,27 @@ - func: replication_pad1d_out(Tensor output, Tensor self, IntList[2] padding) -> Tensor python_module: nn + dispatch: + CPU: replication_pad1d_out_cpu + CUDA: replication_pad1d_out_cuda - func: replication_pad1d(Tensor self, IntList[2] padding) -> Tensor python_module: nn + dispatch: + CPU: replication_pad1d_cpu + CUDA: replication_pad1d_cuda - func: replication_pad1d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[2] padding) -> Tensor python_module: nn + dispatch: + CPU: replication_pad1d_backward_out_cpu + CUDA: replication_pad1d_backward_out_cuda - func: replication_pad1d_backward(Tensor grad_output, Tensor self, IntList[2] padding) -> Tensor python_module: nn + dispatch: + CPU: replication_pad1d_backward_cpu + CUDA: replication_pad1d_backward_cuda - func: replication_pad2d_out(Tensor output, Tensor self, IntList[4] padding) -> Tensor python_module: nn diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index 72f9505..190e371 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -202,12 +202,6 @@ output: 'false' grad_input: 'false' -- name: _thnn_replication_pad1d(Tensor self, IntList[2] padding) - cname: TemporalReplicationPadding - scalar_check: - output: 'false' - grad_input: 'false' - - name: _thnn_replication_pad2d(Tensor self, IntList[4] padding) cname: SpatialReplicationPadding scalar_check: diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index 237a3ff..00c66e3 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -53,7 +53,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/Tanh.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalMaxPooling.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalReflectionPadding.cu -${CMAKE_CURRENT_SOURCE_DIR}/TemporalReplicationPadding.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalRowConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalUpSamplingLinear.cu ${CMAKE_CURRENT_SOURCE_DIR}/TemporalUpSamplingNearest.cu diff --git a/aten/src/THCUNN/TemporalReplicationPadding.cu b/aten/src/THCUNN/TemporalReplicationPadding.cu deleted file mode 100644 index 28d7c17..0000000 --- a/aten/src/THCUNN/TemporalReplicationPadding.cu +++ /dev/null @@ -1,62 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -template -__global__ void TemporalReplicationPadding_updateOutput( - THCDeviceTensor input, - THCDeviceTensor output, - int padL, int padR) { - - int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - if (outputPointId >= output.getSize(2)) { - return; - } - int outputPointX = outputPointId % output.getSize(2); - - int iStartX = max(0, -padL); - int oStartX = max(0, padL); - - int inputPointX = min(max(padL, outputPointX), input.getSize(2) + padL - 1) - oStartX + iStartX; - - Dtype valueToCopy = input[batch][plane][inputPointX]; - output[batch][plane][outputPointX] = valueToCopy; -} - -template -__global__ void TemporalReplicationPadding_updateGradInput( - THCDeviceTensor gradInput, - THCDeviceTensor gradOutput, - int padL, int padR) { - - int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - if (outputPointId >= gradOutput.getSize(2)) { - return; - } - int outputPointX = outputPointId % gradOutput.getSize(2); - - int iStartX = max(0, -padL); - int oStartX = max(0, padL); - - int inputPointX = min(max(padL, outputPointX), gradInput.getSize(2) + padL - 1) - oStartX + iStartX; - - Dtype valueToCopy = gradOutput[batch][plane][outputPointX]; - atomicAdd(&gradInput[batch][plane][inputPointX], valueToCopy); -} - - -#include -#include diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index b91d941..af3388b 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -1170,19 +1170,6 @@ THC_API void THNN_(TemporalReflectionPadding_updateGradInput)( THCTensor *gradInput, int padL, int padR); -THC_API void THNN_(TemporalReplicationPadding_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int padL, int padR); - -THC_API void THNN_(TemporalReplicationPadding_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int padL, int padR); - THC_API void THNN_(TemporalUpSamplingLinear_updateOutput)( THCState *state, THCTensor *input, diff --git a/aten/src/THCUNN/generic/TemporalReplicationPadding.cu b/aten/src/THCUNN/generic/TemporalReplicationPadding.cu deleted file mode 100644 index b9b0a99..0000000 --- a/aten/src/THCUNN/generic/TemporalReplicationPadding.cu +++ /dev/null @@ -1,114 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/TemporalReplicationPadding.cu" -#else - -void THNN_(TemporalReplicationPadding_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int padL, int padR) { - THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); - - int planeDim = 0; - int dimw = 1; - int numBatch = 1; - - int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 2 || numInputDims == 3), 2, input, - "2D or 3D (batch mode) tensor expected for input, but got: %s") - - if (numInputDims == 3) { - numBatch = THCTensor_(size)(state, input, 0); - planeDim++; - dimw++; - } - - int numPlanes = THCTensor_(size)(state, input, planeDim); - int inputW = THCTensor_(size)(state, input, dimw); - int outputW = inputW + padL + padR; - - THArgCheck(outputW >= 1, 2, - "input (W: %d)is too small." - " Calculated output W: %d", - inputW, outputW); - - THCDeviceTensor devInput; - THCDeviceTensor devOutput; - - if (numInputDims == 2) { - THCTensor_(resize2d)(state, output, numPlanes, outputW); - - devInput = toDeviceTensor(state, input).upcastOuter<3>(); - devOutput = toDeviceTensor(state, output).upcastOuter<3>(); - } else { - THCTensor_(resize3d)(state, output, numBatch, numPlanes, outputW); - - devInput = toDeviceTensor(state, input); - devOutput = toDeviceTensor(state, output); - } - - int outputPlaneSize = devOutput.getSize(2); - dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), - devOutput.getSize(1), - devOutput.getSize(0)); - dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); - - TemporalReplicationPadding_updateOutput<<>>( - devInput, devOutput, padL, padR); - -} - -void THNN_(TemporalReplicationPadding_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int padL, int padR) { - - THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); - THArgCheck(THCTensor_canUse32BitIndexMath(state, gradOutput), 3, - "output gradient tensor must fit into 32-bit index math"); - - int planeDim = 0; - int dimw = 1; - - int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - if (numInputDims == 3) { - planeDim++; - dimw++; - } - int iwidth = input->size(dimw); - int owidth = iwidth + padL + padR; - - THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3, - "gradOutput width unexpected. Expected: %d, Got: %d", - owidth, THCTensor_(size)(state, gradOutput, dimw)); - - THCTensor_(resizeAs)(state, gradInput, input); - THCTensor_(zero)(state, gradInput); - - THCDeviceTensor devGradInput; - THCDeviceTensor devGradOutput; - - if (numInputDims == 2) { - devGradInput = toDeviceTensor(state, gradInput).upcastOuter<3>(); - devGradOutput = toDeviceTensor(state, gradOutput).upcastOuter<3>(); - } else { - devGradInput = toDeviceTensor(state, gradInput); - devGradOutput = toDeviceTensor(state, gradOutput); - } - - int outputPlaneSize = devGradOutput.getSize(2); - dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), - devGradOutput.getSize(1), - devGradOutput.getSize(0)); - dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); - - TemporalReplicationPadding_updateGradInput<<>>( - devGradInput, devGradOutput, padL, padR); - -} - -#endif diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index 0b3a459..c0ed467 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -1047,19 +1047,6 @@ TH_API void THNN_(TemporalReflectionPadding_updateGradInput)( THTensor *gradInput, int pad_left, int pad_right); -TH_API void THNN_(TemporalReplicationPadding_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int pad_left, int pad_right); - -TH_API void THNN_(TemporalReplicationPadding_updateGradInput)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int pad_left, int pad_right); - TH_API void THNN_(Tanh_updateOutput)( THNNState *state, THTensor *input, diff --git a/aten/src/THNN/generic/TemporalReplicationPadding.c b/aten/src/THNN/generic/TemporalReplicationPadding.c deleted file mode 100644 index e8ffe21..0000000 --- a/aten/src/THNN/generic/TemporalReplicationPadding.c +++ /dev/null @@ -1,211 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/TemporalReplicationPadding.c" -#else - -static void THNN_(TemporalReplicationPadding_updateOutput_frame)( - scalar_t *input_p, scalar_t *output_p, - long nslices, - long iwidth, - long owidth, - int pad_l, int pad_r) -{ - int iStartX = fmax(0, -pad_l); - int oStartX = fmax(0, pad_l); - - long k, ip_x; -#pragma omp parallel for private(k, ip_x) - for (k = 0; k < nslices; k++) - { - long j; - for (j = 0; j < owidth; j++) { - if (j < pad_l) { - ip_x = pad_l; - } else if (j >= pad_l && j < iwidth + pad_l) { - ip_x = j; - } else { - ip_x = iwidth + pad_l - 1; - } - ip_x = ip_x - oStartX + iStartX; - - scalar_t *dest_p = output_p + k*owidth + j; - scalar_t *src_p = input_p + k*iwidth + ip_x; - *dest_p = *src_p; - } - } -} - -void THNN_(TemporalReplicationPadding_updateOutput)(THNNState *state, - THTensor *input, - THTensor *output, - int pad_l, int pad_r) -{ - int dimw = 1; - int dimslices = 0; - long nbatch = 1; - long nslices; - long iwidth; - long owidth; - scalar_t *input_data; - scalar_t *output_data; - - THNN_ARGCHECK(!input->is_empty() && (input->dim() == 2 || input->dim() == 3), 2, input, - "non-empty 2D or 3D (batch mode) tensor expected for input, but got: %s"); - - if (input->dim() == 3) - { - nbatch = input->size(0); - dimw++; - dimslices++; - } - - /* sizes */ - nslices = input->size(dimslices); - iwidth = input->size(dimw); - owidth = iwidth + pad_l + pad_r; - - THArgCheck(owidth >= 1 , 2, - "input (W: %d)is too small." - " Calculated output W: %d", - iwidth, owidth); - - - /* get contiguous input */ - input = THTensor_(newContiguous)(input); - - /* resize output */ - if (input->dim() == 2) - { - THTensor_(resize2d)(output, nslices, owidth); - - input_data = input->data(); - output_data = output->data(); - - THNN_(TemporalReplicationPadding_updateOutput_frame)(input_data, output_data, - nslices, - iwidth, - owidth, - pad_l, pad_r); - } - else - { - long p; - - THTensor_(resize3d)(output, nbatch, nslices, owidth); - - input_data = input->data(); - output_data = output->data(); - -#pragma omp parallel for private(p) - for (p = 0; p < nbatch; p++) - { - THNN_(TemporalReplicationPadding_updateOutput_frame)( - input_data+p*nslices*iwidth, - output_data+p*nslices*owidth, - nslices, - iwidth, - owidth, - pad_l, pad_r); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(input); -} - -static void THNN_(TemporalReplicationPadding_updateGradInput_frame)( - scalar_t *ginput_p, scalar_t *goutput_p, - long nslices, - long iwidth, - long owidth, - int pad_l, int pad_r) -{ - int iStartX = fmax(0, -pad_l); - int oStartX = fmax(0, pad_l); - - long k, ip_x; -#pragma omp parallel for private(k, ip_x) - for (k = 0; k < nslices; k++) - { - long j; - for (j = 0; j < owidth; j++) { - if (j < pad_l) { - ip_x = pad_l; - } else if (j >= pad_l && j < iwidth + pad_l) { - ip_x = j; - } else { - ip_x = iwidth + pad_l - 1; - } - ip_x = ip_x - oStartX + iStartX; - - scalar_t *src_p = goutput_p + k*owidth + j; - scalar_t *dest_p = ginput_p + k*iwidth + ip_x; - *dest_p += *src_p; - } - } -} - -void THNN_(TemporalReplicationPadding_updateGradInput)(THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int pad_l, int pad_r) -{ - int dimw = 1; - int dimslices = 0; - long nbatch = 1; - long nslices; - long iwidth; - long owidth; - - if (input->dim() == 3) - { - nbatch = input->size(0); - dimw++; - dimslices++; - } - - /* sizes */ - nslices = input->size(dimslices); - iwidth = input->size(dimw); - owidth = iwidth + pad_l + pad_r; - - THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3, - "gradOutput width unexpected. Expected: %d, Got: %d", - owidth, THTensor_(size)(gradOutput, dimw)); - - /* get contiguous gradOutput */ - gradOutput = THTensor_(newContiguous)(gradOutput); - - /* resize */ - THTensor_(resizeAs)(gradInput, input); - THTensor_(zero)(gradInput); - - /* backprop */ - if (input->dim() == 2) { - THNN_(TemporalReplicationPadding_updateGradInput_frame)( - gradInput->data(), - gradOutput->data(), - nslices, - iwidth, - owidth, - pad_l, pad_r); - } else { - long p; -#pragma omp parallel for private(p) - for (p = 0; p < nbatch; p++) { - THNN_(TemporalReplicationPadding_updateGradInput_frame)( - gradInput->data() + p * nslices * iwidth, - gradOutput->data() + p * nslices * owidth, - nslices, - iwidth, - owidth, - pad_l, pad_r); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(gradOutput); -} - - -#endif diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index 0ef123a..0fba04c 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -211,9 +211,6 @@ #include #include -#include -#include - #include #include diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py index b639cb5..9b5eecb 100644 --- a/torch/nn/_functions/thnn/auto.py +++ b/torch/nn/_functions/thnn/auto.py @@ -305,7 +305,6 @@ def _generate_function_classes(scope_dict): name_remap = { 'TemporalConvolution': 'Conv1d', 'TemporalReflectionPadding': 'ReflectionPad1d', - 'TemporalReplicationPadding': 'ReplicationPad1d', 'SpatialDilatedConvolution': 'DilatedConv2d', 'SpatialMaxUnpooling': 'MaxUnpool2d', 'SpatialReflectionPadding': 'ReflectionPad2d', -- 2.7.4