--- /dev/null
+#include "ATen/ATen.h"
+#include "ATen/NativeFunctions.h"
+
+#include <tuple>
+#include <vector>
+
+namespace at {
+namespace native {
+namespace {
+
+template <typename scalar_t>
+static std::vector<int> fractional_max_pool2d_generate_intervals(
+ scalar_t sample,
+ int inputSize,
+ int outputSize,
+ int poolSize) {
+ scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
+ static_cast<scalar_t>(outputSize - 1);
+ std::vector<int> sequence(outputSize);
+
+ for (int i = 0; i < outputSize - 1; ++i) {
+ sequence[i] =
+ static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
+ }
+ sequence[outputSize - 1] = inputSize - poolSize;
+
+ return sequence;
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_out_single_batch_frame(
+ scalar_t* input,
+ scalar_t* output,
+ int64_t* indices,
+ scalar_t* randomSamples,
+ int numPlanes,
+ int inputW, int inputH,
+ int outputW, int outputH,
+ int poolSizeW, int poolSizeH) {
+ int plane;
+#pragma omp parallel for private(plane)
+ for (plane = 0; plane < numPlanes; ++plane) {
+ /* each plane contains 2 random samples, one for W and one for H */
+ scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
+
+ /* Generate interval sequence */
+ auto sequenceW = fractional_max_pool2d_generate_intervals<scalar_t>(
+ randomSamplesForPlane[0], inputW, outputW, poolSizeW);
+ auto sequenceH = fractional_max_pool2d_generate_intervals<scalar_t>(
+ randomSamplesForPlane[1], inputH, outputH, poolSizeH);
+
+ /* loop over output */
+ int h, w;
+
+ scalar_t* inputForPlane = input + plane * inputW * inputH;
+ scalar_t* outputForPlane = output + plane * outputW * outputH;
+ int64_t* indicesForPlane = indices + plane * outputW * outputH;
+
+ for (h = 0; h < outputH; ++h) {
+ int inputHStart = sequenceH[h];
+
+ for (w = 0; w < outputW; ++w) {
+ int inputWStart = sequenceW[w];
+
+ scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
+ int64_t maxIndex = -1;
+
+ int h2, w2;
+ for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
+ for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
+ AT_ASSERT(h2 >= 0 && h2 < inputH);
+ AT_ASSERT(w2 >= 0 && w2 < inputW);
+
+ int planeIndex = h2 * inputW + w2;
+ scalar_t val = inputForPlane[planeIndex];
+ if (val > maxVal) {
+ maxVal = val;
+ maxIndex = planeIndex;
+ }
+ }
+ }
+
+ AT_ASSERT(maxVal != -std::numeric_limits<scalar_t>::infinity());
+ AT_ASSERT(maxIndex != -1);
+
+ outputForPlane[h * outputW + w] = maxVal;
+ indicesForPlane[h * outputW + w] = maxIndex;
+ }
+ }
+
+ }
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_out_frame(
+ scalar_t* input,
+ scalar_t* output,
+ int64_t* indices,
+ scalar_t* randomSamples,
+ int numBatch, int numPlanes,
+ int inputW, int inputH,
+ int outputW, int outputH,
+ int poolSizeW, int poolSizeH) {
+ if(numBatch == 1) {
+ fractional_max_pool2d_out_single_batch_frame<scalar_t>(
+ input,
+ output,
+ indices,
+ randomSamples,
+ numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
+ );
+ return;
+ }
+ int batch;
+#pragma omp parallel for private(batch)
+ for(batch = 0; batch < numBatch; ++ batch) {
+ fractional_max_pool2d_out_single_batch_frame<scalar_t>(
+ input + batch * numPlanes * inputH * inputW,
+ output + batch * numPlanes * outputH * outputW,
+ indices + batch * numPlanes * outputH * outputW,
+ randomSamples + batch * numPlanes * 2,
+ numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
+ }
+ }
+
+void fractional_max_pool2d_out_cpu_template(
+ const at::Tensor& input_,
+ at::Tensor& output,
+ IntList output_size,
+ IntList pool_size,
+ at::Tensor& indices,
+ const at::Tensor& randomSamples) {
+
+ int numBatch = 1;
+ int planeDim = 0;
+ int heightDim = 1;
+ int widthDim = 2;
+ int outputH = output_size[0];
+ int outputW = output_size[1];
+ int poolSizeH = pool_size[0];
+ int poolSizeW = pool_size[1];
+
+ /* get contiguous input */
+ auto input = input_.contiguous();
+
+ int ndims = input.ndimension();
+ AT_CHECK(input.numel() > 0 && (ndims == 3 || ndims == 4),
+ "non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
+ ndims);
+
+ if (ndims == 4) {
+ numBatch = input.size(0);
+ planeDim++;
+ heightDim++;
+ widthDim++;
+ }
+
+ /* sizes */
+ int numPlanes = input.size(planeDim);
+ int inputH = input.size(heightDim);
+ int inputW = input.size(widthDim);
+
+ AT_CHECK(outputH + poolSizeH - 1 <= inputH,
+ "fractional_max_pool2d(): pool height ", poolSizeH,
+ " too large relative to input height ", inputH);
+ AT_CHECK(outputW + poolSizeW - 1 <= inputW,
+ "fractional_max_pool2d(): pool width ", poolSizeW,
+ " too large relative to input width ", inputW);
+
+ if (ndims == 3) {
+ /* resize output */
+ output.resize_({numPlanes, outputH, outputW});
+ /* indices will contain the locations for each output point */
+ indices.resize_({numPlanes, outputH, outputW});
+ } else {
+ output.resize_({numBatch, numPlanes, outputH, outputW});
+ /* indices will contain the locations for each output point */
+ indices.resize_({numBatch, numPlanes, outputH, outputW});
+ }
+
+ AT_DISPATCH_FLOATING_TYPES(input.type(),
+ "fractional_max_pool2d_out_frame", [&] {
+ auto input_data = input.data<scalar_t>();
+ auto output_data = output.data<scalar_t>();
+ auto indices_data = indices.data<int64_t>();
+ auto randomSamples_data = randomSamples.data<scalar_t>();
+ fractional_max_pool2d_out_frame<scalar_t>(
+ input_data,
+ output_data,
+ indices_data,
+ randomSamples_data,
+ numBatch, numPlanes,
+ inputW, inputH,
+ outputW, outputH,
+ poolSizeW, poolSizeH);
+ }
+ );
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_backward_out_single_batch_frame(
+ scalar_t* gradInput,
+ scalar_t* gradOutput,
+ int64_t* indices,
+ int numPlanes,
+ int inputW, int inputH,
+ int outputW, int outputH) {
+ int plane;
+#pragma omp parallel for private(plane)
+ for (plane = 0; plane < numPlanes; plane++) {
+ scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
+ scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
+ int64_t* indicesForPlane = indices + plane * outputW * outputH;
+
+ int h, w;
+ for (h = 0; h < outputH; ++h) {
+ for (w = 0; w < outputW; ++w) {
+ int outputIndex = h * outputW + w;
+ int64_t index = indicesForPlane[outputIndex];
+ AT_ASSERT(index >= 0 && index < inputW * inputH);
+
+ gradInputForPlane[index] += gradOutputForPlane[outputIndex];
+ }
+ }
+ }
+}
+
+template <typename scalar_t>
+static void fractional_max_pool2d_backward_out_frame(
+ scalar_t* gradInput,
+ scalar_t* gradOutput,
+ int64_t* indices,
+ int numBatch, int numPlanes,
+ int inputW, int inputH,
+ int outputW, int outputH) {
+ if(numBatch == 1) {
+ fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
+ gradInput, gradOutput, indices,
+ numPlanes,
+ inputW, inputH, outputW, outputH
+ );
+ return;
+ }
+ int batch;
+#pragma omp parallel for private(batch)
+ for(batch = 0; batch < numBatch; ++ batch) {
+ fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
+ gradInput + batch * numPlanes * inputH * inputW,
+ gradOutput + batch * numPlanes * outputH * outputW,
+ indices + batch * numPlanes * outputH * outputW,
+ numPlanes, inputW, inputH, outputW, outputH);
+ }
+}
+
+Tensor& fractional_max_pool2d_backward_out_cpu_template(
+ const at::Tensor& input,
+ const at::Tensor& gradOutput_,
+ at::Tensor& gradInput,
+ IntList output_size,
+ IntList pool_size /* unused */,
+ const at::Tensor& indices) {
+
+ int numBatch = 1;
+ int planeDim = 0;
+ int heightDim = 1;
+ int widthDim = 2;
+
+ int outputH = output_size[0];
+ int outputW = output_size[1];
+
+ int ndims = input.ndimension();
+ if (ndims == 4) {
+ numBatch = input.size(0);
+ planeDim = 1;
+ heightDim++;
+ widthDim++;
+ }
+
+ /* sizes */
+ int numPlanes = input.size(planeDim);
+ int inputH = input.size(heightDim);
+ int inputW = input.size(widthDim);
+
+ /* get contiguous gradOutput */
+ auto gradOutput = gradOutput_.contiguous();
+
+ AT_CHECK(outputW == gradOutput.size(widthDim),
+ "fractional_max_pool2d_backward(): gradOutput width unexpected");
+ AT_CHECK(outputH == gradOutput.size(heightDim),
+ "fractional_max_pool2d_backward(): gradOutput height unexpected");
+
+ /* resize */
+ gradInput.resize_as_(input);
+ gradInput.zero_();
+
+ /* backprop */
+ AT_DISPATCH_FLOATING_TYPES(
+ input.type(), "fractional_max_pool2d_backward_out_frame", [&] {
+ auto gradInput_data = gradInput.data<scalar_t>();
+ auto gradOutput_data = gradOutput.data<scalar_t>();
+ auto indices_data = indices.data<int64_t>();
+ fractional_max_pool2d_backward_out_frame<scalar_t>(
+ gradInput_data,
+ gradOutput_data,
+ indices_data,
+ numBatch, numPlanes,
+ inputW, inputH,
+ outputW, outputH
+ );
+ }
+ );
+ return gradInput;
+}
+
+} // namespace
+
+std::tuple<Tensor&, Tensor&> fractional_max_pool2d_out_cpu(
+ at::Tensor& output,
+ at::Tensor& indices,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& randomSamples)
+{
+ fractional_max_pool2d_out_cpu_template(
+ input,
+ output,
+ output_size,
+ pool_size,
+ indices,
+ randomSamples);
+ return std::tuple<Tensor&, Tensor&>(output, indices);
+}
+
+std::tuple<Tensor, Tensor> fractional_max_pool2d_cpu(
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& randomSamples)
+{
+ Tensor output = at::empty({0}, input.options());
+ Tensor indices = at::empty({0}, input.options().dtype(kLong));
+ fractional_max_pool2d_out_cpu_template(
+ input,
+ output,
+ output_size,
+ pool_size,
+ indices,
+ randomSamples);
+ return std::tuple<Tensor, Tensor>(output, indices);
+}
+
+Tensor& fractional_max_pool2d_backward_out_cpu(
+ at::Tensor& gradInput,
+ const at::Tensor& gradOutput_,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& indices)
+{
+ gradInput.resize_as_(input);
+ fractional_max_pool2d_backward_out_cpu_template(
+ input,
+ gradOutput_,
+ gradInput,
+ output_size,
+ pool_size,
+ indices);
+ return gradInput;
+}
+
+Tensor fractional_max_pool2d_backward_cpu(
+ const at::Tensor& gradOutput_,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& indices)
+{
+ Tensor gradInput = at::empty({0}, input.options());
+ fractional_max_pool2d_backward_out_cpu_template(
+ input,
+ gradOutput_,
+ gradInput,
+ output_size,
+ pool_size,
+ indices);
+ return gradInput;
+}
+
+} // at::native
+} // at
return at::legacy::th::_thnn_avg_pool3d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad);
}
-std::tuple<Tensor &,Tensor &> fractional_max_pool2d_out(Tensor & output, Tensor & indices, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & random_samples) {
- return at::legacy::th::_thnn_fractional_max_pool2d_forward_out(output, indices, self, kernel_size, output_size, random_samples);
-}
-
-std::tuple<Tensor,Tensor> fractional_max_pool2d(const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & random_samples) {
- return at::legacy::th::_thnn_fractional_max_pool2d_forward(self, kernel_size, output_size, random_samples);
-}
-
-Tensor & fractional_max_pool2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & indices) {
- return at::legacy::th::_thnn_fractional_max_pool2d_backward_out(grad_input, grad_output, self, kernel_size, output_size, indices);
-}
-
-Tensor fractional_max_pool2d_backward(const Tensor & grad_output, const Tensor & self, IntList kernel_size, IntList output_size, const Tensor & indices) {
- return at::legacy::th::_thnn_fractional_max_pool2d_backward(grad_output, self, kernel_size, output_size, indices);
-}
-
std::tuple<Tensor &,Tensor &> max_pool2d_with_indices_out(Tensor & output, Tensor & indices, const Tensor & self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) {
return at::legacy::th::_thnn_max_pool2d_with_indices_forward_out(output, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
}
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <c10/util/Exception.h>
+
+#include <algorithm>
+#include <cfloat>
+#include <cmath>
+
+namespace at {
+namespace native {
+
+using namespace at::cuda::detail;
+
+namespace {
+
+template <typename scalar_t, typename accscalar_t>
+__device__ inline int get_interval(accscalar_t sample,
+ int index, int inputSize, int outputSize, int poolSize) {
+ accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
+ static_cast<accscalar_t>(outputSize - 1);
+ if (index == outputSize - 1) {
+ return inputSize - poolSize;
+ } else {
+ return static_cast<int>((index + sample) * alpha) -
+ static_cast<int>(sample * alpha);
+ }
+}
+
+template <typename scalar_t>
+__global__ void fractional_max_pool2d_out_cuda_frame(
+ PackedTensorAccessor<scalar_t, 4> output,
+ PackedTensorAccessor<int64_t, 4> indices,
+ PackedTensorAccessor<scalar_t, 4> input,
+ PackedTensorAccessor<scalar_t, 3> samples,
+ int poolSizeH, int poolSizeW) {
+
+ using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
+
+ int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
+ int plane = blockIdx.y;
+ int batch = blockIdx.z;
+
+ // Each thread generates a specific output point
+ if (ourOutputPoint < output.size(2) * output.size(3)) {
+ int outputW = ourOutputPoint % output.size(3);
+ int outputH = ourOutputPoint / output.size(3);
+
+ int poolW = get_interval<scalar_t, accscalar_t>(
+ static_cast<accscalar_t>(samples[batch][plane][0]),
+ outputW, input.size(3), output.size(3), poolSizeW);
+ int poolH = get_interval<scalar_t, accscalar_t>(
+ static_cast<accscalar_t>(samples[batch][plane][1]),
+ outputH, input.size(2), output.size(2), poolSizeH);
+
+ scalar_t maxVal = at::numeric_limits<scalar_t>::lowest();
+ int maxIndex = -1;
+
+ for (int h = poolH; h < poolH + poolSizeH; ++h) {
+ if (poolSizeW < 2 || poolSizeW > 7) {
+ for (int w = poolW; w < poolW + poolSizeW; ++w) {
+ scalar_t val = input[batch][plane][h][w];
+ // for consistency with THNN, favor the first max
+ if (val > maxVal) {
+ maxIndex = h * input.size(3) + w;
+ maxVal = val;
+ }
+ }
+ } else {
+ for (int i = 0; i < poolSizeW; ++i) {
+ int w = i + poolW;
+ scalar_t val = input[batch][plane][h][w];
+ // for consistency with THNN, favor the first max
+ if (val > maxVal) {
+ maxIndex = h * input.size(3) + w;
+ maxVal = val;
+ }
+ }
+ }
+ }
+
+ assert(maxVal != at::numeric_limits<scalar_t>::lowest());
+ assert(maxIndex != -1);
+
+ indices[batch][plane][outputH][outputW] = maxIndex;
+ output[batch][plane][outputH][outputW] = maxVal;
+ }
+}
+
+template <typename scalar_t>
+__global__ void fractional_max_pool2d_backward_out_cuda_frame(
+ PackedTensorAccessor<scalar_t, 4> gradInput,
+ PackedTensorAccessor<scalar_t, 4> gradOutput,
+ PackedTensorAccessor<int64_t, 4> indices) {
+ // Output (h, w) point that this thread is responsible for
+ int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
+ int plane = blockIdx.y;
+ int batch = blockIdx.z;
+
+ // Each thread generates a specific output point
+ if (ourOutputPoint < gradOutput.size(2) *
+ gradOutput.size(3)) {
+ int outputW = ourOutputPoint % gradOutput.size(3);
+ int outputH = ourOutputPoint / gradOutput.size(3);
+
+ int index = indices[batch][plane][outputH][outputW];
+ assert(index >= 0);
+ int inputW = index % gradInput.size(3);
+ int inputH = index / gradInput.size(3);
+ assert(inputH < gradInput.size(2));
+
+ atomicAdd(
+ &gradInput[batch][plane][inputH][inputW],
+ gradOutput[batch][plane][outputH][outputW]
+ );
+ }
+}
+
+void fractional_max_pool2d_out_cuda_template(
+ Tensor & output,
+ Tensor& indices,
+ const Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const Tensor& randomSamples) {
+ int planeDim = 0;
+ int dimh = 1;
+ int dimw = 2;
+ int numBatch = 1;
+
+ int ndims = input.ndimension();
+ AT_CHECK(input.numel() > 0,
+ "fractional_max_pool2d(): expected input to have non-empty ",
+ "spatial dimensions.");
+
+ AT_CHECK((ndims == 3 || ndims == 4),
+ "non-empty 3D or 4D (batch mode) tensor expected for input");
+
+ if (ndims == 4) {
+ numBatch = input.size(0);
+ planeDim++;
+ dimh++;
+ dimw++;
+ }
+
+ /* sizes */
+ int numPlanes = input.size(planeDim);
+ int inputH = input.size(dimh);
+ int inputW = input.size(dimw);
+
+ int outputH = output_size[0];
+ int outputW = output_size[1];
+ int poolSizeH = pool_size[0];
+ int poolSizeW = pool_size[1];
+
+ AT_CHECK(outputH + poolSizeH - 1 <= inputH,
+ "fractional_max_pool2d(): pool_size height ", poolSizeH,
+ " too large relative to input height ", inputH);
+ AT_CHECK(outputW + poolSizeW - 1 <= inputW,
+ "pool_size width ", poolSizeW,
+ " too large relative to input width ", inputW);
+
+ if (ndims == 3) {
+ /* resize output */
+ output.resize_({numPlanes, outputH, outputW});
+ /* indices will contain the locations for each output point */
+ indices.resize_({numPlanes, outputH, outputW});
+ } else {
+ output.resize_({numBatch, numPlanes, outputH, outputW});
+ indices.resize_({numBatch, numPlanes, outputH, outputW});
+ }
+
+ auto output_ = output;
+ auto input_ = input;
+ auto indices_ = indices;
+
+ if(ndims == 3) {
+ output_ = output_.reshape({1, numPlanes, outputH, outputW});
+ indices_ = indices_.reshape({1, numPlanes, outputH, outputW});
+ input_ = input_.reshape({1, input.size(0), input.size(1), input.size(2)});
+ }
+
+ // block is limited to 4 warps
+ // grid handles overflow per each plane
+ int outputPlaneSize = output_.size(2) *
+ output_.size(3);
+ dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
+ input_.size(1),
+ input_.size(0));
+ dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(),
+ "fractional_max_pool2d_out_cuda_frame",
+ [&] {
+ auto devInput = input_.packed_accessor<scalar_t, 4>();
+ auto devOutput = output_.packed_accessor<scalar_t, 4>();
+ auto devIndices = indices_.packed_accessor<int64_t, 4>();
+ auto devSamples = randomSamples.packed_accessor<scalar_t, 3>();
+ fractional_max_pool2d_out_cuda_frame<scalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ devOutput, devIndices, devInput, devSamples,
+ poolSizeH, poolSizeW);
+ }
+ );
+ AT_CHECK(cudaGetLastError() == cudaSuccess,
+ "fractional_max_pool2d_out_cuda_frame failed with error code ",
+ cudaGetLastError());
+}
+
+void fractional_max_pool2d_backward_out_cuda_template(
+ Tensor& gradInput,
+ const Tensor& gradOutput,
+ const Tensor& input,
+ IntList pool_size /* unused */,
+ IntList output_size,
+ const Tensor& indices)
+{
+ int dimh = 1;
+ int dimw = 2;
+
+ int ndims = input.ndimension();
+ if (ndims == 4) {
+ dimh++;
+ dimw++;
+ }
+
+ /* sizes */
+ int inputH = input.size(dimh);
+ int inputW = input.size(dimw);
+
+ int outputH = output_size[0];
+ int outputW = output_size[1];
+
+ AT_CHECK(outputH == gradOutput.size(dimh),
+ "fractional_max_pool2d(): gradOutput height unexpected");
+ AT_CHECK(outputW == gradOutput.size(dimw),
+ "fractional_max_pool2d(): gradOutput width unexpected");
+
+ /* resize */
+ gradInput.resize_as_(input);
+ gradInput.zero_();
+
+ auto gradInput_ = gradInput;
+ auto gradOutput_ = gradOutput;
+ auto indices_ = indices;
+
+ if(ndims == 3) {
+ gradInput_ = gradInput_.reshape({1, input.size(0), inputH, inputW});
+ gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputH, outputW});
+ indices_ = indices_.reshape({1, indices_.size(0), outputH, outputW});
+ }
+
+ /* backprop */
+ // block is limited to 4 warps
+ // grid handles overflow per each plane
+ int outputPlaneSize = gradOutput_.size(2) *
+ gradOutput_.size(3);
+ dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
+ gradInput_.size(1),
+ gradInput_.size(0));
+ dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
+
+ auto devIndices = indices.packed_accessor<int64_t, 4>();
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.type(),
+ "fractional_max_pool2d_backward_out_cuda_frame",
+ [&] {
+ auto devGradInput = gradInput_.packed_accessor<scalar_t, 4>();
+ auto devGradOutput = gradOutput_.packed_accessor<scalar_t, 4>();
+ fractional_max_pool2d_backward_out_cuda_frame<scalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ devGradInput, devGradOutput, devIndices);
+ }
+ );
+ AT_CHECK(cudaGetLastError() == cudaSuccess,
+ "fractional_max_pool2d_backward_out_cuda_frame failed with error code ",
+ cudaGetLastError());
+}
+
+}// namespace
+
+std::tuple<Tensor&, Tensor&> fractional_max_pool2d_out_cuda(
+ at::Tensor& output,
+ at::Tensor& indices,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& randomSamples)
+{
+ fractional_max_pool2d_out_cuda_template(
+ output,
+ indices,
+ input,
+ pool_size,
+ output_size,
+ randomSamples);
+ return std::tuple<Tensor&, Tensor&>(output, indices);
+}
+
+std::tuple<Tensor, Tensor> fractional_max_pool2d_cuda(
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& randomSamples)
+{
+ Tensor output = at::empty({0}, input.options());
+ Tensor indices = at::empty({0}, input.options().dtype(kLong));
+ fractional_max_pool2d_out_cuda_template(
+ output,
+ indices,
+ input,
+ pool_size,
+ output_size,
+ randomSamples);
+ return std::tuple<Tensor, Tensor>(output, indices);
+}
+
+Tensor& fractional_max_pool2d_backward_out_cuda(
+ at::Tensor& gradInput,
+ const at::Tensor& gradOutput_,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& indices)
+{
+ fractional_max_pool2d_backward_out_cuda_template(
+ gradInput,
+ gradOutput_,
+ input,
+ pool_size,
+ output_size,
+ indices);
+ return gradInput;
+}
+
+Tensor fractional_max_pool2d_backward_cuda(
+ const at::Tensor& gradOutput_,
+ const at::Tensor& input,
+ IntList pool_size,
+ IntList output_size,
+ const at::Tensor& indices)
+{
+ Tensor gradInput = at::empty({0}, input.options());
+ fractional_max_pool2d_backward_out_cuda_template(
+ gradInput,
+ gradOutput_,
+ input,
+ pool_size,
+ output_size,
+ indices);
+ return gradInput;
+}
+
+}// at::native
+}// at
- func: fractional_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices)
python_module: nn
+ dispatch:
+ CPU: fractional_max_pool2d_out_cpu
+ CUDA: fractional_max_pool2d_out_cuda
- func: fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices)
python_module: nn
+ dispatch:
+ CPU: fractional_max_pool2d_cpu
+ CUDA: fractional_max_pool2d_cuda
- func: fractional_max_pool2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor indices) -> Tensor
python_module: nn
+ dispatch:
+ CPU: fractional_max_pool2d_backward_out_cpu
+ CUDA: fractional_max_pool2d_backward_out_cuda
- func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor indices) -> Tensor
python_module: nn
+ dispatch:
+ CPU: fractional_max_pool2d_backward_cpu
+ CUDA: fractional_max_pool2d_backward_cuda
- func: max_pool2d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices)
python_module: nn
output: 'false'
grad_input: 'false'
-- name: _thnn_fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples)
- cname: SpatialFractionalMaxPooling
- scalar_check:
- grad_input: 'false'
- scalar_check:
- output: 'false'
- grad_input: 'false'
-
- name: _thnn_max_pool2d_with_indices(Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false)
cname: SpatialDilatedMaxPooling
default_init:
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDepthwiseConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedMaxPooling.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SpatialFractionalMaxPooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu
+++ /dev/null
-#include <THCUNN/THCUNN.h>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-#include <cfloat>
-
-template <typename Dtype, typename Acctype>
-__device__ inline int getInterval(Acctype sample,
- int index,
- int inputSize,
- int outputSize,
- int poolSize) {
- Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1);
- if (index == outputSize - 1) {
- return inputSize - poolSize;
- } else {
- return (int) ((index + sample) * alpha) - (int) (sample * alpha);
- }
-}
-
-// We template on poolSizeW to allow the innermost loop to be unrolled
-template <int PoolSizeWStatic, typename Dtype, typename Acctype>
-__global__ void SpatialFractionalMaxPooling_updateOutput(
- THCDeviceTensor<Dtype, 4> input,
- THCDeviceTensor<Dtype, 4> output,
- THCDeviceTensor<THCIndex_t, 4> indices,
- THCDeviceTensor<Dtype, 3> samples,
- int poolSizeW, int poolSizeH) {
-
- // Output (h, w) point that this thread is responsible for
- int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
- int plane = blockIdx.y;
- int batch = blockIdx.z;
-
- // Each thread generates a specific output point
- if (ourOutputPoint < output.getSize(2) * output.getSize(3)) {
- int outputW = ourOutputPoint % output.getSize(3);
- int outputH = ourOutputPoint / output.getSize(3);
-
- int poolW = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][0]), outputW,
- input.getSize(3), output.getSize(3), poolSizeW);
- int poolH = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][1]), outputH,
- input.getSize(2), output.getSize(2), poolSizeH);
-
- Dtype maxVal = THCNumerics<Dtype>::min();
- int maxIndex = -1;
-
- for (int h = poolH; h < poolH + poolSizeH; ++h) {
- if (PoolSizeWStatic == -1) {
- for (int w = poolW; w < poolW + poolSizeW; ++w) {
- Dtype val = input[batch][plane][h][w];
- // for consistency with THNN, favor the first max
- if (val > maxVal) {
- maxIndex = h * input.getSize(3) + w;
- maxVal = val;
- }
- }
- } else {
-#pragma unroll
- for (int i = 0; i < PoolSizeWStatic; ++i) {
- int w = i + poolW;
- Dtype val = input[batch][plane][h][w];
- // for consistency with THNN, favor the first max
- if (val > maxVal) {
- maxIndex = h * input.getSize(3) + w;
- maxVal = val;
- }
- }
- }
- }
-
- assert(THCNumerics<Dtype>::ne(maxVal, THCNumerics<Dtype>::min()));
- assert(maxIndex != -1);
-
- // +1 for Lua index
- indices[batch][plane][outputH][outputW] = maxIndex + TH_INDEX_BASE;
- output[batch][plane][outputH][outputW] = maxVal;
- }
-}
-
-template <typename Dtype>
-__global__ void SpatialFractionalMaxPooling_updateGradInput(
- THCDeviceTensor<Dtype, 4> gradInput,
- THCDeviceTensor<Dtype, 4> gradOutput,
- THCDeviceTensor<THCIndex_t, 4> indices) {
- // Output (h, w) point that this thread is responsible for
- int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
- int plane = blockIdx.y;
- int batch = blockIdx.z;
-
- // Each thread generates a specific output point
- if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3)) {
- int outputW = ourOutputPoint % gradOutput.getSize(3);
- int outputH = ourOutputPoint / gradOutput.getSize(3);
-
- int index = indices[batch][plane][outputH][outputW] - TH_INDEX_BASE;
- assert(index >= 0);
- int inputW = index % gradInput.getSize(3);
- int inputH = index / gradInput.getSize(3);
- assert(inputH < gradInput.getSize(2));
-
- atomicAdd(gradInput[batch][plane][inputH][inputW].data(),
- gradOutput[batch][plane][outputH][outputW]);
- }
-}
-
-#include <THCUNN/generic/SpatialFractionalMaxPooling.cu>
-#include <THC/THCGenerateFloatTypes.h>
+++ /dev/null
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SpatialFractionalMaxPooling.cu"
-#else
-
-void THNN_(SpatialFractionalMaxPooling_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THCIndexTensor *indices,
- THCTensor *randomSamples)
-{
- int planeDim = 0;
- int dimh = 1;
- int dimw = 2;
- int64_t numBatch = 1;
-
- int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
- THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
- "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
-
- if (numInputDims == 4) {
- numBatch = THCTensor_(size)(state, input, 0);
- planeDim++;
- dimh++;
- dimw++;
- }
-
- /* sizes */
- int64_t numPlanes = THCTensor_(size)(state, input, planeDim);
- int64_t inputH = THCTensor_(size)(state, input, dimh);
- int64_t inputW = THCTensor_(size)(state, input, dimw);
-
- THArgCheck(outputH + poolSizeH - 1 <= inputH, 6,
- "poolSizeH (%d) too large relative to input height (%d)",
- poolSizeH, inputH);
- THArgCheck(outputW + poolSizeW - 1 <= inputW, 5,
- "poolSizeW (%d) too large relative to input width (%d)",
- poolSizeW, inputW);
-
- THCDeviceTensor<scalar_t, 4> devInput;
- THCDeviceTensor<scalar_t, 4> devOutput;
- THCDeviceTensor<THCIndex_t, 4> devIndices;
- THCDeviceTensor<scalar_t, 3> devSamples =
- toDeviceTensor<scalar_t, 3>(state, randomSamples);
-
- if (numInputDims == 3) {
- /* resize output */
- THCTensor_(resize3d)(state, output, numPlanes, outputH, outputW);
- /* indices will contain the locations for each output point */
- THCIndexTensor_(resize3d)(state, indices, numPlanes, outputH, outputW);
-
- devInput = toDeviceTensor<scalar_t, 3>(state, input).upcastOuter<4>();
- devOutput = toDeviceTensor<scalar_t, 3>(state, output).upcastOuter<4>();
- devIndices = toDeviceTensor<THCIndex_t, 3>(state, indices).upcastOuter<4>();
- } else {
- THCTensor_(resize4d)(state, output, numBatch, numPlanes, outputH, outputW);
- /* indices will contain the locations for each output point */
- THCIndexTensor_(resize4d)(state, indices, numBatch, numPlanes, outputH, outputW);
-
- devInput = toDeviceTensor<scalar_t, 4>(state, input);
- devOutput = toDeviceTensor<scalar_t, 4>(state, output);
- devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices);
- }
-
- // block is limited to 4 warps
- // grid handles overflow per each plane
- int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3);
- dim3 grid(THCCeilDiv(outputPlaneSize, 128),
- devInput.getSize(1),
- devInput.getSize(0));
- dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
-
-#define SFMP_UPDATE_OUTPUT(POOL_W) \
- SpatialFractionalMaxPooling_updateOutput<POOL_W, scalar_t, accreal> \
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
- devInput, devOutput, devIndices, devSamples, poolSizeW, poolSizeH);
-
-#define SFMP_UPDATE_OUTPUT_CASE(POOL_W) \
- case POOL_W: SFMP_UPDATE_OUTPUT(POOL_W); break
-
- switch (poolSizeW) {
- SFMP_UPDATE_OUTPUT_CASE(2);
- SFMP_UPDATE_OUTPUT_CASE(3);
- SFMP_UPDATE_OUTPUT_CASE(4);
- SFMP_UPDATE_OUTPUT_CASE(5);
- SFMP_UPDATE_OUTPUT_CASE(6);
- SFMP_UPDATE_OUTPUT_CASE(7);
- default:
- // dynamic pool width
- SFMP_UPDATE_OUTPUT_CASE(-1);
- }
- THCudaCheck(cudaGetLastError());
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THCIndexTensor *indices)
-{
- int dimh = 1;
- int dimw = 2;
-
- int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
- if (numInputDims == 4) {
- dimh++;
- dimw++;
- }
-
- /* sizes */
- int64_t inputH = THCTensor_(size)(state, input, dimh);
- int64_t inputW = THCTensor_(size)(state, input, dimw);
-
- THArgCheck(outputH == THCTensor_(size)(state, gradOutput, dimh), 3,
- "gradOutput height unexpected");
- THArgCheck(outputW == THCTensor_(size)(state, gradOutput, dimw), 3,
- "gradOutput width unexpected");
-
- /* resize */
- THCTensor_(resizeAs)(state, gradInput, input);
- THCTensor_(zero)(state, gradInput);
-
- THCDeviceTensor<scalar_t, 4> devGradInput;
- THCDeviceTensor<scalar_t, 4> devGradOutput;
- THCDeviceTensor<THCIndex_t, 4> devIndices;
-
- /* backprop */
- if (numInputDims == 3) {
- devGradInput = toDeviceTensor<scalar_t, 3>(state, gradInput).upcastOuter<4>();
- devGradOutput = toDeviceTensor<scalar_t, 3>(state, gradOutput).upcastOuter<4>();
- devIndices = toDeviceTensor<THCIndex_t, 3>(state, indices).upcastOuter<4>();
- } else {
- devGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput);
- devGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
- devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices);
- }
-
- // block is limited to 4 warps
- // grid handles overflow per each plane
- int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3);
- dim3 grid(THCCeilDiv(outputPlaneSize, 128),
- devGradInput.getSize(1),
- devGradInput.getSize(0));
- dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
-
- SpatialFractionalMaxPooling_updateGradInput
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- devGradInput, devGradOutput, devIndices);
- THCudaCheck(cudaGetLastError());
-}
-
-#endif
int dilationW, int dilationH,
bool ceil_mode);
-THC_API void THNN_(SpatialFractionalMaxPooling_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THCIndexTensor *indices,
- THCTensor *randomSamples);
-
-THC_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THCIndexTensor *indices);
-
THC_API void THNN_(SpatialFullConvolution_updateOutput)(
THCState *state,
THCTensor *input,
+++ /dev/null
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SpatialFractionalMaxPooling.c"
-#else
-
-static int64_t* THNN_(SpatialFractionalMaxPooling_generateIntervals)(
- scalar_t sample,
- int64_t inputSize,
- int64_t outputSize,
- int poolSize) {
- scalar_t alpha = (scalar_t) (inputSize - poolSize) / (scalar_t) (outputSize - 1);
- int64_t* sequence = (int64_t*) THAlloc(sizeof(int64_t) * outputSize);
-
- int64_t i;
- for (i = 0; i < outputSize - 1; ++i) {
- sequence[i] =
- (int64_t) ((i + sample) * alpha) - (int64_t) (sample * alpha);
- }
- sequence[outputSize - 1] = inputSize - poolSize;
-
- return sequence;
-}
-
-static void THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
- scalar_t* input,
- scalar_t* output,
- THIndex_t* indices,
- scalar_t* randomSamples,
- int64_t numPlanes,
- int64_t inputW, int64_t inputH,
- int64_t outputW, int64_t outputH,
- int poolSizeW, int poolSizeH) {
- int64_t plane;
-#pragma omp parallel for private(plane)
- for (plane = 0; plane < numPlanes; ++plane) {
- /* each plane contains 2 random samples, one for W and one for H */
- scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
-
- /* Generate interval sequence */
- int64_t* sequenceW =
- THNN_(SpatialFractionalMaxPooling_generateIntervals)(
- randomSamplesForPlane[0], inputW, outputW, poolSizeW);
- int64_t* sequenceH =
- THNN_(SpatialFractionalMaxPooling_generateIntervals)(
- randomSamplesForPlane[1], inputH, outputH, poolSizeH);
-
- /* loop over output */
- int64_t h, w;
-
- scalar_t* inputForPlane = input + plane * inputW * inputH;
- scalar_t* outputForPlane = output + plane * outputW * outputH;
- THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
-
- for (h = 0; h < outputH; ++h) {
- int64_t inputHStart = sequenceH[h];
-
- for (w = 0; w < outputW; ++w) {
- int64_t inputWStart = sequenceW[w];
-
- scalar_t maxVal = -THInf;
- int64_t maxIndex = -1;
-
- int64_t h2, w2;
- for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
- for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
- THAssert(h2 >= 0 && h2 < inputH);
- THAssert(w2 >= 0 && w2 < inputW);
-
- int64_t planeIndex = h2 * inputW + w2;
- scalar_t val = inputForPlane[planeIndex];
- if (val > maxVal) {
- maxVal = val;
- maxIndex = planeIndex;
- }
- }
- }
-
- THAssert(maxVal != -THInf);
- THAssert(maxIndex != -1);
-
- outputForPlane[h * outputW + w] = maxVal;
- /* +1 to lua index */
- indicesForPlane[h * outputW + w] = maxIndex + TH_INDEX_BASE;
- }
- }
-
- THFree(sequenceW);
- THFree(sequenceH);
- }
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THIndexTensor *indices,
- THTensor *randomSamples) {
-
- int64_t numBatch = 1;
- int planeDim = 0;
- int heightDim = 1;
- int widthDim = 2;
-
- int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
- THNN_ARGCHECK(!input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
- "non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
-
- if (numInputDims == 4) {
- numBatch = THTensor_(size)(input, 0);
- planeDim++;
- heightDim++;
- widthDim++;
- }
-
- /* sizes */
- int64_t numPlanes = THTensor_(size)(input, planeDim);
- int64_t inputH = THTensor_(size)(input, heightDim);
- int64_t inputW = THTensor_(size)(input, widthDim);
-
- THArgCheck(outputH + poolSizeH - 1 <= inputH, 7,
- "poolSizeH (%d) too large relative to input height (%d)",
- poolSizeH, inputH);
- THArgCheck(outputW + poolSizeW - 1 <= inputW, 6,
- "poolSizeW (%d) too large relative to input width (%d)",
- poolSizeW, inputW);
-
- /* get contiguous input */
- input = THTensor_(newContiguous)(input);
-
- if (numInputDims == 3) {
- /* resize output */
- THTensor_(resize3d)(output, numPlanes, outputH, outputW);
- /* indices will contain the locations for each output point */
- THIndexTensor_(resize3d)(indices, numPlanes, outputH, outputW);
-
- THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
- input->data<scalar_t>(),
- output->data<scalar_t>(),
- THIndexTensor_(data)(indices),
- randomSamples->data<scalar_t>(),
- numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
- } else {
- THTensor_(resize4d)(output, numBatch, numPlanes, outputH, outputW);
- /* indices will contain the locations for each output point */
- THIndexTensor_(resize4d)(indices, numBatch, numPlanes, outputH, outputW);
-
- int64_t batch;
-#pragma omp parallel for private(batch)
- for (batch = 0; batch < numBatch; ++batch) {
- THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
- input->data<scalar_t>() + batch * numPlanes * inputH * inputW,
- output->data<scalar_t>() + batch * numPlanes * outputH * outputW,
- THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
- randomSamples->data<scalar_t>() + batch * numPlanes * 2,
- numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
- }
- }
-
- /* cleanup */
- c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
- scalar_t* gradInput,
- scalar_t* gradOutput,
- THIndex_t* indices,
- int64_t numPlanes,
- int64_t inputW, int64_t inputH,
- int64_t outputW, int64_t outputH) {
- int64_t plane;
-#pragma omp parallel for private(plane)
- for (plane = 0; plane < numPlanes; plane++) {
- scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
- scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
- THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
-
- int64_t h, w;
- for (h = 0; h < outputH; ++h) {
- for (w = 0; w < outputW; ++w) {
- int64_t outputIndex = h * outputW + w;
- int64_t index = indicesForPlane[outputIndex] - TH_INDEX_BASE;
- THAssert(index >= 0 && index < inputW * inputH);
-
- gradInputForPlane[index] += gradOutputForPlane[outputIndex];
- }
- }
- }
-}
-
-void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- int outputW, int outputH,
- int poolSizeW, int poolSizeH,
- THIndexTensor *indices) {
-
- int64_t numBatch = 1;
- int planeDim = 0;
- int heightDim = 1;
- int widthDim = 2;
-
- int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
- if (numInputDims == 4) {
- numBatch = THTensor_(size)(input, 0);
- planeDim = 1;
- heightDim++;
- widthDim++;
- }
-
- /* sizes */
- int64_t numPlanes = THTensor_(size)(input, planeDim);
- int64_t inputH = THTensor_(size)(input, heightDim);
- int64_t inputW = THTensor_(size)(input, widthDim);
-
- THArgCheck(outputW == THTensor_(size)(gradOutput, widthDim), 3,
- "gradOutput width unexpected");
- THArgCheck(outputH == THTensor_(size)(gradOutput, heightDim), 3,
- "gradOutput height unexpected");
-
- /* get contiguous gradOutput */
- gradOutput = THTensor_(newContiguous)(gradOutput);
-
- /* resize */
- THTensor_(resizeAs)(gradInput, input);
- THTensor_(zero)(gradInput);
-
- /* backprop */
- if (numInputDims == 3) {
- THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
- gradInput->data<scalar_t>(),
- gradOutput->data<scalar_t>(),
- THIndexTensor_(data)(indices),
- numPlanes, inputW, inputH, outputW, outputH);
- } else {
- int64_t batch;
-#pragma omp parallel for private(batch)
- for (batch = 0; batch < numBatch; ++batch) {
- THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
- gradInput->data<scalar_t>() + batch * numPlanes * inputH * inputW,
- gradOutput->data<scalar_t>() + batch * numPlanes * outputH * outputW,
- THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
- numPlanes, inputW, inputH, outputW, outputH);
- }
- }
-
- /* cleanup */
- c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
bool ceil_mode,
bool count_include_pad);
-TH_API void THNN_(SpatialFractionalMaxPooling_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- int outputW, int outputH,
- int kW, int kH,
- THIndexTensor *indices,
- THTensor *randomSamples);
-TH_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- int outputW, int outputH,
- int kW, int kH,
- THIndexTensor *indices);
-
TH_API void THNN_(SpatialDilatedConvolution_updateOutput)(
THNNState *state,
THTensor *input,
#include <THNN/generic/SpatialAveragePooling.c>
#include <TH/THGenerateFloatTypes.h>
-#include <THNN/generic/SpatialFractionalMaxPooling.c>
-#include <TH/THGenerateFloatTypes.h>
-
#include <THNN/generic/SpatialDilatedMaxPooling.c>
#include <TH/THGenerateFloatTypes.h>
pickle=False)
+def fractional_max_pool2d_test(test_case):
+ random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
+ if test_case == 'ratio':
+ return dict(
+ constructor=lambda: nn.FractionalMaxPool2d(
+ 2, output_ratio=0.5, _random_samples=random_samples),
+ input_size=(1, 3, 5, 7),
+ fullname='FractionalMaxPool2d_ratio')
+ elif test_case == 'size':
+ return dict(
+ constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
+ 4, 3), _random_samples=random_samples),
+ input_size=(1, 3, 7, 6),
+ fullname='FractionalMaxPool2d_size')
+
+
new_module_tests = [
poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(),
multimarginloss_p_no_reduce_test(),
multimarginloss_margin_no_reduce_test(),
multimarginloss_weights_no_reduce_test(),
+ fractional_max_pool2d_test('ratio'),
+ fractional_max_pool2d_test('size'),
dict(
module_name='BatchNorm1d',
constructor_args=(10,),
test_cuda=(not TEST_WITH_ROCM)
),
dict(
- constructor=lambda: nn.FractionalMaxPool2d(
- 2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
- input_size=(1, 3, 5, 5),
- fullname='FractionalMaxPool2d_ratio',
- ),
- dict(
- constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(
- 4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
- input_size=(1, 3, 7, 7),
- fullname='FractionalMaxPool2d_size',
- test_cuda=False,
- ),
- dict(
module_name='PixelShuffle',
constructor_args=(3,),
input_size=(1, 9, 4, 4),
test_case.assertEqual(cpu_output, gpu_output, self.precision)
# Run backwards on CPU and GPU and compare results
- for i in range(5):
+ for _ in range(5):
cpu_gradOutput = cpu_output.clone().normal_()
gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)