Bicubic interpolation for nn.functional.interpolate (#9849)
authorDavid Riazati <davidriazati@fb.com>
Mon, 17 Dec 2018 23:22:07 +0000 (15:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 23:31:48 +0000 (15:31 -0800)
Summary:
Addresses #918, interpolation results should be similar to tf

* Adds bicubic interpolation operator to `nn.functional.interpolate`
* Corresponding test in `test_nn.py`

The operator is added in legacy `TH` to be aligned with the other upsampling operators; they can be refactored/moved to ATen all at once when #10482 is resolved
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9849

Differential Revision: D9007525

Pulled By: driazati

fbshipit-source-id: 93ef49a34ce4e5ffd4bda94cd9a6ddc939f0a4cc

37 files changed:
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/LegacyNNDefinitions.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/nn.yaml
aten/src/THCUNN/CMakeLists.txt
aten/src/THCUNN/SpatialUpSamplingBicubic.cu [new file with mode: 0644]
aten/src/THCUNN/SpatialUpSamplingBilinear.cu
aten/src/THCUNN/SpatialUpSamplingNearest.cu
aten/src/THCUNN/TemporalUpSamplingLinear.cu
aten/src/THCUNN/TemporalUpSamplingNearest.cu
aten/src/THCUNN/VolumetricUpSamplingNearest.cu
aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu
aten/src/THCUNN/generic/SpatialUpSamplingBicubic.cu [new file with mode: 0644]
aten/src/THCUNN/generic/SpatialUpSamplingBilinear.cu
aten/src/THCUNN/generic/THCUNN.h
aten/src/THCUNN/generic/TemporalUpSamplingLinear.cu
aten/src/THCUNN/generic/VolumetricUpSamplingTrilinear.cu
aten/src/THCUNN/linear_upsampling.h [deleted file]
aten/src/THCUNN/upsampling.h [new file with mode: 0644]
aten/src/THNN/generic/SpatialUpSamplingBicubic.c [new file with mode: 0644]
aten/src/THNN/generic/SpatialUpSamplingBilinear.c
aten/src/THNN/generic/SpatialUpSamplingNearest.c
aten/src/THNN/generic/THNN.h
aten/src/THNN/generic/TemporalUpSamplingLinear.c
aten/src/THNN/generic/TemporalUpSamplingNearest.c
aten/src/THNN/generic/VolumetricUpSamplingNearest.c
aten/src/THNN/generic/VolumetricUpSamplingTrilinear.c
aten/src/THNN/generic/linear_upsampling.h [deleted file]
aten/src/THNN/generic/upsampling.h [new file with mode: 0644]
aten/src/THNN/init.cpp
test/common_nn.py
test/test_nn.py
tools/autograd/derivatives.yaml
torch/csrc/jit/register_prim_ops.cpp
torch/cuda/__init__.py
torch/nn/functional.py
torch/nn/modules/upsampling.py

index 486d5cf..951867d 100644 (file)
@@ -683,6 +683,9 @@ _(aten, unsqueeze) \
 _(aten, upsample_bilinear2d) \
 _(aten, upsample_bilinear2d_backward) \
 _(aten, upsample_bilinear2d_forward) \
+_(aten, upsample_bicubic2d) \
+_(aten, upsample_bicubic2d_backward) \
+_(aten, upsample_bicubic2d_forward) \
 _(aten, upsample_linear1d) \
 _(aten, upsample_linear1d_backward) \
 _(aten, upsample_linear1d_forward) \
index 2c02746..7637c99 100644 (file)
@@ -604,6 +604,22 @@ Tensor upsample_bilinear2d_backward(const Tensor & grad_output, IntList output_s
   return at::legacy::th::_thnn_upsample_bilinear2d_backward(grad_output, output_size, input_size, align_corners);
 }
 
+Tensor & upsample_bicubic2d_out(Tensor & output, const Tensor & self, IntList output_size, bool align_corners) {
+  return at::legacy::th::_thnn_upsample_bicubic2d_forward_out(output, self, output_size, align_corners);
+}
+
+Tensor upsample_bicubic2d(const Tensor & self, IntList output_size, bool align_corners) {
+  return at::legacy::th::_thnn_upsample_bicubic2d_forward(self, output_size, align_corners);
+}
+
+Tensor & upsample_bicubic2d_backward_out(Tensor & grad_input, const Tensor & grad_output, IntList output_size, IntList input_size, bool align_corners) {
+  return at::legacy::th::_thnn_upsample_bicubic2d_backward_out(grad_input, grad_output, output_size, input_size, align_corners);
+}
+
+Tensor upsample_bicubic2d_backward(const Tensor & grad_output, IntList output_size, IntList input_size, bool align_corners) {
+  return at::legacy::th::_thnn_upsample_bicubic2d_backward(grad_output, output_size, input_size, align_corners);
+}
+
 Tensor & upsample_trilinear3d_out(Tensor & output, const Tensor & self, IntList output_size, bool align_corners) {
   return at::legacy::th::_thnn_upsample_trilinear3d_forward_out(output, self, output_size, align_corners);
 }
index 5729e0d..5979229 100644 (file)
 - func: upsample_bilinear2d_backward(Tensor grad_output, IntList[2] output_size, IntList[4] input_size, bool align_corners) -> Tensor
   python_module: nn
 
+- func: upsample_bicubic2d_out(Tensor output, Tensor self, IntList[2] output_size, bool align_corners) -> Tensor
+  python_module: nn
+
+- func: upsample_bicubic2d(Tensor self, IntList[2] output_size, bool align_corners) -> Tensor
+  python_module: nn
+
+- func: upsample_bicubic2d_backward_out(Tensor grad_input, Tensor grad_output, IntList[2] output_size, IntList[4] input_size, bool align_corners) -> Tensor
+  python_module: nn
+
+- func: upsample_bicubic2d_backward(Tensor grad_output, IntList[2] output_size, IntList[4] input_size, bool align_corners) -> Tensor
+  python_module: nn
+
 - func: upsample_trilinear3d_out(Tensor output, Tensor self, IntList[3] output_size, bool align_corners) -> Tensor
   python_module: nn
 
index 04b72b8..b99d616 100644 (file)
     self: 'false'
     grad_input: 'false'
 
+- name: _thnn_upsample_bicubic2d(Tensor self, IntList[2] output_size, bool align_corners)
+  cname: SpatialUpSamplingBicubic
+  scalar_check:
+    grad_input: 'false'
+
 - name: _thnn_upsample_trilinear3d(Tensor self, IntList[3] output_size, bool align_corners)
   cname: VolumetricUpSamplingTrilinear
   scalar_check:
index 3c09b77..237a3ff 100644 (file)
@@ -44,6 +44,7 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxUnpooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialReflectionPadding.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialReplicationPadding.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialSubSampling.cu
+${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBicubic.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingBilinear.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialUpSamplingNearest.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/Sqrt.cu
diff --git a/aten/src/THCUNN/SpatialUpSamplingBicubic.cu b/aten/src/THCUNN/SpatialUpSamplingBicubic.cu
new file mode 100644 (file)
index 0000000..cbd8cbb
--- /dev/null
@@ -0,0 +1,157 @@
+#include <THCUNN/THCUNN.h>
+#include <THC/THCTensor.hpp>
+#include <THCUNN/common.h>
+#include <THCUNN/upsampling.h>
+#include <THC/THCDeviceTensor.cuh>
+#include <THC/THCDeviceTensorUtils.cuh>
+#include <THC/THCDeviceUtils.cuh>
+#include <TH/THHalf.h>
+#include <THCUNN/THCHalfAutoNumerics.cuh>
+#include <THC/THCAtomics.cuh>
+
+template<typename Dtype, typename Acctype>
+__global__ void bicubic_interp2d_kernel(
+  const int num_elements,
+  const Acctype height_scale,
+  const Acctype width_scale,
+  const THCDeviceTensor<Dtype, 4> in_data,
+  THCDeviceTensor<Dtype, 4> out_data
+) {
+
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = in_data.getSize(0);
+  const int channels = in_data.getSize(1);
+  const int input_height = in_data.getSize(2);
+  const int input_width = in_data.getSize(3);
+  const int output_height = out_data.getSize(2);
+  const int output_width = out_data.getSize(3);
+
+  if (index >= num_elements) {
+    return;
+  }
+
+  // Special case: input and output are the same size, just copy
+  const int output_x = index % output_width;
+  const int output_y = index / output_width;
+  if (input_height == output_height && input_width == output_width) {
+    for (int n = 0; n < batchsize; n++){
+      for (int c = 0; c < channels; c++) {
+        const Dtype val = in_data[n][c][output_y][output_x];
+        out_data[n][c][output_x][output_y] = val;
+      }
+    }
+    return;
+  }
+
+  // Interpolation kernel
+  Acctype real_x = width_scale * output_x;
+  int in_x = real_x;
+  Acctype t_x = real_x - in_x;
+
+  Acctype real_y = height_scale * output_y;
+  int in_y = real_y;
+  Acctype t_y = real_y - in_y;
+
+  for (int n = 0; n < batchsize ; n++) {
+    for (int c = 0; c < channels; c++) {
+      Acctype coefficients[4];
+
+      for (int k = 0; k < 4; k++) {
+        coefficients[k] = cubic_interp1d(
+          upsampling_get_value_bounded<Dtype>(
+            in_data, c, n, input_width, input_height, in_x - 1, in_y - 1 + k),
+          upsampling_get_value_bounded<Dtype>(
+            in_data, c, n, input_width, input_height, in_x + 0, in_y - 1 + k),
+          upsampling_get_value_bounded<Dtype>(
+            in_data, c, n, input_width, input_height, in_x + 1, in_y - 1 + k),
+          upsampling_get_value_bounded<Dtype>(
+            in_data, c, n, input_width, input_height, in_x + 2, in_y - 1 + k),
+          t_x
+        );
+      }
+
+      out_data[n][c][output_y][output_x] = ScalarConvert<Acctype, Dtype>::to(cubic_interp1d(
+        coefficients[0],
+        coefficients[1],
+        coefficients[2],
+        coefficients[3],
+        t_y
+      ));
+    }
+  }
+}
+
+// Backward (adjoint) operation 1 <- 2 (accumulates)
+template <typename Dtype, typename Acctype>
+__global__ void bicubic_interp2d_backward_kernel(
+  const int num_elements,
+  const Acctype height_scale,
+  const Acctype width_scale,
+  const bool align_corners,
+  THCDeviceTensor<Dtype, 4> in_data,
+  const THCDeviceTensor<Dtype, 4> out_data
+){
+
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = in_data.getSize(0);
+  const int channels = in_data.getSize(1);
+  const int input_height = in_data.getSize(2);
+  const int input_width = in_data.getSize(3);
+  const int output_height = out_data.getSize(2);
+  const int output_width = out_data.getSize(3);
+
+  if (index >= num_elements) {
+    return;
+  }
+
+  const int output_x = index % output_width;
+  const int output_y = index / output_width;
+  // special case: output_xust copy
+  if (input_height == output_height && input_width == output_width) {
+    for (int n = 0; n < batchsize ; n++){
+      for (int c = 0; c < channels; ++c) {
+        const Dtype val = out_data[n][c][output_y][output_x];
+        in_data[n][c][output_y][output_x] += val;
+      }
+    }
+    return;
+  }
+
+  Acctype real_x = width_scale * output_x;
+  int input_x = real_x;
+  Acctype t_x = real_x - input_x;
+
+  Acctype real_y = height_scale * output_y;
+  int input_y = real_y;
+  Acctype t_y = real_y - input_y;
+
+  Acctype x_coeffs[4];
+  Acctype y_coeffs[4];
+
+  get_cubic_upsampling_coefficients(x_coeffs, t_x);
+  get_cubic_upsampling_coefficients(y_coeffs, t_y);
+
+  for (int n = 0; n < batchsize ; n++){
+    for (int c = 0; c < channels; ++c) {
+      Dtype out_value = out_data[n][c][output_y][output_x];
+      for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 4; j++) {
+          upsampling_increment_value_bounded<Dtype, Acctype>(
+            in_data,
+            c,
+            n,
+            input_width,
+            input_height,
+            input_x - 1 + j,
+            input_y - 1 + i,
+            out_value * y_coeffs[i] * x_coeffs[j]
+          );
+        }
+      }
+    }
+  }
+}
+
+
+#include <THCUNN/generic/SpatialUpSamplingBicubic.cu>
+#include <THC/THCGenerateFloatTypes.h>
index 240b0c5..b88f03c 100644 (file)
@@ -3,7 +3,7 @@
 #include <THCUNN/THCUNN.h>
 #include <THC/THCTensor.hpp>
 #include <THCUNN/common.h>
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
index a24de9e..e4c4dc2 100644 (file)
@@ -2,7 +2,7 @@
 #include <THCUNN/common.h>
 #include <THC/THCTensor.hpp>
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
index 6c936dc..18a65b1 100644 (file)
@@ -3,7 +3,7 @@
 #include <THCUNN/THCUNN.h>
 #include <THC/THCTensor.hpp>
 #include <THCUNN/common.h>
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
index 39280b4..e1f8485 100644 (file)
@@ -2,7 +2,7 @@
 #include <THCUNN/common.h>
 #include <THC/THCTensor.hpp>
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
index e82b2aa..f764386 100644 (file)
@@ -2,7 +2,7 @@
 #include <THCUNN/common.h>
 #include <THC/THCTensor.hpp>
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
index 1005e92..da159bd 100644 (file)
@@ -3,7 +3,7 @@
 #include <THCUNN/THCUNN.h>
 #include <THC/THCTensor.hpp>
 #include <THCUNN/common.h>
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 #include <THC/THCDeviceTensor.cuh>
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
diff --git a/aten/src/THCUNN/generic/SpatialUpSamplingBicubic.cu b/aten/src/THCUNN/generic/SpatialUpSamplingBicubic.cu
new file mode 100644 (file)
index 0000000..d74e85e
--- /dev/null
@@ -0,0 +1,114 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "THCUNN/generic/SpatialUpSamplingBicubic.cu"
+#else
+
+#include <THCUNN/upsampling.h>
+
+static inline void THNN_(SpatialUpSamplingBicubic_shapeCheck)
+                        (THCState *state,
+                         THCTensor *input, THCTensor *gradOutput,
+                         int nBatch, int nChannels,
+                         int inputHeight, int inputWidth,
+                         int outputHeight, int outputWidth) {
+  THArgCheck(inputHeight > 0 && inputWidth > 0
+             && outputHeight > 0 && outputWidth > 0, 2,
+             "input and output sizes should be greater than 0,"
+             " but got input (H: %d, W: %d) output (H: %d, W: %d)",
+             inputHeight, inputWidth, outputHeight, outputWidth);
+  if (input != NULL) {
+     THCUNN_argCheck(state, !input->is_empty() && input->dim() == 4, 2, input,
+                     "non-empty 4D input tensor expected but got: %s");
+  }
+
+  if (gradOutput != NULL) {
+    THCUNN_check_dim_size(state, gradOutput, 4, 0, nBatch);
+    THCUNN_check_dim_size(state, gradOutput, 4, 1, nChannels);
+    THCUNN_check_dim_size(state, gradOutput, 4, 2, outputHeight);
+    THCUNN_check_dim_size(state, gradOutput, 4, 3, outputWidth);
+  }
+}
+
+void THNN_(SpatialUpSamplingBicubic_updateOutput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *output,
+           int outputHeight,
+           int outputWidth,
+           bool align_corners)
+{
+  int nbatch = THCTensor_(size)(state, input, 0);
+  int channels = THCTensor_(size)(state, input, 1);
+  int inputHeight = THCTensor_(size)(state, input, 2);
+  int inputWidth = THCTensor_(size)(state, input, 3);
+  THNN_(SpatialUpSamplingBicubic_shapeCheck)
+       (state, input, NULL,
+        nbatch, channels,
+        inputHeight, inputWidth,
+        outputHeight, outputWidth);
+
+  THCUNN_assertSameGPU(state, 2, input, output);
+  THCTensor_(resize4d)(state, output,
+                       THCTensor_(size)(state, input, 0),
+                       THCTensor_(size)(state, input, 1),
+                       outputHeight, outputWidth);
+  THCTensor_(zero)(state, output);
+  THCDeviceTensor<scalar_t, 4> idata = toDeviceTensor<scalar_t, 4>(state, input);
+  THCDeviceTensor<scalar_t, 4> odata = toDeviceTensor<scalar_t, 4>(state, output);
+  THAssert(inputHeight > 0 && inputWidth > 0 && outputHeight > 0 && outputWidth > 0);
+
+  // Get scaling factors
+  const accreal rheight = linear_upsampling_compute_scale<accreal>(inputHeight, outputHeight, align_corners);
+  const accreal rwidth = linear_upsampling_compute_scale<accreal>(inputWidth, outputWidth, align_corners);
+
+  const int num_output_elements = outputHeight * outputWidth;
+  const int max_threads =
+    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+
+  // Launch kernel
+  cudaStream_t stream = THCState_getCurrentStream(state);
+  bicubic_interp2d_kernel<scalar_t, accreal> <<<
+    THCCeilDiv(num_output_elements, max_threads),
+    max_threads,
+    0,
+    stream
+  >>>(num_output_elements, rheight, rwidth, idata, odata);
+  THCudaCheck(cudaGetLastError());
+}
+
+
+void THNN_(SpatialUpSamplingBicubic_updateGradInput)(
+           THCState *state,
+           THCTensor *gradOutput,
+           THCTensor *gradInput,
+           int nbatch,
+           int nchannels,
+           int inputHeight,
+           int inputWidth,
+           int outputHeight,
+           int outputWidth,
+           bool align_corners)
+{
+  THNN_(SpatialUpSamplingBicubic_shapeCheck)
+       (state, NULL, gradOutput,
+        nbatch, nchannels,
+        inputHeight, inputWidth,
+        outputHeight, outputWidth);
+  gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+  THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+  THCTensor_(resize4d)(state, gradInput, nbatch, nchannels, inputHeight, inputWidth);
+  THCTensor_(zero)(state, gradInput);
+  THCDeviceTensor<scalar_t, 4> in_data = toDeviceTensor<scalar_t, 4>(state, gradInput);
+  THCDeviceTensor<scalar_t, 4> out_data = toDeviceTensor<scalar_t, 4>(state, gradOutput);
+  const accreal rheight = linear_upsampling_compute_scale<accreal>(inputHeight, outputHeight, align_corners);
+  const accreal rwidth = linear_upsampling_compute_scale<accreal>(inputWidth, outputWidth, align_corners);
+  const int num_kernels = outputHeight * outputWidth;
+  const int num_threads =
+    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+  cudaStream_t stream = THCState_getCurrentStream(state);
+  bicubic_interp2d_backward_kernel<scalar_t ,accreal> <<<THCCeilDiv(num_kernels, num_threads),
+  num_threads, 0, stream>>>(num_kernels, rheight, rwidth, align_corners, in_data, out_data);
+  THCudaCheck(cudaGetLastError());
+  THCTensor_(free)(state, gradOutput);
+}
+
+#endif
index 8b8e4e8..52be68f 100644 (file)
@@ -2,7 +2,7 @@
 #define THC_GENERIC_FILE "THCUNN/generic/SpatialUpSamplingBilinear.cu"
 #else
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 
 static inline void THNN_(SpatialUpSamplingBilinear_shapeCheck)
                         (THCState *state,
index b303a04..b91d941 100644 (file)
@@ -925,6 +925,26 @@ THC_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
                   int outputWidth,
                   bool align_corners);
 
+THC_API void THNN_(SpatialUpSamplingBicubic_updateOutput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *output,
+                  int outputHeight,
+                  int outputWidth,
+                  bool align_corners);
+
+THC_API void THNN_(SpatialUpSamplingBicubic_updateGradInput)(
+                  THCState *state,
+                  THCTensor *gradOutput,
+                  THCTensor *gradInput,
+                  int nbatch,
+                  int nchannels,
+                  int inputHeight,
+                  int inputWidth,
+                  int outputHeight,
+                  int outputWidth,
+                  bool align_corners);
+
 THC_API void THNN_(SpatialUpSamplingNearest_updateGradInput)(
                   THCState *state,
                   THCTensor *gradOutput,
index a5c57d6..0365559 100644 (file)
@@ -2,7 +2,7 @@
 #define THC_GENERIC_FILE "THCUNN/generic/TemporalUpSamplingLinear.cu"
 #else
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 
 static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
                         (THCState *state,
index d936d74..ef950f6 100644 (file)
@@ -2,7 +2,7 @@
 #define THC_GENERIC_FILE "THCUNN/generic/VolumetricUpSamplingTrilinear.cu"
 #else
 
-#include <THCUNN/linear_upsampling.h>
+#include <THCUNN/upsampling.h>
 
 static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
                         (THCState *state,
diff --git a/aten/src/THCUNN/linear_upsampling.h b/aten/src/THCUNN/linear_upsampling.h
deleted file mode 100644 (file)
index bd8a601..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-#ifndef THCUNN_LINEAR_UPSAMPLING_H
-#define THCUNN_LINEAR_UPSAMPLING_H
-
-#undef MIN
-#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) )
-#undef MAX
-#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) )
-
-
-template<typename Acctype>
-__host__ __forceinline__
-static Acctype linear_upsampling_compute_scale(
-                          int inputSize, int outputSize, bool align_corners) {
-  if (outputSize > 1) {
-    return align_corners ? (Acctype) (inputSize - 1) / (outputSize - 1)
-                         : (Acctype) inputSize / outputSize;
-  } else {
-    return Acctype(0);
-  }
-}
-
-template<typename Acctype>
-__device__ __forceinline__
-static Acctype linear_upsampling_compute_source_index(
-                          Acctype scale, int dst_index, bool align_corners) {
-  if (align_corners) {
-    return scale * dst_index;
-  } else {
-    Acctype src_idx = scale * (dst_index + Acctype(0.5)) - Acctype(0.5);
-    return src_idx < Acctype(0) ? Acctype(0) : src_idx;
-  }
-}
-
-__device__ __forceinline__
-static int nearest_neighbor_compute_source_index(
-               const float scale, int dst_index, int inputSize) {
-  const int src_index = MIN(floor(dst_index * scale), inputSize - 1);
-  return src_index;
-}
-#endif
-
diff --git a/aten/src/THCUNN/upsampling.h b/aten/src/THCUNN/upsampling.h
new file mode 100644 (file)
index 0000000..7a7c45d
--- /dev/null
@@ -0,0 +1,129 @@
+#ifndef THCUNN_UPSAMPLING_H
+#define THCUNN_UPSAMPLING_H
+
+#include <THC/THCDeviceTensor.cuh>
+#include <THC/THCAtomics.cuh>
+
+#undef MIN
+#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) )
+#undef MAX
+#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) )
+
+
+template<typename Acctype>
+__host__ __forceinline__
+static Acctype linear_upsampling_compute_scale(
+                          int inputSize, int outputSize, bool align_corners) {
+  if (outputSize > 1) {
+    return align_corners ? (Acctype) (inputSize - 1) / (outputSize - 1)
+                         : (Acctype) inputSize / outputSize;
+  } else {
+    return Acctype(0);
+  }
+}
+
+template<typename Acctype>
+__device__ __forceinline__
+static Acctype linear_upsampling_compute_source_index(
+                          Acctype scale, int dst_index, bool align_corners) {
+  if (align_corners) {
+    return scale * dst_index;
+  } else {
+    Acctype src_idx = scale * (dst_index + Acctype(0.5)) - Acctype(0.5);
+    return src_idx < Acctype(0) ? Acctype(0) : src_idx;
+  }
+}
+
+__device__ __forceinline__
+static int nearest_neighbor_compute_source_index(
+               const float scale, int dst_index, int inputSize) {
+  const int src_index = MIN(floor(dst_index * scale), inputSize - 1);
+  return src_index;
+}
+
+template<typename Dtype>
+__device__ __forceinline__
+static Dtype upsampling_get_value_bounded(
+  const THCDeviceTensor<Dtype, 4> data,
+  int channel,
+  int batch,
+  int width,
+  int height,
+  int x,
+  int y
+) {
+  int access_x = max(min(x, width - 1), 0);
+  int access_y = max(min(y, height - 1), 0);
+  return data[batch][channel][access_y][access_x];
+}
+
+template<typename Dtype, typename Acctype>
+__device__ __forceinline__
+static void upsampling_increment_value_bounded(
+  const THCDeviceTensor<Dtype, 4> data,
+  int channel,
+  int batch,
+  int width,
+  int height,
+  int x,
+  int y,
+  Acctype value
+) {
+  int access_x = max(min(x, width - 1), 0);
+  int access_y = max(min(y, height - 1), 0);
+  atomicAdd(
+    data[batch][channel][access_y][access_x].data(),
+    ScalarConvert<Acctype, Dtype>::to(value)
+  );
+}
+
+// Based on https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+template<typename Acctype>
+__device__ __forceinline__
+static Acctype cubic_convolution1(Acctype x, Acctype A) {
+  return ((A + 2) * x - (A + 3)) * x * x + 1;
+}
+
+template<typename Acctype>
+__device__ __forceinline__
+static Acctype cubic_convolution2(Acctype x, Acctype A) {
+  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
+}
+
+template<typename Acctype>
+__device__ __forceinline__
+static void get_cubic_upsampling_coefficients(
+  Acctype coeffs[4],
+  Acctype t
+) {
+  Acctype A = -0.75;
+
+  Acctype x1 = t;
+  coeffs[0] = cubic_convolution2<Acctype>(x1 + 1.0, A);
+  coeffs[1] = cubic_convolution1<Acctype>(x1, A);
+
+  // opposite coefficients
+  Acctype x2 = 1.0 - t;
+  coeffs[2] = cubic_convolution1<Acctype>(x2, A);
+  coeffs[3] = cubic_convolution2<Acctype>(x2 + 1.0, A);
+}
+
+template<typename Dtype, typename Acctype>
+__device__ __forceinline__
+static Acctype cubic_interp1d(
+  Dtype x0,
+  Dtype x1,
+  Dtype x2,
+  Dtype x3,
+  Acctype t
+) {
+  Acctype coeffs[4];
+  get_cubic_upsampling_coefficients<Acctype>(coeffs, t);
+
+  return x0 * coeffs[0]
+    + x1 * coeffs[1]
+    + x2 * coeffs[2]
+    + x3 * coeffs[3];
+}
+
+#endif
diff --git a/aten/src/THNN/generic/SpatialUpSamplingBicubic.c b/aten/src/THNN/generic/SpatialUpSamplingBicubic.c
new file mode 100644 (file)
index 0000000..5faf304
--- /dev/null
@@ -0,0 +1,226 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingBicubic.c"
+#else
+
+#include <THNN/generic/upsampling.h>
+
+static inline void THNN_(SpatialUpSamplingBicubic_shapeCheck)
+     (THTensor *input, THTensor *gradOutput,
+      int nBatch, int nChannels,
+      int input_height, int input_width,
+      int output_height, int output_width) {
+  THArgCheck(input_height > 0 && input_width > 0
+            && output_height > 0 && output_width > 0, 2,
+            "input and output sizes should be greater than 0,"
+            " but got input (H: %d, W: %d) output (H: %d, W: %d)",
+            input_height, input_width, output_height, output_width);
+  if (input != NULL) {
+    THNN_ARGCHECK(!input->is_empty() && input->dim() == 4, 2, input,
+                 "non-empty 4D input tensor expected but got: %s");
+  }
+
+  if (gradOutput != NULL) {
+    THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nBatch);
+    THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, nChannels);
+    THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, output_height);
+    THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, output_width);
+  }
+}
+
+void THNN_(SpatialUpSamplingBicubic_updateOutput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *output,
+    int output_height,
+    int output_width,
+    bool align_corners) {
+
+  const int nbatch = THTensor_(size)(input, 0);
+  const int channels = THTensor_(size)(input, 1);
+  const int input_height = THTensor_(size)(input, 2);
+  const int input_width = THTensor_(size)(input, 3);
+
+  THNN_(SpatialUpSamplingBicubic_shapeCheck)
+    (input, NULL,
+     nbatch, channels,
+     input_height, input_width,
+     output_height, output_width);
+
+  input = THTensor_(newContiguous)(input);
+  THTensor_(resize4d)(output,
+          THTensor_(size)(input, 0),
+          THTensor_(size)(input, 1),
+          output_height, output_width);
+  THTensor_(zero)(output);
+  scalar_t *idata = input->data<scalar_t>();
+  scalar_t *odata = output->data<scalar_t>();
+
+  // Special case: input/output same size, just copy
+  if (input_height == output_height && input_width == output_width) {
+    for (int output_y = 0; output_y < output_height; output_y++) {
+      for (int output_x = 0; output_x < output_width; output_x++) {
+        const scalar_t* in = &idata[output_y * input_width + output_x];
+        scalar_t* out = &odata[output_y * output_width + output_x];
+        for (int c = 0; c < channels; ++c) {
+          out[0] = in[0];
+          in += input_width * input_height;
+          out += output_width * output_height;
+        }
+      }
+    }
+    c10::raw::intrusive_ptr::decref(input);
+    return;
+  }
+
+  // Bicubic interpolation
+  const accreal height_scale = linear_upsampling_compute_scale<accreal>(
+    input_height,
+    output_height,
+    align_corners);
+  const accreal width_scale = linear_upsampling_compute_scale<accreal>(
+    input_width,
+    output_width,
+    align_corners);
+
+  for (int output_y = 0; output_y < output_height; output_y++) {
+    for (int output_x = 0; output_x < output_width; output_x++) {
+      scalar_t* in = idata;
+      scalar_t* out = odata;
+
+      const scalar_t real_x = width_scale * output_x;
+      int input_x = real_x;
+      const scalar_t t_x = real_x - input_x;
+
+      const scalar_t real_y = height_scale * output_y;
+      int input_y = real_y;
+      const scalar_t t_y = real_y - input_y;
+
+      for (int c = 0; c < channels * nbatch; c++) {
+        scalar_t coefficients[4];
+
+        // Interpolate 4 times in the x direction
+        for (int i = 0; i < 4; i++) {
+          coefficients[i] = cubic_interp1d<scalar_t>(
+            upsampling_get_value_bounded<scalar_t>(
+              in, input_width, input_height, input_x - 1, input_y - 1 + i),
+            upsampling_get_value_bounded<scalar_t>(
+              in, input_width, input_height, input_x + 0, input_y - 1 + i),
+            upsampling_get_value_bounded<scalar_t>(
+              in, input_width, input_height, input_x + 1, input_y - 1 + i),
+            upsampling_get_value_bounded<scalar_t>(
+              in, input_width, input_height, input_x + 2, input_y - 1 + i),
+            t_x
+          );
+        }
+
+        // Interpolate in the y direction using x interpolations
+        out[output_y * output_width + output_x] = cubic_interp1d<scalar_t>(
+          coefficients[0],
+          coefficients[1],
+          coefficients[2],
+          coefficients[3],
+          t_y
+        );
+
+        // Move to next channel
+        in += input_width * input_height;
+        out += output_width * output_height;
+      }
+    }
+  }
+
+  c10::raw::intrusive_ptr::decref(input);
+}
+
+void THNN_(SpatialUpSamplingBicubic_updateGradInput)(
+    THNNState *state,
+    THTensor *gradOutput,
+    THTensor *gradInput,
+    int nbatch,
+    int channels,
+    int input_height,
+    int input_width,
+    int output_height,
+    int output_width,
+    bool align_corners){
+
+  THNN_(SpatialUpSamplingBicubic_shapeCheck)
+    (NULL, gradOutput,
+     nbatch, channels,
+     input_height, input_width,
+     output_height, output_width);
+
+  THTensor_(resize4d)(gradInput, nbatch, channels, input_height, input_width);
+  THTensor_(zero)(gradInput);
+
+  gradOutput = THTensor_(newContiguous)(gradOutput);
+  scalar_t *idata = gradInput->data<scalar_t>();
+  scalar_t *odata = gradOutput->data<scalar_t>();
+  channels = nbatch * channels;
+
+  // Special case: input/output same size, just copy
+  if (input_height == output_height && input_width == output_width) {
+    for (int output_y = 0; output_y < output_height; output_y++) {
+      for (int output_x = 0; output_x < output_width; output_x++) {
+        scalar_t* in = &idata[output_y * input_width + output_x];
+        scalar_t* out = &odata[output_y * output_width + output_x];
+        for (int c = 0; c < channels; ++c) {
+          in[0] = out[0];
+          in += input_width * input_height;
+          out += output_width * output_height;
+        }
+      }
+    }
+    c10::raw::intrusive_ptr::decref(gradOutput);
+    return;
+  }
+
+  const accreal height_scale = linear_upsampling_compute_scale<accreal>(
+    input_height, output_height, align_corners);
+  const accreal width_scale = linear_upsampling_compute_scale<accreal>(
+    input_width, output_width, align_corners);
+
+  for (int output_y = 0; output_y < output_height; output_y++) {
+    for (int output_x = 0; output_x < output_width; output_x++) {
+      scalar_t* in = idata;
+      scalar_t* out = odata;
+
+      scalar_t real_x = width_scale * output_x;
+      int input_x = real_x;
+      scalar_t t_x = real_x - input_x;
+
+      scalar_t real_y = height_scale * output_y;
+      int input_y = real_y;
+      scalar_t t_y = real_y - input_y;
+
+      scalar_t x_coeffs[4];
+      scalar_t y_coeffs[4];
+
+      get_cubic_upsampling_coefficients<scalar_t>(x_coeffs, t_x);
+      get_cubic_upsampling_coefficients<scalar_t>(y_coeffs, t_y);
+
+
+      for (int c = 0; c < channels; c++) {
+        scalar_t out_value = out[output_y * output_width + output_x];
+
+        for (int i = 0; i < 4; i++) {
+          for (int j = 0; j < 4; j++) {
+            upsampling_increment_value_bounded<scalar_t>(in,
+              input_width,
+              input_height,
+              input_x - 1 + i,
+              input_y - 1 + j,
+              out_value * y_coeffs[j] * x_coeffs[i]);
+          }
+        }
+
+        in += input_width * input_height;
+        out += output_width * output_height;
+      }
+    }
+  }
+
+  c10::raw::intrusive_ptr::decref(gradOutput);
+}
+
+#endif
index 1fc37a2..647f52a 100644 (file)
@@ -5,7 +5,7 @@
 #define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingBilinear.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(SpatialUpSamplingBilinear_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
