Port FractionalMaxPool2d from TH to ATen (#15531)
authorChandler Zuo <chandlerzuo@fb.com>
Wed, 16 Jan 2019 01:54:20 +0000 (17:54 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 16 Jan 2019 01:57:12 +0000 (17:57 -0800)
Summary:
Tested:

pytest test/test_nn.py -k Fractional
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15531

Differential Revision: D13612833

Pulled By: chandlerzuo

fbshipit-source-id: b919d698d068b97ba7a4f8021367e7f6c8aae39c

13 files changed:
aten/src/ATen/native/FractionalMaxPool2d.cpp [new file with mode: 0644]
aten/src/ATen/native/LegacyNNDefinitions.cpp
aten/src/ATen/native/cuda/FractionalMaxPool2d.cu [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/nn.yaml
aten/src/THCUNN/CMakeLists.txt
aten/src/THCUNN/SpatialFractionalMaxPooling.cu [deleted file]
aten/src/THCUNN/generic/SpatialFractionalMaxPooling.cu [deleted file]
aten/src/THCUNN/generic/THCUNN.h
aten/src/THNN/generic/SpatialFractionalMaxPooling.c [deleted file]
aten/src/THNN/generic/THNN.h
aten/src/THNN/init.cpp
test/common_nn.py

diff --git a/aten/src/ATen/native/FractionalMaxPool2d.cpp b/aten/src/ATen/native/FractionalMaxPool2d.cpp
new file mode 100644 (file)
index 0000000..478c932
--- /dev/null
@@ -0,0 +1,391 @@
+#include "ATen/ATen.h"
+#include "ATen/NativeFunctions.h"
+
+#include <tuple>
+#include <vector>
+
+namespace at {
+namespace native {
+namespace {
+
+template <typename scalar_t>
+static std::vector<int> fractional_max_pool2d_generate_intervals(
+  scalar_t sample,
+  int inputSize,
+  int outputSize,
+  int poolSize) {
+  scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
+    static_cast<scalar_t>(outputSize - 1);
+  std::vector<int> sequence(outputSize);
+
+  for (int i = 0; i < outputSize - 1; ++i) {
+    sequence[i] =
+      static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
+  }
+  sequence[outputSize - 1] = inputSize - poolSize;
+
+  return sequence;
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_out_single_batch_frame(
+  scalar_t* input,
+  scalar_t* output,
+  int64_t* indices,
+  scalar_t* randomSamples,
+  int numPlanes,
+  int inputW, int inputH,
+  int outputW, int outputH,
+  int poolSizeW, int poolSizeH) {
+  int plane;
+#pragma omp parallel for private(plane)
+  for (plane = 0; plane < numPlanes; ++plane) {
+    /* each plane contains 2 random samples, one for W and one for H */
+    scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
+
+    /* Generate interval sequence */
+    auto sequenceW = fractional_max_pool2d_generate_intervals<scalar_t>(
+        randomSamplesForPlane[0], inputW, outputW, poolSizeW);
+    auto sequenceH = fractional_max_pool2d_generate_intervals<scalar_t>(
+        randomSamplesForPlane[1], inputH, outputH, poolSizeH);
+
+    /* loop over output */
+    int h, w;
+
+    scalar_t* inputForPlane = input + plane * inputW * inputH;
+    scalar_t* outputForPlane = output + plane * outputW * outputH;
+    int64_t* indicesForPlane = indices + plane * outputW * outputH;
+
+    for (h = 0; h < outputH; ++h) {
+      int inputHStart = sequenceH[h];
+
+      for (w = 0; w < outputW; ++w) {
+        int inputWStart = sequenceW[w];
+
+        scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
+        int64_t maxIndex = -1;
+
+        int h2, w2;
+        for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
+          for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
+            AT_ASSERT(h2 >= 0 && h2 < inputH);
+            AT_ASSERT(w2 >= 0 && w2 < inputW);
+
+            int planeIndex = h2 * inputW + w2;
+            scalar_t val = inputForPlane[planeIndex];
+            if (val > maxVal) {
+              maxVal = val;
+              maxIndex = planeIndex;
+            }
+          }
+        }
+
+        AT_ASSERT(maxVal != -std::numeric_limits<scalar_t>::infinity());
+        AT_ASSERT(maxIndex != -1);
+
+        outputForPlane[h * outputW + w] = maxVal;
+        indicesForPlane[h * outputW + w] = maxIndex;
+      }
+    }
+
+  }
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_out_frame(
+  scalar_t* input,
+  scalar_t* output,
+  int64_t* indices,
+  scalar_t* randomSamples,
+  int numBatch, int numPlanes,
+  int inputW, int inputH,
+  int outputW, int outputH,
+  int poolSizeW, int poolSizeH) {
+    if(numBatch == 1) {
+      fractional_max_pool2d_out_single_batch_frame<scalar_t>(
+        input,
+        output,
+        indices,
+        randomSamples,
+        numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
+      );
+      return;
+    }
+    int batch;
+#pragma omp parallel for private(batch)
+    for(batch = 0; batch < numBatch; ++ batch) {
+      fractional_max_pool2d_out_single_batch_frame<scalar_t>(
+        input + batch * numPlanes * inputH * inputW,
+        output + batch * numPlanes * outputH * outputW,
+        indices + batch * numPlanes * outputH * outputW,
+        randomSamples + batch * numPlanes * 2,
+        numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
+    }
+  }
+
+void fractional_max_pool2d_out_cpu_template(
+  const at::Tensor& input_,
+  at::Tensor& output,
+  IntList output_size,
+  IntList pool_size,
+  at::Tensor& indices,
+  const at::Tensor& randomSamples) {
+
+  int numBatch = 1;
+  int planeDim = 0;
+  int heightDim = 1;
+  int widthDim = 2;
+  int outputH = output_size[0];
+  int outputW = output_size[1];
+  int poolSizeH = pool_size[0];
+  int poolSizeW = pool_size[1];
+
+  /* get contiguous input */
+  auto input = input_.contiguous();
+
+  int ndims = input.ndimension();
+  AT_CHECK(input.numel() > 0 && (ndims == 3 || ndims == 4),
+    "non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
+    ndims);
+
+  if (ndims == 4) {
+    numBatch = input.size(0);
+    planeDim++;
+    heightDim++;
+    widthDim++;
+  }
+
+  /* sizes */
+  int numPlanes = input.size(planeDim);
+  int inputH = input.size(heightDim);
+  int inputW = input.size(widthDim);
+
+  AT_CHECK(outputH + poolSizeH - 1 <= inputH,
+    "fractional_max_pool2d(): pool height ", poolSizeH,
+    " too large relative to input height ", inputH);
+  AT_CHECK(outputW + poolSizeW - 1 <= inputW,
+    "fractional_max_pool2d(): pool width ", poolSizeW,
+    " too large relative to input width ", inputW);
+
+  if (ndims == 3) {
+    /* resize output */
+    output.resize_({numPlanes, outputH, outputW});
+    /* indices will contain the locations for each output point */
+    indices.resize_({numPlanes, outputH, outputW});
+  } else {
+    output.resize_({numBatch, numPlanes, outputH, outputW});
+    /* indices will contain the locations for each output point */
+    indices.resize_({numBatch, numPlanes, outputH, outputW});
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(input.type(),
+  "fractional_max_pool2d_out_frame", [&] {
+    auto input_data = input.data<scalar_t>();
+    auto output_data = output.data<scalar_t>();
+    auto indices_data = indices.data<int64_t>();
+    auto randomSamples_data = randomSamples.data<scalar_t>();
+    fractional_max_pool2d_out_frame<scalar_t>(
+      input_data,
+      output_data,
+      indices_data,
+      randomSamples_data,
+      numBatch, numPlanes,
+      inputW, inputH,
+      outputW, outputH,
+      poolSizeW, poolSizeH);
+    }
+  );
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_backward_out_single_batch_frame(
+  scalar_t* gradInput,
+  scalar_t* gradOutput,
+  int64_t* indices,
+  int numPlanes,
+  int inputW, int inputH,
+  int outputW, int outputH) {
+  int plane;
+#pragma omp parallel for private(plane)
+  for (plane = 0; plane < numPlanes; plane++) {
+    scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
+    scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
+    int64_t* indicesForPlane = indices + plane * outputW * outputH;
+
+    int h, w;
+    for (h = 0; h < outputH; ++h) {
+      for (w = 0; w < outputW; ++w) {
+        int outputIndex = h * outputW + w;
+        int64_t index = indicesForPlane[outputIndex];
+        AT_ASSERT(index >= 0 && index < inputW * inputH);
+
+        gradInputForPlane[index] += gradOutputForPlane[outputIndex];
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_backward_out_frame(
+  scalar_t* gradInput,
+  scalar_t* gradOutput,
+  int64_t* indices,
+  int numBatch, int numPlanes,
+  int inputW, int inputH,
+  int outputW, int outputH) {
+    if(numBatch == 1) {
+      fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
+        gradInput, gradOutput, indices,
+        numPlanes,
+        inputW, inputH, outputW, outputH
+      );
+      return;
+    }
+    int batch;
+#pragma omp parallel for private(batch)
+    for(batch = 0; batch < numBatch; ++ batch) {
+      fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
+        gradInput + batch * numPlanes * inputH * inputW,
+        gradOutput + batch * numPlanes * outputH * outputW,
+        indices + batch * numPlanes * outputH * outputW,
+        numPlanes, inputW, inputH, outputW, outputH);
+    }
+}
+
+Tensor& fractional_max_pool2d_backward_out_cpu_template(
+  const at::Tensor& input,
+  const at::Tensor& gradOutput_,
+  at::Tensor& gradInput,
+  IntList output_size,
+  IntList pool_size /* unused */,
+  const at::Tensor& indices) {
+
+  int numBatch = 1;
+  int planeDim = 0;
+  int heightDim = 1;
+  int widthDim = 2;
+
+  int outputH = output_size[0];
+  int outputW = output_size[1];
+
+  int ndims = input.ndimension();
+  if (ndims == 4) {
+    numBatch = input.size(0);
+    planeDim = 1;
+    heightDim++;
+    widthDim++;
+  }
+
+  /* sizes */
+  int numPlanes = input.size(planeDim);
+  int inputH = input.size(heightDim);
+  int inputW = input.size(widthDim);
+
+  /* get contiguous gradOutput */
+  auto gradOutput = gradOutput_.contiguous();
+
+  AT_CHECK(outputW == gradOutput.size(widthDim),
+    "fractional_max_pool2d_backward(): gradOutput width unexpected");
+  AT_CHECK(outputH == gradOutput.size(heightDim),
+    "fractional_max_pool2d_backward(): gradOutput height unexpected");
+
+  /* resize */
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  /* backprop */
+  AT_DISPATCH_FLOATING_TYPES(
+    input.type(), "fractional_max_pool2d_backward_out_frame", [&] {
+      auto gradInput_data = gradInput.data<scalar_t>();
+      auto gradOutput_data = gradOutput.data<scalar_t>();
+      auto indices_data = indices.data<int64_t>();
+      fractional_max_pool2d_backward_out_frame<scalar_t>(
+        gradInput_data,
+        gradOutput_data,
+        indices_data,
+        numBatch, numPlanes,
+        inputW, inputH,
+        outputW, outputH
+      );
+      }
+    );
+  return gradInput;
+}
+
+} // namespace
+
+std::tuple<Tensor&, Tensor&> fractional_max_pool2d_out_cpu(
+  at::Tensor& output,
+  at::Tensor& indices,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& randomSamples)
+{
+  fractional_max_pool2d_out_cpu_template(
+    input,
+    output,
+    output_size,
+    pool_size,
+    indices,
+    randomSamples);
+  return std::tuple<Tensor&, Tensor&>(output, indices);
+}
+
+std::tuple<Tensor, Tensor> fractional_max_pool2d_cpu(
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& randomSamples)
+{
+  Tensor output = at::empty({0}, input.options());
+  Tensor indices = at::empty({0}, input.options().dtype(kLong));
+  fractional_max_pool2d_out_cpu_template(
+    input,
+    output,
+    output_size,
+    pool_size,
+    indices,
+    randomSamples);
+  return std::tuple<Tensor, Tensor>(output, indices);
+}
+
+Tensor& fractional_max_pool2d_backward_out_cpu(
+  at::Tensor& gradInput,
+  const at::Tensor& gradOutput_,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& indices)
+{
+  gradInput.resize_as_(input);
+  fractional_max_pool2d_backward_out_cpu_template(
+    input,
+    gradOutput_,
+    gradInput,
+    output_size,
+    pool_size,
+    indices);
+  return gradInput;
+}
+
+Tensor fractional_max_pool2d_backward_cpu(
+  const at::Tensor& gradOutput_,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& indices)
+{
+  Tensor gradInput = at::empty({0}, input.options());
+  fractional_max_pool2d_backward_out_cpu_template(
+    input,
+    gradOutput_,
+    gradInput,
+    output_size,
+    pool_size,
+    indices);
+  return gradInput;
+}
+
+} // at::native
+} // at
index 57e8edb..9ee1a82 100644 (file)
@@ -412,22 +412,6 @@ Tensor avg_pool3d_backward(const Tensor & grad_output, const Tensor & self, IntL
   return at::legacy::th::_thnn_avg_pool3d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad);
 }
 
-std::tuple<Tensor &,Tensor &> fractional_max_pool2d_out(Tensor & output, Tensor & indices, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & random_samples) {
-  return at::legacy::th::_thnn_fractional_max_pool2d_forward_out(output, indices, self, kernel_size, output_size, random_samples);
-}
-
-std::tuple<Tensor,Tensor> fractional_max_pool2d(const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & random_samples) {
-  return at::legacy::th::_thnn_fractional_max_pool2d_forward(self, kernel_size, output_size, random_samples);
-}
-
-Tensor & fractional_max_pool2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & indices) {
-  return at::legacy::th::_thnn_fractional_max_pool2d_backward_out(grad_input, grad_output, self, kernel_size, output_size, indices);
-}
-
-Tensor fractional_max_pool2d_backward(const Tensor & grad_output, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & indices) {
-  return at::legacy::th::_thnn_fractional_max_pool2d_backward(grad_output, self, kernel_size, output_size, indices);
-}
-
 std::tuple<Tensor &,Tensor &> max_pool2d_with_indices_out(Tensor & output, Tensor & indices, const Tensor & self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) {
   return at::legacy::th::_thnn_max_pool2d_with_indices_forward_out(output, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
 }
diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu
new file mode 100644 (file)
index 0000000..8dbb22f
--- /dev/null
@@ -0,0 +1,360 @@
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <c10/util/Exception.h>
+
+#include <algorithm>
+#include <cfloat>
+#include <cmath>
+
+namespace at {
+namespace native {
+
+using namespace at::cuda::detail;
+
+namespace {
+
+template <typename scalar_t, typename accscalar_t>
+__device__ inline int get_interval(accscalar_t sample,
+  int index, int inputSize, int outputSize, int poolSize) {
+  accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
+    static_cast<accscalar_t>(outputSize - 1);
+  if (index == outputSize - 1) {
+    return inputSize - poolSize;
+  } else {
+    return static_cast<int>((index + sample) * alpha) -
+      static_cast<int>(sample * alpha);
+  }
+}
+
+template <typename scalar_t>
+__global__ void fractional_max_pool2d_out_cuda_frame(
+  PackedTensorAccessor<scalar_t, 4> output,
+  PackedTensorAccessor<int64_t, 4> indices,
+  PackedTensorAccessor<scalar_t, 4> input,
+  PackedTensorAccessor<scalar_t, 3> samples,
+  int poolSizeH, int poolSizeW) {
+
+  using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
+
+  int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+
+  // Each thread generates a specific output point
+  if (ourOutputPoint < output.size(2) * output.size(3)) {
+    int outputW = ourOutputPoint % output.size(3);
+    int outputH = ourOutputPoint / output.size(3);
+
+    int poolW = get_interval<scalar_t, accscalar_t>(
+      static_cast<accscalar_t>(samples[batch][plane][0]),
+        outputW, input.size(3), output.size(3), poolSizeW);
+    int poolH = get_interval<scalar_t, accscalar_t>(
+      static_cast<accscalar_t>(samples[batch][plane][1]),
+        outputH, input.size(2), output.size(2), poolSizeH);
+
+    scalar_t maxVal = at::numeric_limits<scalar_t>::lowest();
+    int maxIndex = -1;
+
+    for (int h = poolH; h < poolH + poolSizeH; ++h) {
+      if (poolSizeW < 2 || poolSizeW > 7) {
+        for (int w = poolW; w < poolW + poolSizeW; ++w) {
+          scalar_t val = input[batch][plane][h][w];
+          // for consistency with THNN, favor the first max
+          if (val > maxVal) {
+            maxIndex = h * input.size(3) + w;
+            maxVal = val;
+          }
+        }
+      } else {
+        for (int i = 0; i < poolSizeW; ++i) {
+          int w = i + poolW;
+          scalar_t val = input[batch][plane][h][w];
+          // for consistency with THNN, favor the first max
+          if (val > maxVal) {
+            maxIndex = h * input.size(3) + w;
+            maxVal = val;
+          }
+        }
+      }
+    }
+
+    assert(maxVal != at::numeric_limits<scalar_t>::lowest());
+    assert(maxIndex != -1);
+
+    indices[batch][plane][outputH][outputW] = maxIndex;
+    output[batch][plane][outputH][outputW] = maxVal;
+  }
+}
+
+template <typename scalar_t>
+__global__ void fractional_max_pool2d_backward_out_cuda_frame(
+  PackedTensorAccessor<scalar_t, 4> gradInput,
+  PackedTensorAccessor<scalar_t, 4> gradOutput,
+  PackedTensorAccessor<int64_t, 4> indices) {
+  // Output (h, w) point that this thread is responsible for
+  int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
+  int plane = blockIdx.y;
+  int batch = blockIdx.z;
+
+  // Each thread generates a specific output point
+  if (ourOutputPoint < gradOutput.size(2) *
+    gradOutput.size(3)) {
+    int outputW = ourOutputPoint % gradOutput.size(3);
+    int outputH = ourOutputPoint / gradOutput.size(3);
+
+    int index = indices[batch][plane][outputH][outputW];
+    assert(index >= 0);
+    int inputW = index % gradInput.size(3);
+    int inputH = index / gradInput.size(3);
+    assert(inputH < gradInput.size(2));
+
+    atomicAdd(
+      &gradInput[batch][plane][inputH][inputW],
+      gradOutput[batch][plane][outputH][outputW]
+    );
+  }
+}
+
+void fractional_max_pool2d_out_cuda_template(
+  Tensor & output,
+  Tensor& indices,
+  const Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const Tensor& randomSamples) {
+  int planeDim = 0;
+  int dimh = 1;
+  int dimw = 2;
+  int numBatch = 1;
+
+  int ndims = input.ndimension();
+  AT_CHECK(input.numel() > 0,
+    "fractional_max_pool2d(): expected input to have non-empty ",
+    "spatial dimensions.");
+
+  AT_CHECK((ndims == 3 || ndims == 4),
+     "non-empty 3D or 4D (batch mode) tensor expected for input");
+
+  if (ndims == 4) {
+    numBatch = input.size(0);
+    planeDim++;
+    dimh++;
+    dimw++;
+  }
+
+  /* sizes */
+  int numPlanes = input.size(planeDim);
+  int inputH = input.size(dimh);
+  int inputW = input.size(dimw);
+
+  int outputH = output_size[0];
+  int outputW = output_size[1];
+  int poolSizeH = pool_size[0];
+  int poolSizeW = pool_size[1];
+
+  AT_CHECK(outputH + poolSizeH - 1 <= inputH,
+             "fractional_max_pool2d(): pool_size height ", poolSizeH,
+             " too large relative to input height ", inputH);
+  AT_CHECK(outputW + poolSizeW - 1 <= inputW,
+           "pool_size width ", poolSizeW,
+           " too large relative to input width ", inputW);
+
+  if (ndims == 3) {
+    /* resize output */
+    output.resize_({numPlanes, outputH, outputW});
+    /* indices will contain the locations for each output point */
+    indices.resize_({numPlanes, outputH, outputW});
+  } else {
+    output.resize_({numBatch, numPlanes, outputH, outputW});
+    indices.resize_({numBatch, numPlanes, outputH, outputW});
+  }
+
+  auto output_ = output;
+  auto input_ = input;
+  auto indices_ = indices;
+
+  if(ndims == 3) {
+    output_ = output_.reshape({1, numPlanes, outputH, outputW});
+    indices_ = indices_.reshape({1, numPlanes, outputH, outputW});
+    input_ = input_.reshape({1, input.size(0), input.size(1), input.size(2)});
+  }
+
+  // block is limited to 4 warps
+  // grid handles overflow per each plane
+  int outputPlaneSize = output_.size(2) *
+    output_.size(3);
+  dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
+            input_.size(1),
+            input_.size(0));
+  dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(),
+    "fractional_max_pool2d_out_cuda_frame",
+    [&] {
+      auto devInput = input_.packed_accessor<scalar_t, 4>();
+      auto devOutput = output_.packed_accessor<scalar_t, 4>();
+      auto devIndices = indices_.packed_accessor<int64_t, 4>();
+      auto devSamples = randomSamples.packed_accessor<scalar_t, 3>();
+      fractional_max_pool2d_out_cuda_frame<scalar_t>
+        <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+          devOutput, devIndices, devInput, devSamples,
+          poolSizeH, poolSizeW);
+       }
+     );
+  AT_CHECK(cudaGetLastError() == cudaSuccess,
+     "fractional_max_pool2d_out_cuda_frame failed with error code ",
+     cudaGetLastError());
+}
+
+void fractional_max_pool2d_backward_out_cuda_template(
+  Tensor& gradInput,
+  const Tensor& gradOutput,
+  const Tensor& input,
+  IntList pool_size /* unused */,
+  IntList output_size,
+  const Tensor& indices)
+{
+  int dimh = 1;
+  int dimw = 2;
+
+  int ndims = input.ndimension();
+  if (ndims == 4) {
+    dimh++;
+    dimw++;
+  }
+
+  /* sizes */
+  int inputH = input.size(dimh);
+  int inputW = input.size(dimw);
+
+  int outputH = output_size[0];
+  int outputW = output_size[1];
+
+  AT_CHECK(outputH == gradOutput.size(dimh),
+           "fractional_max_pool2d(): gradOutput height unexpected");
+  AT_CHECK(outputW == gradOutput.size(dimw),
+           "fractional_max_pool2d(): gradOutput width unexpected");
+
+  /* resize */
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  auto gradInput_ = gradInput;
+  auto gradOutput_ = gradOutput;
+  auto indices_ = indices;
+
+  if(ndims == 3) {
+    gradInput_ = gradInput_.reshape({1, input.size(0), inputH, inputW});
+    gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputH, outputW});
+    indices_ = indices_.reshape({1, indices_.size(0), outputH, outputW});
+  }
+
+  /* backprop */
+  // block is limited to 4 warps
+  // grid handles overflow per each plane
+  int outputPlaneSize = gradOutput_.size(2) *
+    gradOutput_.size(3);
+  dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
+            gradInput_.size(1),
+            gradInput_.size(0));
+  dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
+
+  auto devIndices = indices.packed_accessor<int64_t, 4>();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.type(),
+    "fractional_max_pool2d_backward_out_cuda_frame",
+    [&] {
+      auto devGradInput = gradInput_.packed_accessor<scalar_t, 4>();
+      auto devGradOutput = gradOutput_.packed_accessor<scalar_t, 4>();
+      fractional_max_pool2d_backward_out_cuda_frame<scalar_t>
+        <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+        devGradInput, devGradOutput, devIndices);
+      }
+    );
+  AT_CHECK(cudaGetLastError() == cudaSuccess,
+    "fractional_max_pool2d_backward_out_cuda_frame failed with error code ",
+    cudaGetLastError());
+}
+
+}// namespace
+
+std::tuple<Tensor&, Tensor&> fractional_max_pool2d_out_cuda(
+  at::Tensor& output,
+  at::Tensor& indices,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& randomSamples)
+{
+  fractional_max_pool2d_out_cuda_template(
+    output,
+    indices,
+    input,
+    pool_size,
+    output_size,
+    randomSamples);
+  return std::tuple<Tensor&, Tensor&>(output, indices);
+}
+
+std::tuple<Tensor, Tensor> fractional_max_pool2d_cuda(
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& randomSamples)
+{
+  Tensor output = at::empty({0}, input.options());
+  Tensor indices = at::empty({0}, input.options().dtype(kLong));
+  fractional_max_pool2d_out_cuda_template(
+    output,
+    indices,
+    input,
+    pool_size,
+    output_size,
+    randomSamples);
+  return std::tuple<Tensor, Tensor>(output, indices);
+}
+
+Tensor& fractional_max_pool2d_backward_out_cuda(
+  at::Tensor& gradInput,
+  const at::Tensor& gradOutput_,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& indices)
+{
+  fractional_max_pool2d_backward_out_cuda_template(
+    gradInput,
+    gradOutput_,
+    input,
+    pool_size,
+    output_size,
+    indices);
+  return gradInput;
+}
+
+Tensor fractional_max_pool2d_backward_cuda(
+  const at::Tensor& gradOutput_,
+  const at::Tensor& input,
+  IntList pool_size,
+  IntList output_size,
+  const at::Tensor& indices)
+{
+  Tensor gradInput = at::empty({0}, input.options());
+  fractional_max_pool2d_backward_out_cuda_template(
+    gradInput,
+    gradOutput_,
+    input,
+    pool_size,
+    output_size,
+    indices);
+  return gradInput;
+}
+
+}// at::native
+}// at
index 8d520cd..fd6ed67 100644 (file)
 
 - func: fractional_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices)
   python_module: nn
