From 56c16fe26f84148d5664ac21dd3943c7d3dded86 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Mon, 25 Mar 2019 14:31:43 -0700 Subject: [PATCH] Porting CPU UpSample functions to ATen (#18020) Summary: This PR resolves partially #10482 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18020 Differential Revision: D14598029 Pulled By: ezyang fbshipit-source-id: 513e7c6438ab6d5dc3f43241e7cb724744e9a287 --- aten/src/ATen/native/LegacyNNDefinitions.cpp | 112 ------ aten/src/ATen/native/UpSample.h | 250 +++++++++++++ aten/src/ATen/native/UpSampleBicubic2d.cpp | 315 +++++++++++++++++ aten/src/ATen/native/UpSampleBilinear2d.cpp | 312 +++++++++++++++++ aten/src/ATen/native/UpSampleLinear1d.cpp | 249 +++++++++++++ aten/src/ATen/native/UpSampleNearest1d.cpp | 222 ++++++++++++ aten/src/ATen/native/UpSampleNearest2d.cpp | 259 ++++++++++++++ aten/src/ATen/native/UpSampleNearest3d.cpp | 309 ++++++++++++++++ aten/src/ATen/native/UpSampleTrilinear3d.cpp | 389 +++++++++++++++++++++ aten/src/ATen/native/cuda/UpSampleBicubic2d.cu | 45 +++ aten/src/ATen/native/cuda/UpSampleBilinear2d.cu | 45 +++ aten/src/ATen/native/cuda/UpSampleLinear1d.cu | 45 +++ aten/src/ATen/native/cuda/UpSampleNearest1d.cu | 41 +++ aten/src/ATen/native/cuda/UpSampleNearest2d.cu | 41 +++ aten/src/ATen/native/cuda/UpSampleNearest3d.cu | 41 +++ aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu | 45 +++ aten/src/ATen/native/native_functions.yaml | 84 +++++ aten/src/THNN/generic/SpatialUpSamplingBicubic.c | 227 +----------- aten/src/THNN/generic/SpatialUpSamplingBilinear.c | 171 +-------- aten/src/THNN/generic/SpatialUpSamplingNearest.c | 147 +------- aten/src/THNN/generic/TemporalUpSamplingLinear.c | 138 +------- aten/src/THNN/generic/TemporalUpSamplingNearest.c | 126 +------ .../src/THNN/generic/VolumetricUpSamplingNearest.c | 163 +-------- .../THNN/generic/VolumetricUpSamplingTrilinear.c | 204 +---------- aten/src/THNN/generic/upsampling.h | 111 ------ 25 files changed, 2767 insertions(+), 1324 deletions(-) create mode 100644 aten/src/ATen/native/UpSample.h create mode 100644 aten/src/ATen/native/UpSampleBicubic2d.cpp create mode 100644 aten/src/ATen/native/UpSampleBilinear2d.cpp create mode 100644 aten/src/ATen/native/UpSampleLinear1d.cpp create mode 100644 aten/src/ATen/native/UpSampleNearest1d.cpp create mode 100644 aten/src/ATen/native/UpSampleNearest2d.cpp create mode 100644 aten/src/ATen/native/UpSampleNearest3d.cpp create mode 100644 aten/src/ATen/native/UpSampleTrilinear3d.cpp create mode 100644 aten/src/ATen/native/cuda/UpSampleBicubic2d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleBilinear2d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleLinear1d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleNearest1d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleNearest2d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleNearest3d.cu create mode 100644 aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu delete mode 100644 aten/src/THNN/generic/upsampling.h diff --git a/aten/src/ATen/native/LegacyNNDefinitions.cpp b/aten/src/ATen/native/LegacyNNDefinitions.cpp index 59d1a39..2f0c2bd 100644 --- a/aten/src/ATen/native/LegacyNNDefinitions.cpp +++ b/aten/src/ATen/native/LegacyNNDefinitions.cpp @@ -476,118 +476,6 @@ Tensor max_unpool3d_backward(const Tensor & grad_output, const Tensor & self, co return at::legacy::th::_thnn_max_unpool3d_backward(grad_output, self, indices, output_size, stride, padding); } -Tensor & upsample_linear1d_out(Tensor & output, const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_linear1d_forward_out(output, self, output_size, align_corners); -} - -Tensor upsample_linear1d(const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_linear1d_forward(self, output_size, align_corners); -} - -Tensor & upsample_linear1d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_linear1d_backward_out(grad_input, grad_output, output_size, input_size, align_corners); -} - -Tensor upsample_linear1d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_linear1d_backward(grad_output, output_size, input_size, align_corners); -} - -Tensor & upsample_bilinear2d_out(Tensor & output, const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bilinear2d_forward_out(output, self, output_size, align_corners); -} - -Tensor upsample_bilinear2d(const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bilinear2d_forward(self, output_size, align_corners); -} - -Tensor & upsample_bilinear2d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bilinear2d_backward_out(grad_input, grad_output, output_size, input_size, align_corners); -} - -Tensor upsample_bilinear2d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bilinear2d_backward(grad_output, output_size, input_size, align_corners); -} - -Tensor & upsample_bicubic2d_out(Tensor & output, const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bicubic2d_forward_out(output, self, output_size, align_corners); -} - -Tensor upsample_bicubic2d(const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bicubic2d_forward(self, output_size, align_corners); -} - -Tensor & upsample_bicubic2d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bicubic2d_backward_out(grad_input, grad_output, output_size, input_size, align_corners); -} - -Tensor upsample_bicubic2d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_bicubic2d_backward(grad_output, output_size, input_size, align_corners); -} - -Tensor & upsample_trilinear3d_out(Tensor & output, const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_trilinear3d_forward_out(output, self, output_size, align_corners); -} - -Tensor upsample_trilinear3d(const Tensor & self, IntArrayRef output_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_trilinear3d_forward(self, output_size, align_corners); -} - -Tensor & upsample_trilinear3d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_trilinear3d_backward_out(grad_input, grad_output, output_size, input_size, align_corners); -} - -Tensor upsample_trilinear3d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size, bool align_corners) { - return at::legacy::th::_thnn_upsample_trilinear3d_backward(grad_output, output_size, input_size, align_corners); -} - -Tensor & upsample_nearest1d_out(Tensor & output, const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest1d_forward_out(output, self, output_size); -} - -Tensor upsample_nearest1d(const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest1d_forward(self, output_size); -} - -Tensor & upsample_nearest1d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest1d_backward_out(grad_input, grad_output, output_size, input_size); -} - -Tensor upsample_nearest1d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest1d_backward(grad_output, output_size, input_size); -} - -Tensor & upsample_nearest2d_out(Tensor & output, const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest2d_forward_out(output, self, output_size); -} - -Tensor upsample_nearest2d(const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest2d_forward(self, output_size); -} - -Tensor & upsample_nearest2d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest2d_backward_out(grad_input, grad_output, output_size, input_size); -} - -Tensor upsample_nearest2d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest2d_backward(grad_output, output_size, input_size); -} - -Tensor & upsample_nearest3d_out(Tensor & output, const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest3d_forward_out(output, self, output_size); -} - -Tensor upsample_nearest3d(const Tensor & self, IntArrayRef output_size) { - return at::legacy::th::_thnn_upsample_nearest3d_forward(self, output_size); -} - -Tensor & upsample_nearest3d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest3d_backward_out(grad_input, grad_output, output_size, input_size); -} - -Tensor upsample_nearest3d_backward(const Tensor & grad_output, IntArrayRef output_size, IntArrayRef input_size) { - return at::legacy::th::_thnn_upsample_nearest3d_backward(grad_output, output_size, input_size); -} - Tensor & sigmoid_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & output) { return at::legacy::th::_thnn_sigmoid_backward_out(grad_input, grad_output, output); } diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h new file mode 100644 index 0000000..4e60da2 --- /dev/null +++ b/aten/src/ATen/native/UpSample.h @@ -0,0 +1,250 @@ +#include + +#include + +namespace at { +namespace native { + +// Corresponds to THNN_CHECK_DIM_SIZE +static inline void check_dim_size( + const Tensor& data, + int64_t dim, + int64_t dim_size, + int64_t size) { + /* Check dimension size of a tensor */ + AT_CHECK( + data.dim() == dim && data.size(dim_size) == size, + "Expected tensor of dimension ", + dim, + " and tensor.size[", + dim_size, + "] == ", + size, + " but got: dimension ", + data.dim(), + " and tensor.size[", + dim_size, + "] = ", + data.size(dim_size)); +} + +static inline void upsample_1d_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t nbatch, + int64_t nchannels, + int64_t input_width, + int64_t output_width) { + AT_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); + + if (input.defined()) { + AT_CHECK( + input.numel() != 0 && input.dim() == 3, + "Non-empty 3D data tensor expected but got a tensor with sizes ", + input.sizes()); + } else if (grad_output.defined()) { + check_dim_size(grad_output, 3, 0, nbatch); + check_dim_size(grad_output, 3, 1, nchannels); + check_dim_size(grad_output, 3, 2, output_width); + } +} + +static inline void upsample_2d_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t nbatch, + int64_t nchannels, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width) { + AT_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", + input_height, + ", W: ", + input_width, + ") output (H: ", + output_height, + ", W: ", + output_width, + ")"); + + if (input.defined()) { + AT_CHECK( + input.numel() != 0 && input.dim() == 4, + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + } else if (grad_output.defined()) { + check_dim_size(grad_output, 4, 0, nbatch); + check_dim_size(grad_output, 4, 1, nchannels); + check_dim_size(grad_output, 4, 2, output_height); + check_dim_size(grad_output, 4, 3, output_width); + } +} + +static inline void upsample_3d_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t nbatch, + int64_t nchannels, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width) { + AT_CHECK( + input_depth > 0 && input_height > 0 && input_width > 0 && + output_depth > 0 && output_height > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (D: ", + input_depth, + ", H: ", + input_height, + ", W: ", + input_width, + ") output (D: ", + output_depth, + ", H: ", + output_height, + ", W: ", + output_width, + ")"); + + if (input.defined()) { + AT_CHECK( + input.numel() != 0 && input.dim() == 5, + "Non-empty 5D data tensor expected but got a tensor with sizes ", + input.sizes()); + } else if (grad_output.defined()) { + check_dim_size(grad_output, 5, 0, nbatch); + check_dim_size(grad_output, 5, 1, nchannels); + check_dim_size(grad_output, 5, 2, output_depth); + check_dim_size(grad_output, 5, 3, output_height); + check_dim_size(grad_output, 5, 4, output_width); + } +} + +template +static inline scalar_t linear_upsample_compute_scale( + int64_t input_size, + int64_t output_size, + bool align_corners) { + /* We view each pixel as an area, idx + 0.5 as its center index. + * Here is an example formula in 1D case. + * if align_corners: center of two corner pixel areas are preserved, + * (0.5, 0.5) -> (0.5, 0.5), + * (input_size - 0.5, 0.5) -> (output_size - 0.5) + * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5) + * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5) + * if not align_corners: the whole range is scaled accordingly + * scale = input_size / output_size + * src_idx + 0.5 = scale * (dst_index + 0.5) + */ + if (output_size > 1) { + return align_corners + ? static_cast(input_size - 1) / (output_size - 1) + : static_cast(input_size) / output_size; + } else { + return scalar_t(0); + } +} + +template +static inline scalar_t linear_upsample_compute_source_index( + scalar_t scale, + int64_t dst_index, + bool align_corners) { + if (align_corners) { + return scale * dst_index; + } else { + scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; + return src_idx < 0 ? scalar_t(0) : src_idx; + } +} + +static inline int64_t nearest_neighbor_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + const int64_t src_index = + std::min(static_cast(floorf(dst_index * scale)), input_size - 1); + return src_index; +} + +template +static scalar_t upsample_get_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + return data[access_y * width + access_x]; +} + +template +static void upsample_increment_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y, + scalar_t value) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + data[access_y * width + access_x] += value; +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +static inline void get_cubic_upsample_coefficients( + scalar_t coeffs[4], + scalar_t t) { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +static inline scalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + scalar_t t) { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleBicubic2d.cpp b/aten/src/ATen/native/UpSampleBicubic2d.cpp new file mode 100644 index 0000000..03f09d7 --- /dev/null +++ b/aten/src/ATen/native/UpSampleBicubic2d.cpp @@ -0,0 +1,315 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_bicubic2d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + // Special case: input/output same size, just copy + if (input_height == output_height && input_width == output_width) { + for (int64_t output_y = 0; output_y < output_height; output_y++) { + for (int64_t output_x = 0; output_x < output_width; output_x++) { + const scalar_t* in = &idata[output_y * input_width + output_x]; + scalar_t* out = &odata[output_y * output_width + output_x]; + + for (int64_t c = 0; c < channels; ++c) { + out[0] = in[0]; + in += input_width * input_height; + out += output_width * output_height; + } + } + } + return; + } + + // Bicubic interpolation + const scalar_t height_scale = linear_upsample_compute_scale( + input_height, output_height, align_corners); + const scalar_t width_scale = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t output_y = 0; output_y < output_height; output_y++) { + for (int64_t output_x = 0; output_x < output_width; output_x++) { + scalar_t* in = idata; + scalar_t* out = odata; + + const scalar_t real_x = width_scale * output_x; + int64_t input_x = real_x; + const scalar_t t_x = real_x - input_x; + + const scalar_t real_y = height_scale * output_y; + int64_t input_y = real_y; + const scalar_t t_y = real_y - input_y; + + for (int64_t c = 0; c < channels * nbatch; c++) { + scalar_t coefficients[4]; + + // Interpolate 4 times in the x direction + for (int64_t i = 0; i < 4; i++) { + coefficients[i] = cubic_interp1d( + upsample_get_value_bounded( + in, input_width, input_height, input_x - 1, input_y - 1 + i), + upsample_get_value_bounded( + in, input_width, input_height, input_x + 0, input_y - 1 + i), + upsample_get_value_bounded( + in, input_width, input_height, input_x + 1, input_y - 1 + i), + upsample_get_value_bounded( + in, input_width, input_height, input_x + 2, input_y - 1 + i), + t_x); + } + + // Interpolate in the y direction using x interpolations + out[output_y * output_width + output_x] = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + t_y); + + // Move to next channel + in += input_width * input_height; + out += output_width * output_height; + } + } + } +} + +template +static void upsample_bicubic2d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // Special case: input/output same size, just copy + if (input_height == output_height && input_width == output_width) { + for (int64_t output_y = 0; output_y < output_height; output_y++) { + for (int64_t output_x = 0; output_x < output_width; output_x++) { + scalar_t* in = &idata[output_y * input_width + output_x]; + scalar_t* out = &odata[output_y * output_width + output_x]; + for (int64_t c = 0; c < channels; ++c) { + in[0] = out[0]; + in += input_width * input_height; + out += output_width * output_height; + } + } + } + return; + } + + const scalar_t height_scale = linear_upsample_compute_scale( + input_height, output_height, align_corners); + const scalar_t width_scale = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t output_y = 0; output_y < output_height; output_y++) { + for (int64_t output_x = 0; output_x < output_width; output_x++) { + scalar_t* in = idata; + scalar_t* out = odata; + + scalar_t real_x = width_scale * output_x; + int64_t input_x = real_x; + scalar_t t_x = real_x - input_x; + + scalar_t real_y = height_scale * output_y; + int64_t input_y = real_y; + scalar_t t_y = real_y - input_y; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + + get_cubic_upsample_coefficients(x_coeffs, t_x); + get_cubic_upsample_coefficients(y_coeffs, t_y); + + for (int64_t c = 0; c < channels; c++) { + scalar_t out_value = out[output_y * output_width + output_x]; + + for (int64_t i = 0; i < 4; i++) { + for (int64_t j = 0; j < 4; j++) { + upsample_increment_value_bounded( + in, + input_width, + input_height, + input_x - 1 + i, + input_y - 1 + j, + out_value * y_coeffs[j] * x_coeffs[i]); + } + } + + in += input_width * input_height; + out += output_width * output_height; + } + } + } +} + +static void upsample_bicubic2d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_height = input_.size(2); + int64_t input_width = input_.size(3); + + upsample_2d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_height, output_width}); + output.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_bicubic2d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_bicubic2d_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} + +static void upsample_bicubic2d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_height = input_size[2]; + int64_t input_width = input_size[3]; + + upsample_2d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + auto grad_output = grad_output_.contiguous(); + + grad_input.resize_({nbatch, channels, input_height, input_width}); + grad_input.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_bicubic2d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_bicubic2d_backward_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} +} // namespace + +Tensor& upsample_bicubic2d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + upsample_bicubic2d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor upsample_bicubic2d_cpu( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + auto output = at::empty({0}, input.options()); + upsample_bicubic2d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor& upsample_bicubic2d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + upsample_bicubic2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +Tensor upsample_bicubic2d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_bicubic2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleBilinear2d.cpp b/aten/src/ATen/native/UpSampleBilinear2d.cpp new file mode 100644 index 0000000..e1f43cf --- /dev/null +++ b/aten/src/ATen/native/UpSampleBilinear2d.cpp @@ -0,0 +1,312 @@ +// Adapted from interp.cpp from Caffe util by Pauline Luc +// Originally developed by George Papandreou + +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_bilinear2d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // special case: just copy + if (input_height == output_height && input_width == output_width) { + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = &idata[h1 * input_width + w1]; + scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_width * input_height; + pos2 += output_width * output_height; + } + } + } + return; + } + const scalar_t rheight = linear_upsample_compute_scale( + input_height, output_height, align_corners); + + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const scalar_t h1r = linear_upsample_compute_source_index( + rheight, h2, align_corners); + + const int64_t h1 = h1r; + const int64_t h1p = (h1 < input_height - 1) ? 1 : 0; + + const scalar_t h1lambda = h1r - h1; + const scalar_t h0lambda = static_cast(1.) - h1lambda; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + const scalar_t* pos1 = &idata[h1 * input_width + w1]; + scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + + h1lambda * + (w0lambda * pos1[h1p * input_width] + + w1lambda * pos1[h1p * input_width + w1p]); + pos1 += input_width * input_height; + pos2 += output_width * output_height; + } + } + } +} + +template +static void upsample_bilinear2d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // special case: same-size matching grids + if (input_height == output_height && input_width == output_width) { + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = &idata[h1 * input_width + w1]; + const scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_width * input_height; + pos2 += output_width * output_height; + } + } + } + return; + } + + const scalar_t rheight = linear_upsample_compute_scale( + input_height, output_height, align_corners); + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const scalar_t h1r = linear_upsample_compute_source_index( + rheight, h2, align_corners); + + const int64_t h1 = h1r; + const int64_t h1p = (h1 < input_height - 1) ? 1 : 0; + + const scalar_t h1lambda = h1r - h1; + const scalar_t h0lambda = static_cast(1.) - h1lambda; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + + scalar_t* pos1 = &idata[h1 * input_width + w1]; + + const scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += h0lambda * w0lambda * pos2[0]; + pos1[w1p] += h0lambda * w1lambda * pos2[0]; + pos1[h1p * input_width] += h1lambda * w0lambda * pos2[0]; + pos1[h1p * input_width + w1p] += h1lambda * w1lambda * pos2[0]; + pos1 += input_width * input_height; + pos2 += output_width * output_height; + } + } + } +} + +static void upsample_bilinear2d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_height = input_.size(2); + int64_t input_width = input_.size(3); + + upsample_2d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_height, output_width}); + output.zero_(); + + AT_ASSERT( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_bilinear2d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_bilinear2d_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} + +static void upsample_bilinear2d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_height = input_size[2]; + int64_t input_width = input_size[3]; + + upsample_2d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + auto grad_output = grad_output_.contiguous(); + + grad_input.resize_({nbatch, channels, input_height, input_width}); + grad_input.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_bilinear2d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_bilinear2d_backward_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} +} // namespace + +Tensor& upsample_bilinear2d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + upsample_bilinear2d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor upsample_bilinear2d_cpu( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + auto output = at::empty({0}, input.options()); + upsample_bilinear2d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor& upsample_bilinear2d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + upsample_bilinear2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +Tensor upsample_bilinear2d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_bilinear2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleLinear1d.cpp b/aten/src/ATen/native/UpSampleLinear1d.cpp new file mode 100644 index 0000000..da7a881 --- /dev/null +++ b/aten/src/ATen/native/UpSampleLinear1d.cpp @@ -0,0 +1,249 @@ +// Adapted from interp.cpp from Caffe util by Pauline Luc +// Originally developed by George Papandreou + +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_linear1d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_width, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // special case: just copy + if (input_width == output_width) { + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = &idata[w1]; + scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_width; + pos2 += output_width; + } + } + return; + } + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + const scalar_t* pos1 = &idata[w1]; + // index w2 is interpolated by idata[w1] and (itself or idata[w1 + 1]) + scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = w0lambda * pos1[0] + w1lambda * pos1[w1p]; + pos1 += input_width; + pos2 += output_width; + } + } +} + +template +static void upsample_linear1d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_width, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = nbatch * channels; + + // special case: same-size matching grids + if (input_width == output_width) { + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = &idata[w1]; + const scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_width; + pos2 += output_width; + } + } + return; + } + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + scalar_t* pos1 = &idata[w1]; + const scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += w0lambda * pos2[0]; + pos1[w1p] += w1lambda * pos2[0]; + pos1 += input_width; + pos2 += output_width; + } + } +} + +static void upsample_linear1d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_width = input_.size(2); + + upsample_1d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_width, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_width}); + output.zero_(); + + AT_ASSERT(input_width > 0 && output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_linear1d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_linear1d_out_frame( + odata, + idata, + input_width, + output_width, + nbatch, + channels, + align_corners); + }); +} + +static void upsample_linear1d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + upsample_1d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_width, + output_width); + + auto grad_output = grad_output_.contiguous(); + + grad_input.resize_({nbatch, channels, input_width}); + grad_input.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_linear1d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_linear1d_backward_out_frame( + odata, + idata, + input_width, + output_width, + nbatch, + channels, + align_corners); + }); +} +} // namespace + +Tensor& upsample_linear1d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + upsample_linear1d_out_cpu_template(output, input, output_size, align_corners); + return output; +} + +Tensor upsample_linear1d_cpu( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + auto output = at::empty({0}, input.options()); + upsample_linear1d_out_cpu_template(output, input, output_size, align_corners); + return output; +} + +Tensor& upsample_linear1d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + upsample_linear1d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +Tensor upsample_linear1d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_linear1d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleNearest1d.cpp b/aten/src/ATen/native/UpSampleNearest1d.cpp new file mode 100644 index 0000000..4d99943 --- /dev/null +++ b/aten/src/ATen/native/UpSampleNearest1d.cpp @@ -0,0 +1,222 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_nearest1d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_width, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float scale = (float)input_width / (float)output_width; + channels = channels * nbatch; + + // special case: just copy + if (input_width == output_width) { + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = &idata[w1]; + scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_width; + pos2 += output_width; + } + } + return; + } + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t src_x = + nearest_neighbor_compute_source_index(scale, w2, input_width); + const int64_t w1 = src_x; + const scalar_t* pos1 = &idata[w1]; + scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_width; + pos2 += output_width; + } + } +} + +template +static void upsample_nearest1d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_width, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float scale = (float)input_width / (float)output_width; + channels = channels * nbatch; + + // special case: same-size matching grids + if (input_width == output_width) { + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = &idata[w1]; + const scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_width; + pos2 += output_width; + } + } + return; + } + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = + nearest_neighbor_compute_source_index(scale, w2, input_width); + + scalar_t* pos1 = &idata[w1]; + const scalar_t* pos2 = &odata[w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_width; + pos2 += output_width; + } + } +} + +static void upsample_nearest1d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size) { + AT_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_width = input_.size(2); + + upsample_1d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_width, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_width}); + output.zero_(); + + AT_ASSERT(input_width > 0 && output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_nearest1d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_nearest1d_out_frame( + odata, + idata, + input_width, + output_width, + nbatch, + channels); + }); +} + +static void upsample_nearest1d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size) { + AT_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + upsample_1d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_width, + output_width); + + auto grad_output = grad_output_.contiguous(); + + grad_input.resize_({nbatch, channels, input_width}); + grad_input.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_nearest1d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_nearest1d_backward_out_frame( + odata, + idata, + input_width, + output_width, + nbatch, + channels); + }); +} +} // namespace + +Tensor& upsample_nearest1d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + upsample_nearest1d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor upsample_nearest1d_cpu(const Tensor& input, IntArrayRef output_size) { + auto output = at::empty({0}, input.options()); + upsample_nearest1d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor& upsample_nearest1d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + upsample_nearest1d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +Tensor upsample_nearest1d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_nearest1d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleNearest2d.cpp b/aten/src/ATen/native/UpSampleNearest2d.cpp new file mode 100644 index 0000000..eb9d5fc --- /dev/null +++ b/aten/src/ATen/native/UpSampleNearest2d.cpp @@ -0,0 +1,259 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_nearest2d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float height_scale = (float)input_height / (float)output_height; + const float width_scale = (float)input_width / (float)output_width; + + channels = channels * nbatch; + + // special case: just copy + if (input_height == output_height && input_width == output_width) { + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = &idata[h1 * input_width + w1]; + scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_height * input_width; + pos2 += output_height * output_width; + } + } + } + return; + } + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = + nearest_neighbor_compute_source_index(height_scale, h2, input_height); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = + nearest_neighbor_compute_source_index(width_scale, w2, input_width); + + const scalar_t* pos1 = &idata[h1 * input_width + w1]; + scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_height * input_width; + pos2 += output_height * output_width; + } + } + } +} + +template +static void upsample_nearest2d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float height_scale = (float)input_height / (float)output_height; + const float width_scale = (float)input_width / (float)output_width; + + channels = channels * nbatch; + + // special case: just copy + if (input_height == output_height && input_width == output_width) { + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = &idata[h1 * input_width + w1]; + const scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] = pos2[0]; + pos1 += input_height * input_width; + pos2 += output_height * output_width; + } + } + } + return; + } + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = + nearest_neighbor_compute_source_index(height_scale, h2, input_height); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = + nearest_neighbor_compute_source_index(width_scale, w2, input_width); + scalar_t* pos1 = &idata[h1 * input_width + w1]; + const scalar_t* pos2 = &odata[h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_height * input_width; + pos2 += output_height * output_width; + } + } + } +} + +static void upsample_nearest2d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_height = input_.size(2); + int64_t input_width = input_.size(3); + + upsample_2d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_height, output_width}); + output.zero_(); + + AT_ASSERT(input_width > 0 && output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_nearest2d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_nearest2d_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels); + }); +} + +static void upsample_nearest2d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size) { + AT_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_height = input_size[2]; + int64_t input_width = input_size[3]; + + upsample_2d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + + grad_input.resize_({nbatch, channels, input_height, input_width}); + grad_input.zero_(); + + auto grad_output = grad_output_.contiguous(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_nearest2d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_nearest2d_backward_out_frame( + odata, + idata, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels); + }); +} +} // namespace + +Tensor& upsample_nearest2d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + upsample_nearest2d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor upsample_nearest2d_cpu(const Tensor& input, IntArrayRef output_size) { + auto output = at::empty({0}, input.options()); + upsample_nearest2d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor& upsample_nearest2d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + upsample_nearest2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +Tensor upsample_nearest2d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_nearest2d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleNearest3d.cpp b/aten/src/ATen/native/UpSampleNearest3d.cpp new file mode 100644 index 0000000..fd550fd --- /dev/null +++ b/aten/src/ATen/native/UpSampleNearest3d.cpp @@ -0,0 +1,309 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_nearest3d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float depth_scale = (float)input_depth / (float)output_depth; + const float height_scale = (float)input_height / (float)output_height; + const float width_scale = (float)input_width / (float)output_width; + + channels = channels * nbatch; + + // special case: just copy + if (input_depth == output_depth && input_height == output_height && + input_width == output_width) { + for (int64_t d2 = 0; d2 < output_depth; ++d2) { + const int64_t d1 = d2; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = + &idata[d1 * input_height * input_width + h1 * input_width + w1]; + scalar_t* pos2 = + &odata + [d2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_depth * input_height * input_width; + pos2 += output_depth * output_height * output_width; + } + } + } + } + return; + } + + for (int64_t d2 = 0; d2 < output_depth; ++d2) { + const int64_t d1 = + nearest_neighbor_compute_source_index(depth_scale, d2, input_depth); + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = + nearest_neighbor_compute_source_index(height_scale, h2, input_height); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = + nearest_neighbor_compute_source_index(width_scale, w2, input_width); + const scalar_t* pos1 = + &idata[d1 * input_height * input_width + h1 * input_width + w1]; + scalar_t* pos2 = + &odata[d2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_depth * input_height * input_width; + pos2 += output_depth * output_height * output_width; + } + } + } + } +} + +template +static void upsample_nearest3d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + const float depth_scale = (float)input_depth / (float)output_depth; + const float height_scale = (float)input_height / (float)output_height; + const float width_scale = (float)input_width / (float)output_width; + + channels = channels * nbatch; + + // special case: just copy + if (input_depth == output_depth && input_height == output_height && + input_width == output_width) { + for (int64_t d2 = 0; d2 < output_depth; ++d2) { + const int64_t d1 = d2; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = + &idata[d1 * input_height * input_width + h1 * input_width + w1]; + const scalar_t* pos2 = + &odata + [d2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_depth * input_height * input_width; + pos2 += output_depth * output_height * output_width; + } + } + } + } + return; + } + + for (int64_t d2 = 0; d2 < output_depth; ++d2) { + const int64_t d1 = + nearest_neighbor_compute_source_index(depth_scale, d2, input_depth); + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = + nearest_neighbor_compute_source_index(height_scale, h2, input_height); + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = + nearest_neighbor_compute_source_index(width_scale, w2, input_width); + scalar_t* pos1 = + &idata[d1 * input_height * input_width + h1 * input_width + w1]; + const scalar_t* pos2 = + &odata[d2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_depth * input_height * input_width; + pos2 += output_depth * output_height * output_width; + } + } + } + } +} + +static void upsample_nearest3d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size) { + AT_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_depth = input_.size(2); + int64_t input_height = input_.size(3); + int64_t input_width = input_.size(4); + + upsample_3d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_depth, output_height, output_width}); + output.zero_(); + + AT_ASSERT( + input_depth > 0 && input_height > 0 && input_width > 0 && + output_depth > 0 && output_height > 0 && output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_nearest3d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_nearest3d_out_frame( + odata, + idata, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + nbatch, + channels); + }); +} + +static void upsample_nearest3d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size) { + AT_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 5, + "It is expected input_size equals to 5, but got size ", + input_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_depth = input_size[2]; + int64_t input_height = input_size[3]; + int64_t input_width = input_size[4]; + + upsample_3d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width); + + grad_input.resize_( + {nbatch, channels, input_depth, input_height, input_width}); + grad_input.zero_(); + + auto grad_output = grad_output_.contiguous(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_nearest3d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_nearest3d_backward_out_frame( + odata, + idata, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + nbatch, + channels); + }); +} +} // namespace + +Tensor& upsample_nearest3d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + upsample_nearest3d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor upsample_nearest3d_cpu(const Tensor& input, IntArrayRef output_size) { + auto output = at::empty({0}, input.options()); + upsample_nearest3d_out_cpu_template(output, input, output_size); + return output; +} + +Tensor& upsample_nearest3d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + upsample_nearest3d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +Tensor upsample_nearest3d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_nearest3d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/UpSampleTrilinear3d.cpp b/aten/src/ATen/native/UpSampleTrilinear3d.cpp new file mode 100644 index 0000000..cece11e --- /dev/null +++ b/aten/src/ATen/native/UpSampleTrilinear3d.cpp @@ -0,0 +1,389 @@ +// Adapted from interp.cpp from Caffe util by Pauline Luc +// Originally developed by George Papandreou + +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +static void upsample_trilinear3d_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // special case: just copy + if (input_depth == output_depth && input_height == output_height && + input_width == output_width) { + for (int64_t t2 = 0; t2 < output_depth; ++t2) { + const int64_t t1 = t2; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + const scalar_t* pos1 = + &idata[t1 * input_height * input_width + h1 * input_width + w1]; + scalar_t* pos2 = + &odata + [t2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += input_width * input_height * input_depth; + pos2 += output_width * output_height * output_depth; + } + } + } + } + return; + } + const scalar_t rdepth = linear_upsample_compute_scale( + input_depth, output_depth, align_corners); + const scalar_t rheight = linear_upsample_compute_scale( + input_height, output_height, align_corners); + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + for (int64_t t2 = 0; t2 < output_depth; ++t2) { + const scalar_t t1r = linear_upsample_compute_source_index( + rdepth, t2, align_corners); + + const int64_t t1 = t1r; + const int64_t t1p = (t1 < input_depth - 1) ? 1 : 0; + const scalar_t t1lambda = t1r - t1; + const scalar_t t0lambda = static_cast(1.) - t1lambda; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const scalar_t h1r = linear_upsample_compute_source_index( + rheight, h2, align_corners); + + const int64_t h1 = h1r; + const int64_t h1p = (h1 < input_height - 1) ? 1 : 0; + const scalar_t h1lambda = h1r - h1; + const scalar_t h0lambda = static_cast(1.) - h1lambda; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + const scalar_t* pos1 = + &idata[t1 * input_height * input_width + h1 * input_width + w1]; + scalar_t* pos2 = + &odata[t2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos2[0] = t0lambda * + (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + + h1lambda * + (w0lambda * pos1[h1p * input_width] + + w1lambda * pos1[h1p * input_width + w1p])) + + t1lambda * + (h0lambda * + (w0lambda * pos1[t1p * input_height * input_width] + + w1lambda * + pos1[t1p * input_height * input_width + w1p]) + + h1lambda * + (w0lambda * + pos1 + [t1p * input_height * input_width + + h1p * input_width] + + w1lambda * + pos1 + [t1p * input_height * input_width + + h1p * input_width + w1p])); + pos1 += input_width * input_height * input_depth; + pos2 += output_width * output_height * output_depth; + } + } + } + } +} + +template +static void upsample_trilinear3d_backward_out_frame( + scalar_t* odata, + scalar_t* idata, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners) { + channels = channels * nbatch; + + // special case: same-size matching grids + if (input_depth == output_depth && input_height == output_height && + input_width == output_width) { + for (int64_t t2 = 0; t2 < output_depth; ++t2) { + const int64_t t1 = t2; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const int64_t h1 = h2; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const int64_t w1 = w2; + scalar_t* pos1 = + &idata[t1 * input_height * input_width + h1 * input_width + w1]; + const scalar_t* pos2 = + &odata + [t2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += input_width * input_height * input_depth; + pos2 += output_width * output_height * output_depth; + } + } + } + } + return; + } + const scalar_t rdepth = linear_upsample_compute_scale( + input_depth, output_depth, align_corners); + + const scalar_t rheight = linear_upsample_compute_scale( + input_height, output_height, align_corners); + + const scalar_t rwidth = linear_upsample_compute_scale( + input_width, output_width, align_corners); + + for (int64_t t2 = 0; t2 < output_depth; ++t2) { + const scalar_t t1r = linear_upsample_compute_source_index( + rdepth, t2, align_corners); + const int64_t t1 = t1r; + const int64_t t1p = (t1 < input_depth - 1) ? 1 : 0; + const scalar_t t1lambda = t1r - t1; + const scalar_t t0lambda = static_cast(1.) - t1lambda; + + for (int64_t h2 = 0; h2 < output_height; ++h2) { + const scalar_t h1r = linear_upsample_compute_source_index( + rheight, h2, align_corners); + const int64_t h1 = h1r; + const int64_t h1p = (h1 < input_height - 1) ? 1 : 0; + const scalar_t h1lambda = h1r - h1; + const scalar_t h0lambda = static_cast(1.) - h1lambda; + + for (int64_t w2 = 0; w2 < output_width; ++w2) { + const scalar_t w1r = linear_upsample_compute_source_index( + rwidth, w2, align_corners); + const int64_t w1 = w1r; + const int64_t w1p = (w1 < input_width - 1) ? 1 : 0; + const scalar_t w1lambda = w1r - w1; + const scalar_t w0lambda = static_cast(1.) - w1lambda; + scalar_t* pos1 = + &idata[t1 * input_height * input_width + h1 * input_width + w1]; + const scalar_t* pos2 = + &odata[t2 * output_height * output_width + h2 * output_width + w2]; + + for (int64_t c = 0; c < channels; ++c) { + pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0]; + pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0]; + pos1[h1p * input_width] += t0lambda * h1lambda * w0lambda * pos2[0]; + pos1[h1p * input_width + w1p] += + t0lambda * h1lambda * w1lambda * pos2[0]; + pos1[t1p * input_height * input_width] += + t1lambda * h0lambda * w0lambda * pos2[0]; + pos1[t1p * input_height * input_width + w1p] += + t1lambda * h0lambda * w1lambda * pos2[0]; + pos1[t1p * input_height * input_width + h1p * input_width] += + t1lambda * h1lambda * w0lambda * pos2[0]; + pos1[t1p * input_height * input_width + h1p * input_width + w1p] += + t1lambda * h1lambda * w1lambda * pos2[0]; + pos1 += input_width * input_height * input_depth; + pos2 += output_width * output_height * output_depth; + } + } + } + } +} + +static void upsample_trilinear3d_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_.size(0); + int64_t channels = input_.size(1); + int64_t input_depth = input_.size(2); + int64_t input_height = input_.size(3); + int64_t input_width = input_.size(4); + + upsample_3d_shape_check( + input_, + Tensor(), + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width); + + auto input = input_.contiguous(); + + output.resize_({nbatch, channels, output_depth, output_height, output_width}); + output.zero_(); + + AT_ASSERT( + input_depth > 0 && input_height > 0 && input_width > 0 && + output_depth > 0 && output_height > 0 && output_width > 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "upsample_trilinear3d", [&] { + auto* idata = input.data(); + auto* odata = output.data(); + + upsample_trilinear3d_out_frame( + odata, + idata, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} + +static void upsample_trilinear3d_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + AT_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + AT_CHECK( + input_size.size() == 5, + "It is expected input_size equals to 5, but got size ", + input_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_depth = input_size[2]; + int64_t input_height = input_size[3]; + int64_t input_width = input_size[4]; + + upsample_3d_shape_check( + Tensor(), + grad_output_, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width); + + auto grad_output = grad_output_.contiguous(); + + grad_input.resize_( + {nbatch, channels, input_depth, input_height, input_width}); + grad_input.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_trilinear3d_backward", [&] { + scalar_t* idata = grad_input.data(); + scalar_t* odata = grad_output.data(); + + upsample_trilinear3d_backward_out_frame( + odata, + idata, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + nbatch, + channels, + align_corners); + }); +} +} // namespace + +Tensor& upsample_trilinear3d_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + upsample_trilinear3d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor upsample_trilinear3d_cpu( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + auto output = at::empty({0}, input.options()); + upsample_trilinear3d_out_cpu_template( + output, input, output_size, align_corners); + return output; +} + +Tensor& upsample_trilinear3d_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + upsample_trilinear3d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +Tensor upsample_trilinear3d_backward_cpu( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + auto grad_input = at::zeros(input_size, grad_output.options()); + upsample_trilinear3d_backward_out_cpu_template( + grad_input, grad_output, output_size, input_size, align_corners); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu new file mode 100644 index 0000000..5375e33 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu @@ -0,0 +1,45 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_bicubic2d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bicubic2d_forward_out( + output, input, output_size, align_corners); +} + +Tensor upsample_bicubic2d_cuda( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bicubic2d_forward( + input, output_size, align_corners); +} + +Tensor& upsample_bicubic2d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bicubic2d_backward_out( + grad_input, grad_output, output_size, input_size, align_corners); +} + +Tensor upsample_bicubic2d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bicubic2d_backward( + grad_output, output_size, input_size, align_corners); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu new file mode 100644 index 0000000..7c53443 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -0,0 +1,45 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_bilinear2d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bilinear2d_forward_out( + output, input, output_size, align_corners); +} + +Tensor upsample_bilinear2d_cuda( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bilinear2d_forward( + input, output_size, align_corners); +} + +Tensor& upsample_bilinear2d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bilinear2d_backward_out( + grad_input, grad_output, output_size, input_size, align_corners); +} + +Tensor upsample_bilinear2d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_bilinear2d_backward( + grad_output, output_size, input_size, align_corners); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu new file mode 100644 index 0000000..eb491f5 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu @@ -0,0 +1,45 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_linear1d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_linear1d_forward_out( + output, input, output_size, align_corners); +} + +Tensor upsample_linear1d_cuda( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_linear1d_forward( + input, output_size, align_corners); +} + +Tensor& upsample_linear1d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_linear1d_backward_out( + grad_input, grad_output, output_size, input_size, align_corners); +} + +Tensor upsample_linear1d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_linear1d_backward( + grad_output, output_size, input_size, align_corners); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu new file mode 100644 index 0000000..bfba5a1 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu @@ -0,0 +1,41 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_nearest1d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest1d_forward_out( + output, input, output_size); +} + +Tensor upsample_nearest1d_cuda( + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest1d_forward( + input, output_size); +} + +Tensor& upsample_nearest1d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest1d_backward_out( + grad_input, grad_output, output_size, input_size); +} + +Tensor upsample_nearest1d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest1d_backward( + grad_output, output_size, input_size); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu new file mode 100644 index 0000000..83f40f6 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -0,0 +1,41 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_nearest2d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest2d_forward_out( + output, input, output_size); +} + +Tensor upsample_nearest2d_cuda( + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest2d_forward( + input, output_size); +} + +Tensor& upsample_nearest2d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest2d_backward_out( + grad_input, grad_output, output_size, input_size); +} + +Tensor upsample_nearest2d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest2d_backward( + grad_output, output_size, input_size); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu new file mode 100644 index 0000000..bb208a5 --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -0,0 +1,41 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_nearest3d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest3d_forward_out( + output, input, output_size); +} + +Tensor upsample_nearest3d_cuda( + const Tensor& input, + IntArrayRef output_size) { + return at::legacy::th::_thnn_upsample_nearest3d_forward( + input, output_size); +} + +Tensor& upsample_nearest3d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest3d_backward_out( + grad_input, grad_output, output_size, input_size); +} + +Tensor upsample_nearest3d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size) { + return at::legacy::th::_thnn_upsample_nearest3d_backward( + grad_output, output_size, input_size); +} + +} // native +} // at diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu new file mode 100644 index 0000000..386887f --- /dev/null +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -0,0 +1,45 @@ +#include +#include +#include + +namespace at { +namespace native { + +Tensor& upsample_trilinear3d_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_trilinear3d_forward_out( + output, input, output_size, align_corners); +} + +Tensor upsample_trilinear3d_cuda( + const Tensor& input, + IntArrayRef output_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_trilinear3d_forward( + input, output_size, align_corners); +} + +Tensor& upsample_trilinear3d_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_trilinear3d_backward_out( + grad_input, grad_output, output_size, input_size, align_corners); +} + +Tensor upsample_trilinear3d_backward_cuda( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners) { + return at::legacy::th::_thnn_upsample_trilinear3d_backward( + grad_output, output_size, input_size, align_corners); +} + +} // native +} // at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ca8a5ad..d3dd911 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4788,114 +4788,198 @@ - func: upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_linear1d_out_cpu + CUDA: upsample_linear1d_out_cuda - func: upsample_linear1d(Tensor self, int[1] output_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_linear1d_cpu + CUDA: upsample_linear1d_cuda - func: upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_linear1d_backward_out_cpu + CUDA: upsample_linear1d_backward_out_cuda - func: upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_linear1d_backward_cpu + CUDA: upsample_linear1d_backward_cuda - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bilinear2d_out_cpu + CUDA: upsample_bilinear2d_out_cuda - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bilinear2d_cpu + CUDA: upsample_bilinear2d_cuda - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bilinear2d_backward_out_cpu + CUDA: upsample_bilinear2d_backward_out_cuda - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bilinear2d_backward_cpu + CUDA: upsample_bilinear2d_backward_cuda - func: upsample_bicubic2d(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bicubic2d_out_cpu + CUDA: upsample_bicubic2d_out_cuda - func: upsample_bicubic2d(Tensor self, int[2] output_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bicubic2d_cpu + CUDA: upsample_bicubic2d_cuda - func: upsample_bicubic2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bicubic2d_backward_out_cpu + CUDA: upsample_bicubic2d_backward_out_cuda - func: upsample_bicubic2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_bicubic2d_backward_cpu + CUDA: upsample_bicubic2d_backward_cuda - func: upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_trilinear3d_out_cpu + CUDA: upsample_trilinear3d_out_cuda - func: upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_trilinear3d_cpu + CUDA: upsample_trilinear3d_cuda - func: upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_trilinear3d_backward_out_cpu + CUDA: upsample_trilinear3d_backward_out_cuda - func: upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_trilinear3d_backward_cpu + CUDA: upsample_trilinear3d_backward_cuda - func: upsample_nearest1d(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest1d_out_cpu + CUDA: upsample_nearest1d_out_cuda - func: upsample_nearest1d(Tensor self, int[1] output_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest1d_cpu + CUDA: upsample_nearest1d_cuda - func: upsample_nearest1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest1d_backward_out_cpu + CUDA: upsample_nearest1d_backward_out_cuda - func: upsample_nearest1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest1d_backward_cpu + CUDA: upsample_nearest1d_backward_cuda - func: upsample_nearest2d(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest2d_out_cpu + CUDA: upsample_nearest2d_out_cuda - func: upsample_nearest2d(Tensor self, int[2] output_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest2d_cpu + CUDA: upsample_nearest2d_cuda - func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest2d_backward_out_cpu + CUDA: upsample_nearest2d_backward_out_cuda - func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest2d_backward_cpu + CUDA: upsample_nearest2d_backward_cuda - func: upsample_nearest3d(Tensor self, int[3] output_size, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest3d_out_cpu + CUDA: upsample_nearest3d_out_cuda - func: upsample_nearest3d(Tensor self, int[3] output_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest3d_cpu + CUDA: upsample_nearest3d_cuda - func: upsample_nearest3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest3d_backward_out_cpu + CUDA: upsample_nearest3d_backward_out_cuda - func: upsample_nearest3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size) -> Tensor matches_jit_signature: True python_module: nn + dispatch: + CPU: upsample_nearest3d_backward_cpu + CUDA: upsample_nearest3d_backward_cuda - func: sigmoid_backward(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) matches_jit_signature: True diff --git a/aten/src/THNN/generic/SpatialUpSamplingBicubic.c b/aten/src/THNN/generic/SpatialUpSamplingBicubic.c index 5faf304..a81d6bc 100644 --- a/aten/src/THNN/generic/SpatialUpSamplingBicubic.c +++ b/aten/src/THNN/generic/SpatialUpSamplingBicubic.c @@ -2,225 +2,28 @@ #define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingBicubic.c" #else -#include - -static inline void THNN_(SpatialUpSamplingBicubic_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int input_height, int input_width, - int output_height, int output_width) { - THArgCheck(input_height > 0 && input_width > 0 - && output_height > 0 && output_width > 0, 2, - "input and output sizes should be greater than 0," - " but got input (H: %d, W: %d) output (H: %d, W: %d)", - input_height, input_width, output_height, output_width); - if (input != NULL) { - THNN_ARGCHECK(!input->is_empty() && input->dim() == 4, 2, input, - "non-empty 4D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, output_height); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, output_width); - } -} - void THNN_(SpatialUpSamplingBicubic_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int output_height, - int output_width, + THNNState* state, + THTensor* input, + THTensor* output, + int outputHeight, + int outputWidth, bool align_corners) { - - const int nbatch = THTensor_(size)(input, 0); - const int channels = THTensor_(size)(input, 1); - const int input_height = THTensor_(size)(input, 2); - const int input_width = THTensor_(size)(input, 3); - - THNN_(SpatialUpSamplingBicubic_shapeCheck) - (input, NULL, - nbatch, channels, - input_height, input_width, - output_height, output_width); - - input = THTensor_(newContiguous)(input); - THTensor_(resize4d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - output_height, output_width); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - - // Special case: input/output same size, just copy - if (input_height == output_height && input_width == output_width) { - for (int output_y = 0; output_y < output_height; output_y++) { - for (int output_x = 0; output_x < output_width; output_x++) { - const scalar_t* in = &idata[output_y * input_width + output_x]; - scalar_t* out = &odata[output_y * output_width + output_x]; - for (int c = 0; c < channels; ++c) { - out[0] = in[0]; - in += input_width * input_height; - out += output_width * output_height; - } - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - - // Bicubic interpolation - const accreal height_scale = linear_upsampling_compute_scale( - input_height, - output_height, - align_corners); - const accreal width_scale = linear_upsampling_compute_scale( - input_width, - output_width, - align_corners); - - for (int output_y = 0; output_y < output_height; output_y++) { - for (int output_x = 0; output_x < output_width; output_x++) { - scalar_t* in = idata; - scalar_t* out = odata; - - const scalar_t real_x = width_scale * output_x; - int input_x = real_x; - const scalar_t t_x = real_x - input_x; - - const scalar_t real_y = height_scale * output_y; - int input_y = real_y; - const scalar_t t_y = real_y - input_y; - - for (int c = 0; c < channels * nbatch; c++) { - scalar_t coefficients[4]; - - // Interpolate 4 times in the x direction - for (int i = 0; i < 4; i++) { - coefficients[i] = cubic_interp1d( - upsampling_get_value_bounded( - in, input_width, input_height, input_x - 1, input_y - 1 + i), - upsampling_get_value_bounded( - in, input_width, input_height, input_x + 0, input_y - 1 + i), - upsampling_get_value_bounded( - in, input_width, input_height, input_x + 1, input_y - 1 + i), - upsampling_get_value_bounded( - in, input_width, input_height, input_x + 2, input_y - 1 + i), - t_x - ); - } - - // Interpolate in the y direction using x interpolations - out[output_y * output_width + output_x] = cubic_interp1d( - coefficients[0], - coefficients[1], - coefficients[2], - coefficients[3], - t_y - ); - - // Move to next channel - in += input_width * input_height; - out += output_width * output_height; - } - } - } - - c10::raw::intrusive_ptr::decref(input); + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(SpatialUpSamplingBicubic_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, - int input_height, - int input_width, - int output_height, - int output_width, - bool align_corners){ - - THNN_(SpatialUpSamplingBicubic_shapeCheck) - (NULL, gradOutput, - nbatch, channels, - input_height, input_width, - output_height, output_width); - - THTensor_(resize4d)(gradInput, nbatch, channels, input_height, input_width); - THTensor_(zero)(gradInput); - - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *idata = gradInput->data(); - scalar_t *odata = gradOutput->data(); - channels = nbatch * channels; - - // Special case: input/output same size, just copy - if (input_height == output_height && input_width == output_width) { - for (int output_y = 0; output_y < output_height; output_y++) { - for (int output_x = 0; output_x < output_width; output_x++) { - scalar_t* in = &idata[output_y * input_width + output_x]; - scalar_t* out = &odata[output_y * output_width + output_x]; - for (int c = 0; c < channels; ++c) { - in[0] = out[0]; - in += input_width * input_height; - out += output_width * output_height; - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - - const accreal height_scale = linear_upsampling_compute_scale( - input_height, output_height, align_corners); - const accreal width_scale = linear_upsampling_compute_scale( - input_width, output_width, align_corners); - - for (int output_y = 0; output_y < output_height; output_y++) { - for (int output_x = 0; output_x < output_width; output_x++) { - scalar_t* in = idata; - scalar_t* out = odata; - - scalar_t real_x = width_scale * output_x; - int input_x = real_x; - scalar_t t_x = real_x - input_x; - - scalar_t real_y = height_scale * output_y; - int input_y = real_y; - scalar_t t_y = real_y - input_y; - - scalar_t x_coeffs[4]; - scalar_t y_coeffs[4]; - - get_cubic_upsampling_coefficients(x_coeffs, t_x); - get_cubic_upsampling_coefficients(y_coeffs, t_y); - - - for (int c = 0; c < channels; c++) { - scalar_t out_value = out[output_y * output_width + output_x]; - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - upsampling_increment_value_bounded(in, - input_width, - input_height, - input_x - 1 + i, - input_y - 1 + j, - out_value * y_coeffs[j] * x_coeffs[i]); - } - } - - in += input_width * input_height; - out += output_width * output_height; - } - } - } - - c10::raw::intrusive_ptr::decref(gradOutput); + int inputHeight, + int inputWidth, + int outputHeight, + int outputWidth, + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/SpatialUpSamplingBilinear.c b/aten/src/THNN/generic/SpatialUpSamplingBilinear.c index 647f52a..0dc6463 100644 --- a/aten/src/THNN/generic/SpatialUpSamplingBilinear.c +++ b/aten/src/THNN/generic/SpatialUpSamplingBilinear.c @@ -1,180 +1,29 @@ -// Adapted from interp.cpp from Caffe util by Pauline Luc -// Originally developed by George Papandreou - #ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingBilinear.c" #else -#include - -static inline void THNN_(SpatialUpSamplingBilinear_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputHeight, int inputWidth, - int outputHeight, int outputWidth) { - THArgCheck(inputHeight > 0 && inputWidth > 0 - && outputHeight > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (H: %d, W: %d) output (H: %d, W: %d)", - inputHeight, inputWidth, outputHeight, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(!input->is_empty() && input->dim() == 4, 2, input, - "non-empty 4D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, outputHeight); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, outputWidth); - } -} - void THNN_(SpatialUpSamplingBilinear_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, + THNNState* state, + THTensor* input, + THTensor* output, int outputHeight, int outputWidth, - bool align_corners){ - - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputHeight = THTensor_(size)(input, 2); - int inputWidth = THTensor_(size)(input, 3); - - THNN_(SpatialUpSamplingBilinear_shapeCheck) - (input, NULL, - nbatch, channels, - inputHeight, inputWidth, - outputHeight, outputWidth); - - input = THTensor_(newContiguous)(input); - THTensor_(resize4d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputHeight, outputWidth); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - channels = nbatch * channels; - THAssert(inputHeight > 0 && inputWidth > 0 && outputHeight > 0 && outputWidth > 0); - // special case: just copy - if (inputHeight == outputHeight && inputWidth == outputWidth) { - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputWidth * inputHeight; - pos2 += outputWidth * outputHeight; - } - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - const accreal rheight = linear_upsampling_compute_scale(inputHeight, outputHeight, align_corners); - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int h2 = 0; h2 < outputHeight; ++h2) { - const accreal h1r = linear_upsampling_compute_source_index(rheight, h2, align_corners); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; - const scalar_t h1lambda = h1r - h1; - const scalar_t h0lambda = (scalar_t)1. - h1lambda; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - const scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p]) - + h1lambda * (w0lambda * pos1[h1p * inputWidth] - + w1lambda * pos1[h1p * inputWidth + w1p]); - pos1 += inputWidth * inputHeight; - pos2 += outputWidth * outputHeight; - } - } - } - c10::raw::intrusive_ptr::decref(input); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(SpatialUpSamplingBilinear_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputHeight, int inputWidth, int outputHeight, int outputWidth, - bool align_corners){ - - THNN_(SpatialUpSamplingBilinear_shapeCheck) - (NULL, gradOutput, - nbatch, channels, - inputHeight, inputWidth, - outputHeight, outputWidth); - - THTensor_(resize4d)(gradInput, nbatch, channels, inputHeight, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *data1 = gradInput->data(); - scalar_t *data2 = gradOutput->data(); - channels = nbatch * channels; - - // special case: same-size matching grids - if (inputHeight == outputHeight && inputWidth == outputWidth) { - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &data1[h1 * inputWidth + w1]; - const scalar_t* pos2 = &data2[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputWidth * inputHeight; - pos2 += outputWidth * outputHeight; - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - const accreal rheight = linear_upsampling_compute_scale(inputHeight, outputHeight, align_corners); - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int h2 = 0; h2 < outputHeight; ++h2) { - const accreal h1r = linear_upsampling_compute_source_index(rheight, h2, align_corners); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; - const scalar_t h1lambda = h1r - h1; - const scalar_t h0lambda = (scalar_t)1. - h1lambda; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - scalar_t* pos1 = &data1[h1 * inputWidth + w1]; - const scalar_t* pos2 = &data2[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += h0lambda * w0lambda * pos2[0]; - pos1[w1p] += h0lambda * w1lambda * pos2[0]; - pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0]; - pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0]; - pos1 += inputWidth * inputHeight; - pos2 += outputWidth * outputHeight; - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/SpatialUpSamplingNearest.c b/aten/src/THNN/generic/SpatialUpSamplingNearest.c index 8fd4973..82f8237 100644 --- a/aten/src/THNN/generic/SpatialUpSamplingNearest.c +++ b/aten/src/THNN/generic/SpatialUpSamplingNearest.c @@ -2,153 +2,26 @@ #define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingNearest.c" #else -#include - -static inline void THNN_(SpatialUpSamplingNearest_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputHeight, int inputWidth, - int outputHeight, int outputWidth) { - THArgCheck(inputHeight > 0 && inputWidth > 0 - && outputHeight > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (H: %d, W: %d) output (H: %d, W: %d)", - inputHeight, inputWidth, outputHeight, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 4, 2, input, - "4D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, outputHeight); - THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, outputWidth); - } -} - - void THNN_(SpatialUpSamplingNearest_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, + THNNState* state, + THTensor* input, + THTensor* output, int outputHeight, - int outputWidth) -{ - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputHeight = THTensor_(size)(input, 2); - int inputWidth = THTensor_(size)(input, 3); - const float height_scale = (float) inputHeight / (float) outputHeight; - const float width_scale = (float) inputWidth / (float) outputWidth; - - THNN_(SpatialUpSamplingNearest_shapeCheck)(input, NULL, nbatch, channels, - inputHeight, inputWidth, outputHeight, outputWidth); - - THTensor_(resize4d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputHeight, - outputWidth); - channels = channels * nbatch; - - THAssert(inputWidth > 0 && outputWidth > 0); - - input = THTensor_(newContiguous)(input); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - - // special case: just copy - if (inputHeight == outputHeight && inputWidth == outputWidth) { - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputHeight * inputWidth; - pos2 += outputHeight * outputWidth; - } - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = nearest_neighbor_compute_source_index(height_scale, h2, inputHeight); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = nearest_neighbor_compute_source_index(width_scale, w2, inputWidth); - const scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputHeight * inputWidth; - pos2 += outputHeight * outputWidth; - } - } - } - c10::raw::intrusive_ptr::decref(input); + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(SpatialUpSamplingNearest_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputHeight, int inputWidth, int outputHeight, - int outputWidth) -{ - THNN_(SpatialUpSamplingNearest_shapeCheck)(NULL, gradOutput, nbatch, channels, - inputHeight, inputWidth, outputHeight, outputWidth); - THTensor_(resize4d)(gradInput, nbatch, channels, inputHeight, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *idata = gradInput->data(); - scalar_t *odata = gradOutput->data(); - channels = nbatch * channels; - const float height_scale = (float) inputHeight / (float)outputHeight; - const float width_scale = (float) inputWidth / (float)outputWidth; - // special case: just copy - if (inputHeight == outputHeight && inputWidth == outputWidth) { - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - const scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] = pos2[0]; - pos1 += inputHeight * inputWidth; - pos2 += outputHeight * outputWidth; - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = nearest_neighbor_compute_source_index(height_scale, h2, inputHeight); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = nearest_neighbor_compute_source_index(width_scale, w2, inputWidth); - scalar_t* pos1 = &idata[h1 * inputWidth + w1]; - const scalar_t* pos2 = &odata[h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputHeight * inputWidth; - pos2 += outputHeight * outputWidth; - } - } - } - - c10::raw::intrusive_ptr::decref(gradOutput); + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/TemporalUpSamplingLinear.c b/aten/src/THNN/generic/TemporalUpSamplingLinear.c index 56a2e13..6968054 100644 --- a/aten/src/THNN/generic/TemporalUpSamplingLinear.c +++ b/aten/src/THNN/generic/TemporalUpSamplingLinear.c @@ -5,143 +5,25 @@ #define TH_GENERIC_FILE "THNN/generic/TemporalUpSamplingLinear.c" #else -#include - -static inline void THNN_(TemporalUpSamplingLinear_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputWidth, int outputWidth) { - THArgCheck(inputWidth > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (W: %d) output (W: %d)", - inputWidth, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(!input->is_empty() && input->dim() == 3, 2, input, - "non-empty 3D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth); - } -} - void THNN_(TemporalUpSamplingLinear_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, + THNNState* state, + THTensor* input, + THTensor* output, int outputWidth, - bool align_corners){ - - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputWidth = THTensor_(size)(input, 2); - - THNN_(TemporalUpSamplingLinear_shapeCheck) - (input, NULL, - nbatch, channels, - inputWidth, outputWidth); - - input = THTensor_(newContiguous)(input); - THTensor_(resize3d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputWidth); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - channels = nbatch * channels; - THAssert(inputWidth > 0 && outputWidth > 0); - // special case: just copy - if (inputWidth == outputWidth) { - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[w1]; - scalar_t* pos2 = &odata[w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - const scalar_t* pos1 = &idata[w1]; - // index w2 is interpolated by idata[w1] and (itself or idata[w1 + 1]) - scalar_t* pos2 = &odata[w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = w0lambda * pos1[0] + w1lambda * pos1[w1p]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(input); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(TemporalUpSamplingLinear_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputWidth, int outputWidth, - bool align_corners){ - - THNN_(TemporalUpSamplingLinear_shapeCheck) - (NULL, gradOutput, - nbatch, channels, - inputWidth, - outputWidth); - - THTensor_(resize3d)(gradInput, nbatch, channels, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *data1 = gradInput->data(); - scalar_t *data2 = gradOutput->data(); - channels = nbatch * channels; - - // special case: same-size matching grids - if (inputWidth == outputWidth) { - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &data1[w1]; - const scalar_t* pos2 = &data2[w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - scalar_t* pos1 = &data1[w1]; - const scalar_t* pos2 = &data2[w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += w0lambda * pos2[0]; - pos1[w1p] += w1lambda * pos2[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(gradOutput); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/TemporalUpSamplingNearest.c b/aten/src/THNN/generic/TemporalUpSamplingNearest.c index 8251266..0d3fca5 100644 --- a/aten/src/THNN/generic/TemporalUpSamplingNearest.c +++ b/aten/src/THNN/generic/TemporalUpSamplingNearest.c @@ -2,129 +2,23 @@ #define TH_GENERIC_FILE "THNN/generic/TemporalUpSamplingNearest.c" #else -#include - -static inline void THNN_(TemporalUpSamplingNearest_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputWidth, int outputWidth) { - THArgCheck(inputWidth > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (W: %d) output (W: %d)", - inputWidth, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 3, 2, input, - "3D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth); - } -} - void THNN_(TemporalUpSamplingNearest_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int outputWidth) -{ - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputWidth = THTensor_(size)(input, 2); - const float scale = (float) inputWidth / (float)outputWidth; - - THNN_(TemporalUpSamplingNearest_shapeCheck)(input, NULL, nbatch, channels, inputWidth, outputWidth); - - THTensor_(resize3d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputWidth); - channels = channels * nbatch; - - THAssert(inputWidth > 0 && outputWidth > 0); - - input = THTensor_(newContiguous)(input); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - - // special case: just copy - if (inputWidth == outputWidth) { - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[w1]; - scalar_t* pos2 = &odata[w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal src_x = nearest_neighbor_compute_source_index(scale, w2, inputWidth); - const int w1 = src_x; - const scalar_t* pos1 = &idata[w1]; - scalar_t* pos2 = &odata[w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(input); + THNNState* state, + THTensor* input, + THTensor* output, + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(TemporalUpSamplingNearest_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputWidth, - int outputWidth) -{ - THNN_(TemporalUpSamplingNearest_shapeCheck)(NULL, gradOutput, nbatch, channels, inputWidth, outputWidth); - THTensor_(resize3d)(gradInput, nbatch, channels, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *data1 = gradInput->data(); - scalar_t *data2 = gradOutput->data(); - channels = nbatch * channels; - const float scale = (float) inputWidth / (float)outputWidth; - - // special case: same-size matching grids - if (inputWidth == outputWidth) { - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &data1[w1]; - const scalar_t* pos2 = &data2[w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = nearest_neighbor_compute_source_index(scale, w2, inputWidth); - scalar_t* pos1 = &data1[w1]; - const scalar_t* pos2 = &data2[w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputWidth; - pos2 += outputWidth; - } - } - c10::raw::intrusive_ptr::decref(gradOutput); + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/VolumetricUpSamplingNearest.c b/aten/src/THNN/generic/VolumetricUpSamplingNearest.c index 3e83167..f4ff944 100644 --- a/aten/src/THNN/generic/VolumetricUpSamplingNearest.c +++ b/aten/src/THNN/generic/VolumetricUpSamplingNearest.c @@ -2,112 +2,20 @@ #define TH_GENERIC_FILE "THNN/generic/VolumetricUpSamplingNearest.c" #else -#include - -static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputDepth, int inputHeight, int inputWidth, - int outputDepth, int outputHeight, int outputWidth) { - THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 - && outputDepth > 0 && outputHeight > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)", - inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 5, 2, input, - "5D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 5, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 2, outputDepth); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 3, outputHeight); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 4, outputWidth); - } -} - - void THNN_(VolumetricUpSamplingNearest_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, + THNNState* state, + THTensor* input, + THTensor* output, int outputDepth, int outputHeight, - int outputWidth) -{ - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputDepth = THTensor_(size)(input, 2); - int inputHeight = THTensor_(size)(input, 3); - int inputWidth = THTensor_(size)(input, 4); - const float depth_scale = (float) inputDepth / (float) outputDepth; - const float height_scale = (float) inputHeight / (float)outputHeight; - const float width_scale = (float) inputWidth / (float)outputWidth; - - THNN_(VolumetricUpSamplingNearest_shapeCheck)(input, NULL, nbatch, channels, inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth); - - THTensor_(resize5d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputDepth, - outputHeight, - outputWidth); - channels = channels * nbatch; - - THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 && outputDepth > 0 && outputHeight > 0 && outputWidth > 0); - - input = THTensor_(newContiguous)(input); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - - // special case: just copy - if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) { - for (int d2 = 0; d2 < outputDepth; ++d2) { - const int d1 = d2; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[d1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[d2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputDepth * inputHeight * inputWidth; - pos2 += outputDepth * outputHeight * outputWidth; - } - } - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - - for (int d2 = 0; d2 < outputDepth; ++d2) { - const int d1 = nearest_neighbor_compute_source_index(depth_scale, d2, inputDepth); - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = nearest_neighbor_compute_source_index(height_scale, h2, inputHeight); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = nearest_neighbor_compute_source_index(width_scale, w2, inputWidth); - const scalar_t* pos1 = &idata[d1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[d2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputDepth * inputHeight * inputWidth; - pos2 += outputDepth * outputHeight * outputWidth; - } - } - } - } - c10::raw::intrusive_ptr::decref(input); + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(VolumetricUpSamplingNearest_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputDepth, @@ -115,59 +23,8 @@ void THNN_(VolumetricUpSamplingNearest_updateGradInput)( int inputWidth, int outputDepth, int outputHeight, - int outputWidth) -{ - THNN_(VolumetricUpSamplingNearest_shapeCheck)(NULL, gradOutput, nbatch, channels, inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth); - THTensor_(resize5d)(gradInput, nbatch, channels, inputDepth, inputHeight, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *idata = gradInput->data(); - scalar_t *odata = gradOutput->data(); - channels = nbatch * channels; - const float depth_scale = (float) inputDepth / (float) outputDepth; - const float height_scale = (float) inputHeight / (float)outputHeight; - const float width_scale = (float) inputWidth / (float)outputWidth; - - // special case: just copy - if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) { - for (int d2 = 0; d2 < outputDepth; ++d2) { - const int d1 = d2; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &idata[d1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - const scalar_t* pos2 = &odata[d2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputDepth * inputHeight * inputWidth; - pos2 += outputDepth * outputHeight * outputWidth; - } - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - - for (int d2 = 0; d2 < outputDepth; ++d2) { - const int d1 = nearest_neighbor_compute_source_index(depth_scale, d2, inputDepth); - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = nearest_neighbor_compute_source_index(height_scale, h2, inputHeight); - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = nearest_neighbor_compute_source_index(width_scale, w2, inputWidth); - scalar_t* pos1 = &idata[d1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - const scalar_t* pos2 = &odata[d2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputDepth * inputHeight * inputWidth; - pos2 += outputDepth * outputHeight * outputWidth; - } - } - } - } - - c10::raw::intrusive_ptr::decref(gradOutput); + int outputWidth) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/VolumetricUpSamplingTrilinear.c b/aten/src/THNN/generic/VolumetricUpSamplingTrilinear.c index 8885cc8..4eaa507 100644 --- a/aten/src/THNN/generic/VolumetricUpSamplingTrilinear.c +++ b/aten/src/THNN/generic/VolumetricUpSamplingTrilinear.c @@ -5,132 +5,21 @@ #define TH_GENERIC_FILE "THNN/generic/VolumetricUpSamplingTrilinear.c" #else -#include - -static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck) - (THTensor *input, THTensor *gradOutput, - int nBatch, int nChannels, - int inputDepth, int inputHeight, int inputWidth, - int outputDepth, int outputHeight, int outputWidth) { - THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 - && outputDepth > 0 && outputHeight > 0 && outputWidth > 0, 2, - "input and output sizes should be greater than 0," - " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)", - inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth); - if (input != NULL) { - THNN_ARGCHECK(!input->is_empty() && input->dim() == 5, 2, input, - "non-empty 5D input tensor expected but got: %s"); - } - - if (gradOutput != NULL) { - THNN_CHECK_DIM_SIZE(gradOutput, 5, 0, nBatch); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 1, nChannels); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 2, outputDepth); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 3, outputHeight); - THNN_CHECK_DIM_SIZE(gradOutput, 5, 4, outputWidth); - } -} - void THNN_(VolumetricUpSamplingTrilinear_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, + THNNState* state, + THTensor* input, + THTensor* output, int outputDepth, int outputHeight, int outputWidth, - bool align_corners){ - - int nbatch = THTensor_(size)(input, 0); - int channels = THTensor_(size)(input, 1); - int inputDepth = THTensor_(size)(input, 2); - int inputHeight = THTensor_(size)(input, 3); - int inputWidth = THTensor_(size)(input, 4); - - THNN_(VolumetricUpSamplingTrilinear_shapeCheck) - (input, NULL, - nbatch, channels, - inputDepth, inputHeight, inputWidth, - outputDepth, outputHeight, outputWidth); - - input = THTensor_(newContiguous)(input); - THTensor_(resize5d)(output, - THTensor_(size)(input, 0), - THTensor_(size)(input, 1), - outputDepth, outputHeight, outputWidth); - THTensor_(zero)(output); - scalar_t *idata = input->data(); - scalar_t *odata = output->data(); - channels = nbatch * channels; - THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 && - outputDepth > 0 && outputHeight > 0 && outputWidth > 0); - // special case: just copy - if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) { - for (int t2 = 0; t2 < outputDepth; ++t2) { - const int t1 = t2; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - const scalar_t* pos1 = &idata[t1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[t2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = pos1[0]; - pos1 += inputWidth * inputHeight * inputDepth; - pos2 += outputWidth * outputHeight * outputDepth; - } - } - } - } - c10::raw::intrusive_ptr::decref(input); - return; - } - const accreal rdepth = linear_upsampling_compute_scale(inputDepth, outputDepth, align_corners); - const accreal rheight = linear_upsampling_compute_scale(inputHeight, outputHeight, align_corners); - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int t2 = 0; t2 < outputDepth; ++t2) { - const accreal t1r = linear_upsampling_compute_source_index(rdepth, t2, align_corners); - const int t1 = t1r; - const int t1p = (t1 < inputDepth - 1) ? 1 : 0; - const scalar_t t1lambda = t1r - t1; - const scalar_t t0lambda = (scalar_t)1. - t1lambda; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const accreal h1r = linear_upsampling_compute_source_index(rheight, h2, align_corners); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; - const scalar_t h1lambda = h1r - h1; - const scalar_t h0lambda = (scalar_t)1. - h1lambda; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - const scalar_t* pos1 = &idata[t1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - scalar_t* pos2 = &odata[t2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos2[0] = t0lambda * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) - + h1lambda * (w0lambda * pos1[h1p * inputWidth] - + w1lambda * pos1[h1p * inputWidth + w1p])) - + t1lambda * (h0lambda * (w0lambda * pos1[t1p * inputHeight * inputWidth] - + w1lambda * pos1[t1p * inputHeight * inputWidth - + w1p]) - + h1lambda * (w0lambda * pos1[t1p * inputHeight * inputWidth - + h1p * inputWidth] - + w1lambda * pos1[t1p * inputHeight * inputWidth - + h1p * inputWidth + w1p])); - pos1 += inputWidth * inputHeight * inputDepth; - pos2 += outputWidth * outputHeight * outputDepth; - } - } - } - } - c10::raw::intrusive_ptr::decref(input); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, + THNNState* state, + THTensor* gradOutput, + THTensor* gradInput, int nbatch, int channels, int inputDepth, @@ -139,81 +28,8 @@ void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)( int outputDepth, int outputHeight, int outputWidth, - bool align_corners){ - - THNN_(VolumetricUpSamplingTrilinear_shapeCheck) - (NULL, gradOutput, - nbatch, channels, - inputDepth, inputHeight, inputWidth, - outputDepth, outputHeight, outputWidth); - - THTensor_(resize5d)(gradInput, nbatch, channels, inputDepth, inputHeight, inputWidth); - THTensor_(zero)(gradInput); - gradOutput = THTensor_(newContiguous)(gradOutput); - scalar_t *data1 = gradInput->data(); - scalar_t *data2 = gradOutput->data(); - channels = nbatch * channels; - - // special case: same-size matching grids - if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) { - for (int t2 = 0; t2 < outputDepth; ++t2) { - const int t1 = t2; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const int h1 = h2; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const int w1 = w2; - scalar_t* pos1 = &data1[t1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - const scalar_t* pos2 = &data2[t2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += pos2[0]; - pos1 += inputWidth * inputHeight * inputDepth; - pos2 += outputWidth * outputHeight * outputDepth; - } - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); - return; - } - const accreal rdepth = linear_upsampling_compute_scale(inputDepth, outputDepth, align_corners); - const accreal rheight = linear_upsampling_compute_scale(inputHeight, outputHeight, align_corners); - const accreal rwidth = linear_upsampling_compute_scale(inputWidth, outputWidth, align_corners); - for (int t2 = 0; t2 < outputDepth; ++t2) { - const accreal t1r = linear_upsampling_compute_source_index(rdepth, t2, align_corners); - const int t1 = t1r; - const int t1p = (t1 < inputDepth - 1) ? 1 : 0; - const scalar_t t1lambda = t1r - t1; - const scalar_t t0lambda = (scalar_t)1. - t1lambda; - for (int h2 = 0; h2 < outputHeight; ++h2) { - const accreal h1r = linear_upsampling_compute_source_index(rheight, h2, align_corners); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; - const scalar_t h1lambda = h1r - h1; - const scalar_t h0lambda = (scalar_t)1. - h1lambda; - for (int w2 = 0; w2 < outputWidth; ++w2) { - const accreal w1r = linear_upsampling_compute_source_index(rwidth, w2, align_corners); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; - const scalar_t w1lambda = w1r - w1; - const scalar_t w0lambda = (scalar_t)1. - w1lambda; - scalar_t* pos1 = &data1[t1 * inputHeight * inputWidth + h1 * inputWidth + w1]; - const scalar_t* pos2 = &data2[t2 * outputHeight * outputWidth + h2 * outputWidth + w2]; - for (int c = 0; c < channels; ++c) { - pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0]; - pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0]; - pos1[h1p * inputWidth] += t0lambda * h1lambda * w0lambda * pos2[0]; - pos1[h1p * inputWidth + w1p] += t0lambda * h1lambda * w1lambda * pos2[0]; - pos1[t1p * inputHeight * inputWidth] += t1lambda * h0lambda * w0lambda * pos2[0]; - pos1[t1p * inputHeight * inputWidth + w1p] += t1lambda * h0lambda * w1lambda * pos2[0]; - pos1[t1p * inputHeight * inputWidth + h1p * inputWidth] += t1lambda * h1lambda * w0lambda * pos2[0]; - pos1[t1p * inputHeight * inputWidth + h1p * inputWidth + w1p] += t1lambda * h1lambda * w1lambda * pos2[0]; - pos1 += inputWidth * inputHeight * inputDepth; - pos2 += outputWidth * outputHeight * outputDepth; - } - } - } - } - c10::raw::intrusive_ptr::decref(gradOutput); + bool align_corners) { + AT_ERROR("This function is deprecated, please use it from ATen."); } #endif diff --git a/aten/src/THNN/generic/upsampling.h b/aten/src/THNN/generic/upsampling.h deleted file mode 100644 index 22898c0..0000000 --- a/aten/src/THNN/generic/upsampling.h +++ /dev/null @@ -1,111 +0,0 @@ -#ifndef THNN_UPSAMPLING_H -#define THNN_UPSAMPLING_H - -#undef MIN -#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) ) -#undef MAX -#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) ) - -template -static inline T linear_upsampling_compute_scale( - int inputSize, int outputSize, bool align_corners) { - /* We view each pixel as an area, idx + 0.5 as its center index. - * Here is an example formula in 1D case. - * if align_corners: center of two corner pixel areas are preserved, - * (0.5, 0.5) -> (0.5, 0.5), - * (inputSize - 0.5, 0.5) -> (outputSize - 0.5) - * scale = (inputSize - 0.5 - 0.5) / (outputSize - 0.5 - 0.5) - * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5) - * if not align_corners: the whole range is scaled accordingly - * scale = inputSize / outputSize - * src_idx + 0.5 = scale * (dst_index + 0.5) - */ - if (outputSize > 1) { - return align_corners ? (T) (inputSize - 1) / (outputSize - 1) - : (T) inputSize / outputSize; - } else { - return T(0); - } -} - -template -static inline T linear_upsampling_compute_source_index( - T scale, int dst_index, bool align_corners) { - if (align_corners) { - return scale * dst_index; - } else { - T src_idx = scale * (dst_index + 0.5) - 0.5; - return src_idx < 0 ? T(0) : src_idx; - } -} - -static inline int nearest_neighbor_compute_source_index( - const float scale, int dst_index, int inputSize) { - const int src_index = MIN(floorf(dst_index * scale), inputSize - 1); - return src_index; -} - -template -static T upsampling_get_value_bounded(T* data, int width, int height, int x, int y) { - int access_x = std::max(std::min(x, width - 1), 0); - int access_y = std::max(std::min(y, height - 1), 0); - return data[access_y * width + access_x]; -} - -template -static void upsampling_increment_value_bounded( - T* data, - int width, - int height, - int x, - int y, - T value -) { - int access_x = std::max(std::min(x, width - 1), 0); - int access_y = std::max(std::min(y, height - 1), 0); - data[access_y * width + access_x] += value; -} - -// Based on https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -template -static inline T cubic_convolution1(T x, T A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} - -template -static inline T cubic_convolution2(T x, T A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -template -static inline void get_cubic_upsampling_coefficients(T coeffs[4], T t) { - T A = -0.75; - - T x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - T x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); -} - -template -static inline T cubic_interp1d( - T x0, - T x1, - T x2, - T x3, - T t -) { - T coeffs[4]; - get_cubic_upsampling_coefficients(coeffs, t); - - return x0 * coeffs[0] - + x1 * coeffs[1] - + x2 * coeffs[2] - + x3 * coeffs[3]; -} - -#endif -- 2.7.4