index 15271bc..8fd4973 100644 (file)
@@ -2,7 +2,7 @@
 #define TH_GENERIC_FILE "THNN/generic/SpatialUpSamplingNearest.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(SpatialUpSamplingNearest_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
index 006fc57..fd2b364 100644 (file)
@@ -706,6 +706,7 @@ TH_API void THNN_(SpatialUpSamplingBilinear_updateOutput)(
           int osizeH,
           int osizeW,
           bool align_corners);
+
 TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
           THNNState *state,
           THTensor *gradOutput,
@@ -718,6 +719,26 @@ TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
           int osizeW,
           bool align_corners);
 
+TH_API void THNN_(SpatialUpSamplingBicubic_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          int osizeH,
+          int osizeW,
+          bool align_corners);
+
+TH_API void THNN_(SpatialUpSamplingBicubic_updateGradInput)(
+          THNNState *state,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          int isizeB,
+          int isizeC,
+          int isizeH,
+          int isizeW,
+          int osizeH,
+          int osizeW,
+          bool align_corners);
+
 TH_API void THNN_(unfolded_acc)(
           THTensor *finput,
           THTensor *input,
index 5b1f6ea..56a2e13 100644 (file)
@@ -5,7 +5,7 @@
 #define TH_GENERIC_FILE "THNN/generic/TemporalUpSamplingLinear.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
index 2718760..8251266 100644 (file)
@@ -2,7 +2,7 @@
 #define TH_GENERIC_FILE "THNN/generic/TemporalUpSamplingNearest.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
index 7df2766..3e83167 100644 (file)
@@ -2,7 +2,7 @@
 #define TH_GENERIC_FILE "THNN/generic/VolumetricUpSamplingNearest.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
index 0ec50f4..8885cc8 100644 (file)
@@ -5,7 +5,7 @@
 #define TH_GENERIC_FILE "THNN/generic/VolumetricUpSamplingTrilinear.c"
 #else
 
-#include <THNN/generic/linear_upsampling.h>
+#include <THNN/generic/upsampling.h>
 
 static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
      (THTensor *input, THTensor *gradOutput,
diff --git a/aten/src/THNN/generic/linear_upsampling.h b/aten/src/THNN/generic/linear_upsampling.h
deleted file mode 100644 (file)
index 2873506..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-#ifndef THNN_LINEAR_UPSAMPLING_H
-#define THNN_LINEAR_UPSAMPLING_H
-
-#undef MIN
-#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) )
-#undef MAX
-#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) )
-
-
-template<typename T>
-static inline T linear_upsampling_compute_scale(
-                          int inputSize, int outputSize, bool align_corners) {
-  /* We view each pixel as an area, idx + 0.5 as its center index.
-   * Here is an example formula in 1D case.
-   * if align_corners: center of two corner pixel areas are preserved,
-   *     (0.5, 0.5) -> (0.5, 0.5),
-   *     (inputSize - 0.5, 0.5) -> (outputSize - 0.5)
-   *     scale = (inputSize - 0.5 - 0.5) / (outputSize - 0.5 - 0.5)
-   *     src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
-   * if not align_corners: the whole range is scaled accordingly
-   *     scale = inputSize / outputSize
-   *     src_idx + 0.5 = scale * (dst_index + 0.5)
-   */
-  if (outputSize > 1) {
-    return align_corners ? (T) (inputSize - 1) / (outputSize - 1)
-                         : (T) inputSize / outputSize;
-  } else {
-    return T(0);
-  }
-}
-
-template<typename T>
-static inline T linear_upsampling_compute_source_index(
-                          T scale, int dst_index, bool align_corners) {
-  if (align_corners) {
-    return scale * dst_index;
-  } else {
-    T src_idx = scale * (dst_index + 0.5) - 0.5;
-    return src_idx < 0 ? T(0) : src_idx;
-  }
-}
-
-static inline int nearest_neighbor_compute_source_index(
-               const float scale, int dst_index, int inputSize) {
-  const int src_index = MIN(floorf(dst_index * scale), inputSize - 1);
-  return src_index;
-}
-
-
-#endif
-
diff --git a/aten/src/THNN/generic/upsampling.h b/aten/src/THNN/generic/upsampling.h
new file mode 100644 (file)
index 0000000..22898c0
--- /dev/null
@@ -0,0 +1,111 @@
+#ifndef THNN_UPSAMPLING_H
+#define THNN_UPSAMPLING_H
+
+#undef MIN
+#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) )
+#undef MAX
+#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) )
+
+template<typename T>
+static inline T linear_upsampling_compute_scale(
+                          int inputSize, int outputSize, bool align_corners) {
+  /* We view each pixel as an area, idx + 0.5 as its center index.
+   * Here is an example formula in 1D case.
+   * if align_corners: center of two corner pixel areas are preserved,
+   *     (0.5, 0.5) -> (0.5, 0.5),
+   *     (inputSize - 0.5, 0.5) -> (outputSize - 0.5)
+   *     scale = (inputSize - 0.5 - 0.5) / (outputSize - 0.5 - 0.5)
+   *     src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
+   * if not align_corners: the whole range is scaled accordingly
+   *     scale = inputSize / outputSize
+   *     src_idx + 0.5 = scale * (dst_index + 0.5)
+   */
+  if (outputSize > 1) {
+    return align_corners ? (T) (inputSize - 1) / (outputSize - 1)
+                         : (T) inputSize / outputSize;
+  } else {
+    return T(0);
+  }
+}
+
+template<typename T>
+static inline T linear_upsampling_compute_source_index(
+                          T scale, int dst_index, bool align_corners) {
+  if (align_corners) {
+    return scale * dst_index;
+  } else {
+    T src_idx = scale * (dst_index + 0.5) - 0.5;
+    return src_idx < 0 ? T(0) : src_idx;
+  }
+}
+
+static inline int nearest_neighbor_compute_source_index(
+               const float scale, int dst_index, int inputSize) {
+  const int src_index = MIN(floorf(dst_index * scale), inputSize - 1);
+  return src_index;
+}
+
+template<typename T>
+static T upsampling_get_value_bounded(T* data, int width, int height, int x, int y) {
+  int access_x = std::max(std::min(x, width - 1), 0);
+  int access_y = std::max(std::min(y, height - 1), 0);
+  return data[access_y * width + access_x];
+}
+
+template<typename T>
+static void upsampling_increment_value_bounded(
+  T* data,
+  int width,
+  int height,
+  int x,
+  int y,
+  T value
+) {
+  int access_x = std::max(std::min(x, width - 1), 0);
+  int access_y = std::max(std::min(y, height - 1), 0);
+  data[access_y * width + access_x] += value;
+}
+
+// Based on https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+template<typename T>
+static inline T cubic_convolution1(T x, T A) {
+  return ((A + 2) * x - (A + 3)) * x * x + 1;
+}
+
+template<typename T>
+static inline T cubic_convolution2(T x, T A) {
+  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
+}
+
+template<typename T>
+static inline void get_cubic_upsampling_coefficients(T coeffs[4], T t) {
+  T A = -0.75;
+
+  T x1 = t;
+  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
+  coeffs[1] = cubic_convolution1<T>(x1, A);
+
+  // opposite coefficients
+  T x2 = 1.0 - t;
+  coeffs[2] = cubic_convolution1<T>(x2, A);
+  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
+}
+
+template<typename T>
+static inline T cubic_interp1d(
+  T x0,
+  T x1,
+  T x2,
+  T x3,
+  T t
+) {
+  T coeffs[4];
+  get_cubic_upsampling_coefficients<T>(coeffs, t);
+
+  return x0 * coeffs[0]
+    + x1 * coeffs[1]
+    + x2 * coeffs[2]
+    + x3 * coeffs[3];
+}
+
+#endif
index 441af22..0ef123a 100644 (file)
 #include <THNN/generic/VolumetricAveragePooling.c>
 #include <TH/THGenerateFloatTypes.h>
 
+#include <THNN/generic/SpatialUpSamplingBicubic.c>
+#include <TH/THGenerateFloatTypes.h>
+
 #include <THNN/generic/VolumetricConvolutionMM.c>
 #include <TH/THGenerateFloatTypes.h>
 
index 60a17c6..7373465 100644 (file)
@@ -1779,6 +1779,55 @@ new_module_tests = [
     ),
     dict(
         module_name='Upsample',
+        constructor_args=(12, None, 'bicubic', False),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_2d',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=((4, 6), None, 'bicubic', False),
+        input_size=(1, 2, 2, 3),
+        desc='bicubic_tuple_2d',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, 4., 'bicubic', False),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_scale_2d',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, (2., 2.), 'bicubic', False),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_scale_tuple_shared_2d',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, (2., 1.), 'bicubic', False),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_scale_tuple_skewed_2d',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=((4, 6), None, 'bicubic', True),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_tuple_2d_align_corners',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, (2., 1.), 'bicubic', True),
+        input_size=(1, 2, 4, 4),
+        desc='bicubic_scale_tuple_skewed_2d_align_corners',
+        decorator=skipIfRocm
+    ),
+    dict(
+        module_name='Upsample',
         constructor_args=(12, None, 'nearest'),
         input_size=(1, 2, 4, 4, 4),
         desc='nearest_3d',
index a404cb0..0283db2 100644 (file)
@@ -6057,6 +6057,32 @@ class TestNN(NNTestCase):
                 input = torch.randn(1, 1, 2, 2, requires_grad=True)
                 gradcheck(lambda x: F.upsample(x, out_size, **kwargs), [input])
 
+    @skipIfRocm
+    def test_upsamplingBicubic2d(self):
+        # test output against known input
+        in_t = torch.arange(4).view(1, 1, 2, 2).type(torch.FloatTensor)
+        expected_out_t = torch.Tensor(
+            [[[[0.00000, 0.31481, 0.68519, 1.00000],
+               [0.62963, 0.94444, 1.31481, 1.62963],
+               [1.37037, 1.68518, 2.05556, 2.37037],
+               [2.00000, 2.31481, 2.68519, 3.00000]]]])
+        out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=True)
+        torch.set_printoptions(precision=5)
+        self.assertEqual(out_t, expected_out_t)
+
+        for align_corners in [True, False]:
+            kwargs = dict(mode='bicubic', align_corners=align_corners)
+
+            # test float scale factor up & downsampling
+            for scale_factor in [0.5, 1.5, 2]:
+                in_t = torch.ones(2, 2, 2, 2)
+                out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
+                out_size = int(math.floor(in_t.shape[-1] * scale_factor))
+                self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data)
+
+                input = torch.randn(2, 2, 2, 2, requires_grad=True)
+                gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
+
     def test_upsamplingBilinear2d_spatial_invariance(self):
         m = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False)
         in_t_9 = torch.zeros(1, 1, 9, 9)