+  dispatch:
+    CPU: fractional_max_pool2d_out_cpu
+    CUDA: fractional_max_pool2d_out_cuda
 
 - func: fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices)
   python_module: nn
+  dispatch:
+    CPU: fractional_max_pool2d_cpu
+    CUDA: fractional_max_pool2d_cuda
 
 - func: fractional_max_pool2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor indices) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: fractional_max_pool2d_backward_out_cpu
+    CUDA: fractional_max_pool2d_backward_out_cuda
 
 - func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor indices) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: fractional_max_pool2d_backward_cpu
+    CUDA: fractional_max_pool2d_backward_cuda
 
 - func: max_pool2d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices)
   python_module: nn
index 3f7ee96..e0424a9 100644 (file)
     output: 'false'
     grad_input: 'false'
 
-- name: _thnn_fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples)
-  cname: SpatialFractionalMaxPooling
-  scalar_check:
-    grad_input: 'false'
-  scalar_check:
-    output: 'false'
-    grad_input: 'false'
-
 - name: _thnn_max_pool2d_with_indices(Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false)
   cname: SpatialDilatedMaxPooling
   default_init:
index d7f35a7..39b1ec1 100644 (file)
@@ -36,7 +36,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialCrossMapLRN.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialDepthwiseConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedMaxPooling.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SpatialFractionalMaxPooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu
diff --git a/aten/src/THCUNN/SpatialFractionalMaxPooling.cu b/aten/src/THCUNN/SpatialFractionalMaxPooling.cu
deleted file mode 100644 (file)
index f2ffe30..0000000
+++ /dev/null
@@ -1,113 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-#include <cfloat>
-
-template <typename Dtype, typename Acctype>
-__device__ inline int getInterval(Acctype sample,
-                                  int index,
-                                  int inputSize,
-                                  int outputSize,
-                                  int poolSize) {
-  Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1);
-  if (index == outputSize - 1) {
-    return inputSize - poolSize;
-  } else {
-    return (int) ((index + sample) * alpha) - (int) (sample * alpha);
-  }
-}
-
-// We template on poolSizeW to allow the innermost loop to be unrolled
-template <int PoolSizeWStatic, typename Dtype, typename Acctype>
-__global__ void SpatialFractionalMaxPooling_updateOutput(
-  THCDeviceTensor<Dtype, 4> input,
-  THCDeviceTensor<Dtype, 4> output,
-  THCDeviceTensor<THCIndex_t, 4> indices,
-  THCDeviceTensor<Dtype, 3> samples,
-  int poolSizeW, int poolSizeH) {
-
-  // Output (h, w) point that this thread is responsible for
-  int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-
-  // Each thread generates a specific output point
-  if (ourOutputPoint < output.getSize(2) * output.getSize(3)) {
-    int outputW = ourOutputPoint % output.getSize(3);
-    int outputH = ourOutputPoint / output.getSize(3);
-
-    int poolW = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][0]), outputW,
-                            input.getSize(3), output.getSize(3), poolSizeW);
-    int poolH = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][1]), outputH,
-                            input.getSize(2), output.getSize(2), poolSizeH);
-
-    Dtype maxVal = THCNumerics<Dtype>::min();
-    int maxIndex = -1;
-
-    for (int h = poolH; h < poolH + poolSizeH; ++h) {
-      if (PoolSizeWStatic == -1) {
-        for (int w = poolW; w < poolW + poolSizeW; ++w) {
-          Dtype val = input[batch][plane][h][w];
-          // for consistency with THNN, favor the first max
-          if (val > maxVal) {
-            maxIndex = h * input.getSize(3) + w;
-            maxVal = val;
-          }
-        }
-      } else {
-#pragma unroll
-        for (int i = 0; i < PoolSizeWStatic; ++i) {
-          int w = i + poolW;
-          Dtype val = input[batch][plane][h][w];
-          // for consistency with THNN, favor the first max
-          if (val > maxVal) {
-            maxIndex = h * input.getSize(3) + w;
-            maxVal = val;
-          }
-        }
-      }
-    }
-
-    assert(THCNumerics<Dtype>::ne(maxVal, THCNumerics<Dtype>::min()));
-    assert(maxIndex != -1);
-
-    // +1 for Lua index
-    indices[batch][plane][outputH][outputW] = maxIndex + TH_INDEX_BASE;
-    output[batch][plane][outputH][outputW] = maxVal;
-  }
-}
-
-template <typename Dtype>
-__global__ void SpatialFractionalMaxPooling_updateGradInput(
-  THCDeviceTensor<Dtype, 4> gradInput,
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  THCDeviceTensor<THCIndex_t, 4> indices) {
-  // Output (h, w) point that this thread is responsible for
-  int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
-  int plane = blockIdx.y;
-  int batch = blockIdx.z;
-
-  // Each thread generates a specific output point
-  if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3)) {
-    int outputW = ourOutputPoint % gradOutput.getSize(3);
-    int outputH = ourOutputPoint / gradOutput.getSize(3);
-
-    int index = indices[batch][plane][outputH][outputW] - TH_INDEX_BASE;
-    assert(index >= 0);
-    int inputW = index % gradInput.getSize(3);
-    int inputH = index / gradInput.getSize(3);
-    assert(inputH < gradInput.getSize(2));
-
-    atomicAdd(gradInput[batch][plane][inputH][inputW].data(),
-              gradOutput[batch][plane][outputH][outputW]);
-  }
-}
-
-#include <THCUNN/generic/SpatialFractionalMaxPooling.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/SpatialFractionalMaxPooling.cu b/aten/src/THCUNN/generic/SpatialFractionalMaxPooling.cu
deleted file mode 100644 (file)
index d12bcda..0000000
+++ /dev/null
@@ -1,157 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SpatialFractionalMaxPooling.cu"
-#else
-
-void THNN_(SpatialFractionalMaxPooling_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           int outputW, int outputH,
-           int poolSizeW, int poolSizeH,
-           THCIndexTensor *indices,
-           THCTensor *randomSamples)
-{
-  int planeDim = 0;
-  int dimh = 1;
-  int dimw = 2;
-  int64_t numBatch = 1;
-
-  int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
-                  "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
-
-  if (numInputDims == 4) {
-    numBatch = THCTensor_(size)(state, input, 0);
-    planeDim++;
-    dimh++;
-    dimw++;
-  }
-
-  /* sizes */
-  int64_t numPlanes = THCTensor_(size)(state, input, planeDim);
-  int64_t inputH = THCTensor_(size)(state, input, dimh);
-  int64_t inputW = THCTensor_(size)(state, input, dimw);
-
-  THArgCheck(outputH + poolSizeH - 1 <= inputH, 6,
-             "poolSizeH (%d) too large relative to input height (%d)",
-             poolSizeH, inputH);
-  THArgCheck(outputW + poolSizeW - 1 <= inputW, 5,
-             "poolSizeW (%d) too large relative to input width (%d)",
-             poolSizeW, inputW);
-
-  THCDeviceTensor<scalar_t, 4> devInput;
-  THCDeviceTensor<scalar_t, 4> devOutput;
-  THCDeviceTensor<THCIndex_t, 4> devIndices;
-  THCDeviceTensor<scalar_t, 3> devSamples =
-    toDeviceTensor<scalar_t, 3>(state, randomSamples);
-
-  if (numInputDims == 3) {
-    /* resize output */
-    THCTensor_(resize3d)(state, output, numPlanes, outputH, outputW);
-    /* indices will contain the locations for each output point */
-    THCIndexTensor_(resize3d)(state, indices, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 3>(state, input).upcastOuter<4>();
-    devOutput = toDeviceTensor<scalar_t, 3>(state, output).upcastOuter<4>();
-    devIndices = toDeviceTensor<THCIndex_t, 3>(state, indices).upcastOuter<4>();
-  } else {
-    THCTensor_(resize4d)(state, output, numBatch, numPlanes, outputH, outputW);
-    /* indices will contain the locations for each output point */
-    THCIndexTensor_(resize4d)(state, indices, numBatch, numPlanes, outputH, outputW);
-
-    devInput = toDeviceTensor<scalar_t, 4>(state, input);
-    devOutput = toDeviceTensor<scalar_t, 4>(state, output);
-    devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices);
-  }
-
-  // block is limited to 4 warps
-  // grid handles overflow per each plane
-  int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3);
-  dim3 grid(THCCeilDiv(outputPlaneSize, 128),
-            devInput.getSize(1),
-            devInput.getSize(0));
-  dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
-
-#define SFMP_UPDATE_OUTPUT(POOL_W)                                      \
-  SpatialFractionalMaxPooling_updateOutput<POOL_W, scalar_t, accreal>       \
-    <<<grid, block, 0, THCState_getCurrentStream(state)>>>(             \
-      devInput, devOutput, devIndices, devSamples, poolSizeW, poolSizeH);
-
-#define SFMP_UPDATE_OUTPUT_CASE(POOL_W)                 \
-  case POOL_W: SFMP_UPDATE_OUTPUT(POOL_W); break
-
-  switch (poolSizeW) {
-    SFMP_UPDATE_OUTPUT_CASE(2);
-    SFMP_UPDATE_OUTPUT_CASE(3);
-    SFMP_UPDATE_OUTPUT_CASE(4);
-    SFMP_UPDATE_OUTPUT_CASE(5);
-    SFMP_UPDATE_OUTPUT_CASE(6);
-    SFMP_UPDATE_OUTPUT_CASE(7);
-    default:
-      // dynamic pool width
-      SFMP_UPDATE_OUTPUT_CASE(-1);
-  }
-  THCudaCheck(cudaGetLastError());
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           int outputW, int outputH,
-           int poolSizeW, int poolSizeH,
-           THCIndexTensor *indices)
-{
-  int dimh = 1;
-  int dimw = 2;
-
-  int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
-  if (numInputDims == 4) {
-    dimh++;
-    dimw++;
-  }
-
-  /* sizes */
-  int64_t inputH = THCTensor_(size)(state, input, dimh);
-  int64_t inputW = THCTensor_(size)(state, input, dimw);
-
-  THArgCheck(outputH == THCTensor_(size)(state, gradOutput, dimh), 3,
-                "gradOutput height unexpected");
-  THArgCheck(outputW == THCTensor_(size)(state, gradOutput, dimw), 3,
-                "gradOutput width unexpected");
-
-  /* resize */
-  THCTensor_(resizeAs)(state, gradInput, input);
-  THCTensor_(zero)(state, gradInput);
-
-  THCDeviceTensor<scalar_t, 4> devGradInput;
-  THCDeviceTensor<scalar_t, 4> devGradOutput;
-  THCDeviceTensor<THCIndex_t, 4> devIndices;
-
-  /* backprop */
-  if (numInputDims == 3) {
-    devGradInput = toDeviceTensor<scalar_t, 3>(state, gradInput).upcastOuter<4>();
-    devGradOutput = toDeviceTensor<scalar_t, 3>(state, gradOutput).upcastOuter<4>();
-    devIndices = toDeviceTensor<THCIndex_t, 3>(state, indices).upcastOuter<4>();
-  } else {
-    devGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput);
-    devGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
-    devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices);
-  }
-
-  // block is limited to 4 warps
-  // grid handles overflow per each plane
-  int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3);
-  dim3 grid(THCCeilDiv(outputPlaneSize, 128),
-            devGradInput.getSize(1),
-            devGradInput.getSize(0));
-  dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
-
-  SpatialFractionalMaxPooling_updateGradInput
-    <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-      devGradInput, devGradOutput, devIndices);
-  THCudaCheck(cudaGetLastError());
-}
-
-#endif
index 08fefca..ecbd9d9 100644 (file)
@@ -754,24 +754,6 @@ THC_API void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
                   int dilationW, int dilationH,
                   bool ceil_mode);
 
-THC_API void THNN_(SpatialFractionalMaxPooling_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  int outputW, int outputH,
-                  int poolSizeW, int poolSizeH,
-                  THCIndexTensor *indices,
-                  THCTensor *randomSamples);
-
-THC_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  int outputW, int outputH,
-                  int poolSizeW, int poolSizeH,
-                  THCIndexTensor *indices);
-
 THC_API void THNN_(SpatialFullConvolution_updateOutput)(
                   THCState *state,
                   THCTensor *input,
diff --git a/aten/src/THNN/generic/SpatialFractionalMaxPooling.c b/aten/src/THNN/generic/SpatialFractionalMaxPooling.c
deleted file mode 100644 (file)
index e7ea059..0000000
+++ /dev/null
@@ -1,253 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SpatialFractionalMaxPooling.c"
-#else
-
-static int64_t* THNN_(SpatialFractionalMaxPooling_generateIntervals)(
-  scalar_t sample,
-  int64_t inputSize,
-  int64_t outputSize,
-  int poolSize) {
-  scalar_t alpha = (scalar_t) (inputSize - poolSize) / (scalar_t) (outputSize - 1);
-  int64_t* sequence = (int64_t*) THAlloc(sizeof(int64_t) * outputSize);
-
-  int64_t i;
-  for (i = 0; i < outputSize - 1; ++i) {
-    sequence[i] =
-      (int64_t) ((i + sample) * alpha) - (int64_t) (sample * alpha);
-  }
-  sequence[outputSize - 1] = inputSize - poolSize;
-
-  return sequence;
-}
-
-static void THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
-  scalar_t* input,
-  scalar_t* output,
-  THIndex_t* indices,
-  scalar_t* randomSamples,
-  int64_t numPlanes,
-  int64_t inputW, int64_t inputH,
-  int64_t outputW, int64_t outputH,
-  int poolSizeW, int poolSizeH) {
-  int64_t plane;
-#pragma omp parallel for private(plane)
-  for (plane = 0; plane < numPlanes; ++plane) {
-    /* each plane contains 2 random samples, one for W and one for H */
-    scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
-
-    /* Generate interval sequence */
-    int64_t* sequenceW =
-      THNN_(SpatialFractionalMaxPooling_generateIntervals)(
-        randomSamplesForPlane[0], inputW, outputW, poolSizeW);
-    int64_t* sequenceH =
-      THNN_(SpatialFractionalMaxPooling_generateIntervals)(
-        randomSamplesForPlane[1], inputH, outputH, poolSizeH);
-
-    /* loop over output */
-    int64_t h, w;
-
-    scalar_t* inputForPlane = input + plane * inputW * inputH;
-    scalar_t* outputForPlane = output + plane * outputW * outputH;
-    THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
-
-    for (h = 0; h < outputH; ++h) {
-      int64_t inputHStart = sequenceH[h];
-
-      for (w = 0; w < outputW; ++w) {
-        int64_t inputWStart = sequenceW[w];
-
-        scalar_t maxVal = -THInf;
-        int64_t maxIndex = -1;
-
-        int64_t h2, w2;
-        for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
-          for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
-            THAssert(h2 >= 0 && h2 < inputH);
-            THAssert(w2 >= 0 && w2 < inputW);
-
-            int64_t planeIndex = h2 * inputW + w2;
-            scalar_t val = inputForPlane[planeIndex];
-            if (val > maxVal) {
-              maxVal = val;
-              maxIndex = planeIndex;
-            }
-          }
-        }
-
-        THAssert(maxVal != -THInf);
-        THAssert(maxIndex != -1);
-
-        outputForPlane[h * outputW + w] = maxVal;
-        /* +1 to lua index */
-        indicesForPlane[h * outputW + w] = maxIndex + TH_INDEX_BASE;
-      }
-    }
-
-    THFree(sequenceW);
-    THFree(sequenceH);
-  }
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateOutput)(
-    THNNState *state,
-    THTensor *input,
-    THTensor *output,
-    int outputW, int outputH,
-    int poolSizeW, int poolSizeH,
-    THIndexTensor *indices,
-    THTensor *randomSamples) {
-
-  int64_t numBatch = 1;
-  int planeDim = 0;
-  int heightDim = 1;
-  int widthDim = 2;
-
-  int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
-  THNN_ARGCHECK(!input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
-               "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
-
-  if (numInputDims == 4) {
-    numBatch = THTensor_(size)(input, 0);
-    planeDim++;
-    heightDim++;
-    widthDim++;
-  }
-
-  /* sizes */
-  int64_t numPlanes = THTensor_(size)(input, planeDim);
-  int64_t inputH = THTensor_(size)(input, heightDim);
-  int64_t inputW = THTensor_(size)(input, widthDim);
-
-  THArgCheck(outputH + poolSizeH - 1 <= inputH, 7,
-             "poolSizeH (%d) too large relative to input height (%d)",
-            poolSizeH, inputH);
-  THArgCheck(outputW + poolSizeW - 1 <= inputW, 6,
-             "poolSizeW (%d) too large relative to input width (%d)",
-            poolSizeW, inputW);
-
-  /* get contiguous input */
-  input = THTensor_(newContiguous)(input);
-
-  if (numInputDims == 3) {
-    /* resize output */
-    THTensor_(resize3d)(output, numPlanes, outputH, outputW);
-    /* indices will contain the locations for each output point */
-    THIndexTensor_(resize3d)(indices, numPlanes, outputH, outputW);
-
-    THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
-      input->data<scalar_t>(),
-      output->data<scalar_t>(),
-      THIndexTensor_(data)(indices),
-      randomSamples->data<scalar_t>(),
-      numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
-  } else {
-    THTensor_(resize4d)(output, numBatch, numPlanes, outputH, outputW);
-    /* indices will contain the locations for each output point */
-    THIndexTensor_(resize4d)(indices, numBatch, numPlanes, outputH, outputW);
-
-    int64_t batch;
-#pragma omp parallel for private(batch)
-    for (batch = 0; batch < numBatch; ++batch) {
-      THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
-        input->data<scalar_t>() + batch * numPlanes * inputH * inputW,
-        output->data<scalar_t>() + batch * numPlanes * outputH * outputW,
-        THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
-        randomSamples->data<scalar_t>() + batch * numPlanes * 2,
-        numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
-  scalar_t* gradInput,
-  scalar_t* gradOutput,
-  THIndex_t* indices,
-  int64_t numPlanes,
-  int64_t inputW, int64_t inputH,
-  int64_t outputW, int64_t outputH) {
-  int64_t plane;
-#pragma omp parallel for private(plane)
-  for (plane = 0; plane < numPlanes; plane++) {
-    scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
-    scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
-    THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
-
-    int64_t h, w;
-    for (h = 0; h < outputH; ++h) {
-      for (w = 0; w < outputW; ++w) {
-        int64_t outputIndex = h * outputW + w;
-        int64_t index = indicesForPlane[outputIndex] - TH_INDEX_BASE;
-        THAssert(index >= 0 && index < inputW * inputH);
-
-        gradInputForPlane[index] += gradOutputForPlane[outputIndex];
-      }
-    }
-  }
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
-    THNNState *state,
-    THTensor *input,
-    THTensor *gradOutput,
-    THTensor *gradInput,
-    int outputW, int outputH,
-    int poolSizeW, int poolSizeH,
-    THIndexTensor *indices) {
-
-  int64_t numBatch = 1;
-  int planeDim = 0;
-  int heightDim = 1;
-  int widthDim = 2;
-
-  int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
-  if (numInputDims == 4) {
-    numBatch = THTensor_(size)(input, 0);
-    planeDim = 1;
-    heightDim++;
-    widthDim++;
-  }
-
-  /* sizes */
-  int64_t numPlanes = THTensor_(size)(input, planeDim);
-  int64_t inputH = THTensor_(size)(input, heightDim);
-  int64_t inputW = THTensor_(size)(input, widthDim);
-
-  THArgCheck(outputW == THTensor_(size)(gradOutput, widthDim), 3,
-             "gradOutput width unexpected");
-  THArgCheck(outputH == THTensor_(size)(gradOutput, heightDim), 3,
-             "gradOutput height unexpected");
-
-  /* get contiguous gradOutput */
-  gradOutput = THTensor_(newContiguous)(gradOutput);
-
-  /* resize */
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(zero)(gradInput);
-
-  /* backprop */
-  if (numInputDims == 3) {
-    THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
-      gradInput->data<scalar_t>(),
-      gradOutput->data<scalar_t>(),
-      THIndexTensor_(data)(indices),
-      numPlanes, inputW, inputH, outputW, outputH);
-  } else {
-    int64_t batch;
-#pragma omp parallel for private(batch)
-    for (batch = 0; batch < numBatch; ++batch) {
-      THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
-        gradInput->data<scalar_t>() + batch * numPlanes * inputH * inputW,
-        gradOutput->data<scalar_t>() + batch * numPlanes * outputH * outputW,
-        THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
-        numPlanes, inputW, inputH, outputW, outputH);
-    }
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
index f98077c..4ea94ad 100644 (file)
@@ -547,23 +547,6 @@ TH_API void THNN_(SpatialAveragePooling_updateGradInput)(
           bool ceil_mode,
           bool count_include_pad);
 
-TH_API void THNN_(SpatialFractionalMaxPooling_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int outputW, int outputH,
-          int kW, int kH,
-          THIndexTensor *indices,
-          THTensor *randomSamples);
-TH_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int outputW, int outputH,
-          int kW, int kH,
-          THIndexTensor *indices);
-
 TH_API void THNN_(SpatialDilatedConvolution_updateOutput)(
           THNNState *state,
           THTensor *input,
index 9120420..45b073a 100644 (file)
 #include <THNN/generic/SpatialAveragePooling.c>
 #include <TH/THGenerateFloatTypes.h>
 
-#include <THNN/generic/SpatialFractionalMaxPooling.c>
-#include <TH/THGenerateFloatTypes.h>
-
 #include <THNN/generic/SpatialDilatedMaxPooling.c>
 #include <TH/THGenerateFloatTypes.h>
 
index 7373465..7372468 100644 (file)
@@ -832,6 +832,22 @@ def multimarginloss_weights_no_reduce_test():
         pickle=False)
 
 
+def fractional_max_pool2d_test(test_case):
+    random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
+    if test_case == 'ratio':
+        return dict(
+            constructor=lambda: nn.FractionalMaxPool2d(
+                2, output_ratio=0.5, _random_samples=random_samples),
+            input_size=(1, 3, 5, 7),
+            fullname='FractionalMaxPool2d_ratio')
+    elif test_case == 'size':
+        return dict(
+            constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
+                4, 3), _random_samples=random_samples),
+            input_size=(1, 3, 7, 6),
+            fullname='FractionalMaxPool2d_size')
+
+
 new_module_tests = [
     poissonnllloss_no_reduce_test(),
     bceloss_no_reduce_test(),
@@ -874,6 +890,8 @@ new_module_tests = [
     multimarginloss_p_no_reduce_test(),
     multimarginloss_margin_no_reduce_test(),
     multimarginloss_weights_no_reduce_test(),
+    fractional_max_pool2d_test('ratio'),
+    fractional_max_pool2d_test('size'),
     dict(
         module_name='BatchNorm1d',
         constructor_args=(10,),
@@ -1652,19 +1670,6 @@ new_module_tests = [
         test_cuda=(not TEST_WITH_ROCM)
     ),
     dict(
-        constructor=lambda: nn.FractionalMaxPool2d(
-            2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
-        input_size=(1, 3, 5, 5),
-        fullname='FractionalMaxPool2d_ratio',
-    ),
-    dict(
-        constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(
-            4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
-        input_size=(1, 3, 7, 7),
-        fullname='FractionalMaxPool2d_size',
-        test_cuda=False,
-    ),
-    dict(
         module_name='PixelShuffle',
         constructor_args=(3,),
         input_size=(1, 9, 4, 4),
@@ -3064,7 +3069,7 @@ class ModuleTest(TestBase):
             test_case.assertEqual(cpu_output, gpu_output, self.precision)
 
             # Run backwards on CPU and GPU and compare results
-            for i in range(5):
+            for _ in range(5):
                 cpu_gradOutput = cpu_output.clone().normal_()
                 gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
                 cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)