@@ -6141,6 +6167,12 @@ class TestNN(NNTestCase):
                     m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
                     _test_interpolate_helper(_make_input(2), scale_factor, m)
 
+                    kwargs = dict(mode='bicubic', align_corners=align_corners)
+
+                    def m(t):
+                        return F.interpolate(t, scale_factor=scale_factor, **kwargs).to(device)
+                    _test_interpolate_helper(_make_input(2), scale_factor, m)
+
                     kwargs = dict(mode='trilinear', align_corners=align_corners)
                     m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
                     _test_interpolate_helper(_make_input(3), scale_factor, m)
index be5a0c9..3f2e895 100644 (file)
 - name: upsample_bilinear2d(Tensor self, IntList output_size, bool align_corners)
   self: upsample_bilinear2d_backward(grad, output_size, self.sizes(), align_corners)
 
+- name: upsample_bicubic2d(Tensor self, IntList output_size, bool align_corners)
+  self: upsample_bicubic2d_backward(grad, output_size, self.sizes(), align_corners)
+
 - name: upsample_trilinear3d(Tensor self, IntList output_size, bool align_corners)
   self: upsample_trilinear3d_backward(grad, output_size, self.sizes(), align_corners)
 
 - name: upsample_bilinear2d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
   grad_output: upsample_bilinear2d(grad, output_size, align_corners)
 
+- name: upsample_bicubic2d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
+  grad_output: upsample_bicubic2d(grad, output_size, align_corners)
+
 - name: upsample_trilinear3d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
   grad_output: upsample_trilinear3d(grad, output_size, align_corners)
 
index 128145c..6c9e549 100644 (file)
@@ -1219,7 +1219,7 @@ at::Tensor interpolate(
   if ((mode == "nearest" || mode == "area")) {
     if (align_corners != c10::nullopt) {
       throw std::runtime_error("align_corners option can only be set with the "
-                             "interpolating modes: linear | bilinear | trilinear");
+                             "interpolating modes: linear | bilinear | bicubic | trilinear");
     }
   } else {
     if (align_corners == c10::nullopt) {
@@ -1247,18 +1247,24 @@ at::Tensor interpolate(
     return at::upsample_linear1d(input, _output_size(input, 1, size, scale_factors), *align_corners);
   if (input_dim == 3 && mode == "bilinear")
     throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
+  if (input_dim == 3 && mode == "bicubic")
+    throw std::runtime_error("Got 3D input, but bicubic mode needs 4D input");
   if (input_dim == 3 && mode == "trilinear")
     throw std::runtime_error("Got 3D input, but trilinear mode needs 5D input");
   if (input_dim == 4 && mode == "linear")
     throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
   if (input_dim == 4 && mode == "bilinear")
     return at::upsample_bilinear2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
+  if (input_dim == 4 && mode == "bicubic")
+    return at::upsample_bicubic2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
   if (input_dim == 4 && mode == "trilinear")
     throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
   if (input_dim == 5 && mode == "linear")
     throw std::runtime_error("Got 5D input, but linear mode needs 3D input");
   if (input_dim == 5 && mode == "bilinear")
     throw std::runtime_error("Got 5D input, but bilinear mode needs 4D input");
+  if (input_dim == 5 && mode == "bicubic")
+    throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
   if (input_dim == 5 && mode == "trilinear")
     return at::upsample_trilinear3d(input, _output_size(input, 3, size, scale_factors), *align_corners);
 
index 4591702..c6abfc0 100644 (file)
@@ -64,7 +64,7 @@ def _load_cudart():
         return lib
 
     raise RuntimeError(
-        "couldn't find libcudart. Make sure CUDA libraries are installed in a"
+        "couldn't find libcudart. Make sure CUDA libraries are installed in a "
         "default location, or that they're in {}."
         .format('DYLD_LIBRARY_PATH' if platform.system() == 'Darwin' else
                 'LD_LIBRARY_PATH'))
index f3134da..534bebe 100644 (file)
@@ -2324,7 +2324,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=
     `mini-batch x channels x [optional depth] x [optional height] x width`.
 
     The modes available for upsampling are: `nearest`, `linear` (3D-only),
-    `bilinear` (4D-only), `trilinear` (5D-only)
+    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
 
     Args:
         input (Tensor): the input tensor
@@ -2332,7 +2332,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=
             output spatial size.
         scale_factor (int): multiplier for spatial size. Has to be an integer.
         mode (string): algorithm used for upsampling:
-            'nearest' | 'linear' | 'bilinear' | 'trilinear'. Default: 'nearest'
+            'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear'. Default: 'nearest'
         align_corners (bool, optional): Geometrically, we consider the pixels of the
             input and output as squares rather than points.
             If set to True, the input and output tensors are aligned by the
@@ -2340,7 +2340,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=
             output tensors are aligned by the corner points of their corner
             pixels, and the interpolation uses edge value padding for out-of-boundary values.
             This only has effect when :attr:`mode` is `linear`,
-            `bilinear`, or `trilinear`. Default: False
+            `bilinear`, `bicubic` or `trilinear`. Default: False
 
     .. warning::
         With ``align_corners = True``, the linearly interpolating modes
@@ -2369,7 +2369,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
     `mini-batch x channels x [optional depth] x [optional height] x width`.
 
     The modes available for resizing are: `nearest`, `linear` (3D-only),
-    `bilinear` (4D-only), `trilinear` (5D-only), `area`
+    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`
 
     Args:
         input (Tensor): the input tensor
@@ -2377,7 +2377,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
             output spatial size.
         scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
         mode (string): algorithm used for upsampling:
-            'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
+            'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'. Default: 'nearest'
         align_corners (bool, optional): Geometrically, we consider the pixels of the
             input and output as squares rather than points.
             If set to True, the input and output tensors are aligned by the
@@ -2385,7 +2385,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
             output tensors are aligned by the corner points of their corner
             pixels, and the interpolation uses edge value padding for out-of-boundary values.
             This only has effect when :attr:`mode` is `linear`,
-            `bilinear`, or `trilinear`. Default: False
+            `bilinear`, `bicubic`, or `trilinear`. Default: False
 
     .. warning::
         With ``align_corners = True``, the linearly interpolating modes
@@ -2422,7 +2422,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
     if mode in ('nearest', 'area'):
         if align_corners is not None:
             raise ValueError("align_corners option can only be set with the "
-                             "interpolating modes: linear | bilinear | trilinear")
+                             "interpolating modes: linear | bilinear | bicubic | trilinear")
     else:
         if align_corners is None:
             warnings.warn("Default upsampling behavior when mode={} is changed "
@@ -2461,9 +2461,11 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
         raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
     elif input.dim() == 5 and mode == 'trilinear':
         return torch._C._nn.upsample_trilinear3d(input, _output_size(3), align_corners)
+    elif input.dim() == 4 and mode == 'bicubic':
+        return torch._C._nn.upsample_bicubic2d(input, _output_size(2), align_corners)
     else:
         raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported"
-                                  " (got {}D) for the modes: nearest | linear | bilinear | trilinear"
+                                  " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear"
                                   " (got {})".format(input.dim(), mode))
 
 
index d03e6bc..c359097 100644 (file)
@@ -14,8 +14,9 @@ class Upsample(Module):
     `minibatch x channels x [optional depth] x [optional height] x width`.
     Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
 
-    The algorithms available for upsampling are nearest neighbor and linear, bilinear and trilinear
-    for 3D, 4D and 5D input Tensor, respectively.
+    The algorithms available for upsampling are nearest neighbor and linear,
+    bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
+    respectively.
 
     One can either give a :attr:`scale_factor` or the target output :attr:`size` to
     calculate the output size. (You cannot give both, as it is ambiguous)
@@ -23,8 +24,8 @@ class Upsample(Module):
     Args:
         size (tuple, optional): a tuple of ints `([optional D_out], [optional H_out], W_out)` output sizes
         scale_factor (int / tuple of ints, optional): the multiplier for the image height / width / depth
-        mode (string, optional): the upsampling algorithm: one of `nearest`, `linear`, `bilinear` and `trilinear`.
-                                    Default: `nearest`
+        mode (string, optional): the upsampling algorithm: one of `nearest`, `linear`, `bilinear`,
+            `bicubic` and `trilinear`. Default: `nearest`
         align_corners (bool, optional): if True, the corner pixels of the input
             and output tensors are aligned, and thus preserving the values at
             those pixels. This only has effect when :attr:`mode` is `linear`,
@@ -46,11 +47,12 @@ class Upsample(Module):
 
     .. warning::
         With ``align_corners = True``, the linearly interpolating modes
-        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
-        output and input pixels, and thus the output values can depend on the
-        input size. This was the default behavior for these modes up to version
-        0.3.1. Since then, the default behavior is ``align_corners = False``.
-        See below for concrete examples on how this affects the outputs.
+        (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
+        align the output and input pixels, and thus the output values can depend
+        on the input size. This was the default behavior for these modes up to
+        version 0.3.1. Since then, the default behavior is
+        ``align_corners = False``. See below for concrete examples on how this
+        affects the outputs.
 
     .. note::
         If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.