From 237c0c3c7a551c6b27c5108ce30769bd41a542be Mon Sep 17 00:00:00 2001 From: Chandler Zuo Date: Wed, 16 Jan 2019 14:01:39 -0800 Subject: [PATCH] Port the backend of FractionalMaxPool3d from TH to ATen (#15575) Summary: 1. Port the FractionalMaxPool3d implementation from THNN/THCUNN to ATen. 2. Expose this function to Python module nn. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15575 Differential Revision: D13612848 Pulled By: chandlerzuo fbshipit-source-id: 5f474b39005efa7788e984e8a805456dcdc43f6c --- aten/src/ATen/native/FractionalMaxPool3d.cpp | 419 +++++++++++++++++++++ aten/src/ATen/native/cuda/FractionalMaxPool3d.cu | 415 ++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 24 ++ aten/src/THCUNN/CMakeLists.txt | 1 - aten/src/THCUNN/VolumetricFractionalMaxPooling.cu | 120 ------ aten/src/THCUNN/generic/THCUNN.h | 18 - .../generic/VolumetricFractionalMaxPooling.cu | 168 --------- .../THNN/generic/VolumetricFractionalMaxPooling.c | 279 -------------- aten/src/THNN/init.cpp | 3 - test/common_nn.py | 25 ++ tools/autograd/derivatives.yaml | 7 + torch/nn/functional.py | 70 ++++ torch/nn/modules/__init__.py | 6 +- torch/nn/modules/pooling.py | 60 +++ 14 files changed, 1023 insertions(+), 592 deletions(-) create mode 100644 aten/src/ATen/native/FractionalMaxPool3d.cpp create mode 100644 aten/src/ATen/native/cuda/FractionalMaxPool3d.cu delete mode 100644 aten/src/THCUNN/VolumetricFractionalMaxPooling.cu delete mode 100644 aten/src/THCUNN/generic/VolumetricFractionalMaxPooling.cu delete mode 100644 aten/src/THNN/generic/VolumetricFractionalMaxPooling.c diff --git a/aten/src/ATen/native/FractionalMaxPool3d.cpp b/aten/src/ATen/native/FractionalMaxPool3d.cpp new file mode 100644 index 0000000..22f1bed --- /dev/null +++ b/aten/src/ATen/native/FractionalMaxPool3d.cpp @@ -0,0 +1,419 @@ +#include "ATen/ATen.h" +#include "ATen/NativeFunctions.h" + +#include +#include + +namespace at { +namespace native { +namespace { + +template +static std::vector generate_intervals( + scalar_t sample, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + scalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + std::vector sequence(outputSize); + + for (int i = 0; i < outputSize - 1; ++i) { + sequence[i] = + static_cast((i + sample) * alpha) - static_cast(sample * alpha); + } + sequence[outputSize - 1] = inputSize - poolSize; + + return sequence; +} + +template +static void fractional_max_pool3d_out_single_batch_frame( + scalar_t* input, + scalar_t* output, + int64_t* indices, + scalar_t* randomSamples, + int64_t numPlanes, + int64_t inputT, int64_t inputH, int64_t inputW, + int64_t outputT, int64_t outputH, int64_t outputW, + int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { + int64_t plane; +#pragma omp parallel for private(plane) + for (plane = 0; plane < numPlanes; ++plane) { + /* each plane contains 3 random samples, + one for T, one for W, and one for H */ + scalar_t* randomSamplesForPlane = randomSamples + plane * 3; + + /* Generate interval sequence */ + auto sequenceT = generate_intervals( + randomSamplesForPlane[0], inputT, outputT, poolSizeT); + auto sequenceH = generate_intervals( + randomSamplesForPlane[1], inputH, outputH, poolSizeH); + auto sequenceW = generate_intervals( + randomSamplesForPlane[2], inputW, outputW, poolSizeW); + + /* loop over output */ + int64_t t, h, w; + + scalar_t* inputForPlane = input + plane * inputT * inputH * inputW; + scalar_t* outputForPlane = output + plane * outputT * outputH * outputW; + int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW; + + for (t = 0; t < outputT; ++t) { + int64_t inputTStart = sequenceT[t]; + + for (h = 0; h < outputH; ++h) { + int64_t inputHStart = sequenceH[h]; + + for (w = 0; w < outputW; ++w) { + int64_t inputWStart = sequenceW[w]; + + scalar_t maxVal = -std::numeric_limits::infinity(); + int64_t maxIndex = -1; + + int64_t t2, h2, w2; + for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) { + for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) { + for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) { + AT_ASSERT(t2 >= 0 && t2 < inputT); + AT_ASSERT(h2 >= 0 && h2 < inputH); + AT_ASSERT(w2 >= 0 && w2 < inputW); + + int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2; + scalar_t val = inputForPlane[planeIndex]; + if (val > maxVal) { + maxVal = val; + maxIndex = planeIndex; + } + } + } + } + + AT_ASSERT(maxVal != -std::numeric_limits::infinity()); + AT_ASSERT(maxIndex != -1); + + outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal; + indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex; + } + } + } + + } +} + +template +static void fractional_max_pool3d_out_frame( + scalar_t* input, + scalar_t* output, + int64_t* indices, + scalar_t* randomSamples, + int64_t numBatch, int64_t numPlanes, + int64_t inputT, int64_t inputH, int64_t inputW, + int64_t outputT, int64_t outputH, int64_t outputW, + int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { + if(numBatch == 1) { + fractional_max_pool3d_out_single_batch_frame( + input, output, indices, randomSamples, + numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW, + poolSizeT, poolSizeH, poolSizeW + ); + return; + } + int64_t batch; +#pragma omp parallel for private(batch) + for(batch = 0; batch < numBatch; ++batch) { + fractional_max_pool3d_out_single_batch_frame( + input + batch * numPlanes * inputW * inputH * inputT, + output + batch * numPlanes * outputW * outputH * outputT, + indices + batch * numPlanes * outputW * outputH * outputT, + randomSamples + batch * numPlanes * 3, + numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW, + poolSizeT, poolSizeH, poolSizeW + ); + } + } + +void fractional_max_pool3d_out_cpu_template( + Tensor& output, + Tensor& indices, + const Tensor& input_, + IntList pool_size, + IntList output_size, + const Tensor& randomSamples) { + + int64_t outputT = output_size[0]; + int64_t outputH = output_size[1]; + int64_t outputW = output_size[2]; + int64_t poolSizeT = pool_size[0]; + int64_t poolSizeH = pool_size[1]; + int64_t poolSizeW = pool_size[2]; + + int64_t numBatch = 1; + int64_t planeDim = 0; + int64_t timeDim = 1; + int64_t heightDim = 2; + int64_t widthDim = 3; + + int64_t ndims = input_.ndimension(); + AT_CHECK(input_.numel() != 0 && (ndims == 4 || ndims == 5), + "fractional_max_pool3d_out(): non-empty 4D or 5D (batch mode) tensor ", + " expected for input, but got: ", ndims); + + if (ndims == 5) { + numBatch = input_.size(0); + planeDim++; + timeDim++; + heightDim++; + widthDim++; + } + + /* sizes */ + int64_t numPlanes = input_.size(planeDim); + int64_t inputT = input_.size(timeDim); + int64_t inputH = input_.size(heightDim); + int64_t inputW = input_.size(widthDim); + + AT_CHECK(outputT + poolSizeT - 1 < inputT, + "fractional_max_pool3d_out(): pool time ", poolSizeT, + " too large relative to input time ", inputT); + AT_CHECK(outputW + poolSizeW - 1 < inputW, + "fractional_max_pool3d_out(): pool width ", poolSizeW, + " too large relative to input width ", inputW); + AT_CHECK(outputH + poolSizeH - 1 < inputH, + "fractional_max_pool3d_out(): pool height ", poolSizeH, + " too large relative to input height ", inputH); + + /* get contiguous input */ + auto input = input_.contiguous(); + + if (ndims == 4) { + /* resize output */ + output.resize_({numPlanes, outputT, outputH, outputW}); + /* indices will contain the locations for each output point */ + indices.resize_({numPlanes, outputT, outputH, outputW}); + } else { + output.resize_({numBatch, numPlanes, outputT, outputH, outputW}); + /* indices will contain the locations for each output point */ + indices.resize_({numBatch, numPlanes, outputT, outputH, outputW}); + } + AT_DISPATCH_FLOATING_TYPES( + input.type(), + "fractional_max_pool3d_out_frame", + [&] { + fractional_max_pool3d_out_frame( + input.data(), + output.data(), + indices.data(), + randomSamples.data(), + numBatch, numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW, + poolSizeT, poolSizeH, poolSizeW + ); + } + ); +} + +template +static void fractional_max_pool3d_backward_out_single_batch_frame( + scalar_t* gradInput, + scalar_t* gradOutput, + int64_t* indices, + int64_t numPlanes, + int64_t inputT, int64_t inputH, int64_t inputW, + int64_t outputT, int64_t outputH, int64_t outputW) { + int64_t plane; +#pragma omp parallel for private(plane) + for (plane = 0; plane < numPlanes; plane++) { + scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW; + scalar_t* gradOutputForPlane = gradOutput + + plane * outputT * outputH * outputW; + int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW; + + int64_t h, w, t; + for (t = 0; t < outputT; ++t) { + for (h = 0; h < outputH; ++h) { + for (w = 0; w < outputW; ++w) { + int64_t outputIndex = t * outputH * outputW + h * outputW + w; + int64_t index = indicesForPlane[outputIndex]; + AT_ASSERT(index >= 0 && index < inputT * inputH * inputW); + gradInputForPlane[index] += gradOutputForPlane[outputIndex]; + } + } + } + } +} + +template +static void fractional_max_pool3d_backward_out_frame( + scalar_t* gradInput, + scalar_t* gradOutput, + int64_t* indices, + int64_t numBatch, int64_t numPlanes, + int64_t inputT, int64_t inputH, int64_t inputW, + int64_t outputT, int64_t outputH, int64_t outputW) { + if(numBatch == 1) { + fractional_max_pool3d_backward_out_single_batch_frame( + gradInput, gradOutput, indices, + numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW + ); + return; + } + int64_t batch; +#pragma omp parallel for private(batch) + for(batch = 0; batch < numBatch; ++ batch) { + fractional_max_pool3d_backward_out_single_batch_frame( + gradInput + batch * numPlanes * inputW * inputH * inputT, + gradOutput + batch * numPlanes * outputW * outputH * outputT, + indices + batch * numPlanes * outputW * outputH * outputT, + numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW + ); + } + } + + +void fractional_max_pool3d_backward_out_cpu_template( + const Tensor& input, + const Tensor& gradOutput_, + Tensor& gradInput, + IntList output_size, + IntList pool_size /* unused */, + const Tensor& indices) { + + int64_t outputT = output_size[0]; + int64_t outputH = output_size[1]; + int64_t outputW = output_size[2]; + + int64_t numBatch = 1; + int64_t planeDim = 0; + int64_t timeDim = 1; + int64_t heightDim = 2; + int64_t widthDim = 3; + + int64_t ndims = input.ndimension(); + if (ndims == 5) { + numBatch = input.size(0); + planeDim = 1; + heightDim++; + widthDim++; + timeDim++; + } + + /* sizes */ + int64_t numPlanes = input.size(planeDim); + int64_t inputT = input.size(timeDim); + int64_t inputH = input.size(heightDim); + int64_t inputW = input.size(widthDim); + + AT_CHECK(outputT == gradOutput_.size(timeDim), + "fractional_max_pool3d_backward_out(): gradOutput time unexpected"); + AT_CHECK(outputH == gradOutput_.size(heightDim), + "fractional_max_pool3d_backward_out(): ", + "gradOutput height unexpected"); + AT_CHECK(outputW == gradOutput_.size(widthDim), + "fractional_max_pool3d_backward_out(): gradOutput width unexpected"); + + /* get contiguous gradOutput */ + auto gradOutput = gradOutput_.contiguous(); + + /* resize */ + gradInput.resize_as_(input); + gradInput.zero_(); + + /* backprop */ + AT_DISPATCH_FLOATING_TYPES( + input.type(), + "fractional_max_pool3d_backward_out_frame", + [&]{ + fractional_max_pool3d_backward_out_frame( + gradInput.data(), + gradOutput.data(), + indices.data(), + numBatch, numPlanes, + inputT, inputH, inputW, + outputT, outputH, outputW + ); + } + ); +} + +}// namespace + +std::tuple fractional_max_pool3d_out_cpu( + at::Tensor& output, + at::Tensor& indices, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& randomSamples) { + fractional_max_pool3d_out_cpu_template( + output, + indices, + input, + pool_size, + output_size, + randomSamples); + return std::tuple(output, indices); +} + +std::tuple fractional_max_pool3d_cpu( + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& randomSamples) { + Tensor output = at::empty(output_size, input.options()); + Tensor indices = at::empty(output_size, at::kLong); + fractional_max_pool3d_out_cpu_template( + output, + indices, + input, + pool_size, + output_size, + randomSamples); + return std::tuple(output, indices); +} + +Tensor& fractional_max_pool3d_backward_out_cpu( + at::Tensor& gradInput, + const at::Tensor& gradOutput_, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& indices) { + fractional_max_pool3d_backward_out_cpu_template( + input, + gradOutput_, + gradInput, + output_size, + pool_size, + indices); + return gradInput; +} + +Tensor fractional_max_pool3d_backward_cpu( + const at::Tensor& gradOutput_, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& indices) { + Tensor gradInput = at::empty({0}, input.options()); + fractional_max_pool3d_backward_out_cpu_template( + input, + gradOutput_, + gradInput, + output_size, + pool_size, + indices); + return gradInput; +} + +}// native +}// at diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu new file mode 100644 index 0000000..04f7a70 --- /dev/null +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -0,0 +1,415 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +namespace { + +template +__device__ inline int64_t get_intervals( + accscalar_t sample, + int64_t index, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + accscalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + if (index == outputSize - 1) { + return inputSize - poolSize; + } else { + return static_cast((index + sample) * alpha) - \ + static_cast(sample * alpha); + } + } + +template +__global__ void fractional_max_pool3d_out_frame( + PackedTensorAccessor input, + PackedTensorAccessor output, + PackedTensorAccessor indices, + PackedTensorAccessor samples, + int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { + using accscalar_t = at::acc_type; + // Output (t, h, w) point that this thread is responsible for + int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; + int64_t plane = blockIdx.y; + int64_t batch = blockIdx.z; + // Each thread generates a specific output point + if (ourOutputPoint < output.size(2) * output.size(3) * + output.size(4)){ + int64_t outputT = ourOutputPoint / (output.size(3) * + output.size(4)); + int64_t outputH = (ourOutputPoint / output.size(4)) % + output.size(3); + int64_t outputW = ourOutputPoint % output.size(4); + + int64_t poolT = get_intervals( + static_cast(samples[batch][plane][0]), + outputT, input.size(2), output.size(2), poolSizeT); + int64_t poolH = get_intervals( + static_cast(samples[batch][plane][1]), + outputH, input.size(3), output.size(3), poolSizeH); + int64_t poolW = get_intervals( + static_cast(samples[batch][plane][2]), + outputW, input.size(4), output.size(4), poolSizeW); + + scalar_t maxVal = at::numeric_limits::lowest(); + int64_t maxIndex = -1; + + for(int64_t t = poolT; t < poolT + poolSizeT; ++ t) { + for (int64_t h = poolH; h < poolH + poolSizeH; ++h) { + if(poolSizeW < 2 || poolSizeW > 7) { + for (int64_t w = poolW; w < poolW + poolSizeW; ++w) { + scalar_t val = input[batch][plane][t][h][w]; + // for consistency with THNN, favor the first max + if (val > maxVal) { + maxIndex = t * input.size(3) * + input.size(4) + h * input.size(4) + w; + maxVal = val; + } + } + } else { + for (int64_t i = 0; i < poolSizeW; ++i) { + int64_t w = i + poolW; + scalar_t val = input[batch][plane][t][h][w]; + // for consistency with THNN, favor the first max + if (val > maxVal) { + maxIndex = t * input.size(3) * input.size(4) + + h * input.size(4) + w; + maxVal = val; + } + } + } + } + } + + assert(maxVal != at::numeric_limits::lowest()); + assert(maxIndex != -1); + + indices[batch][plane][outputT][outputH][outputW] = maxIndex; + output[batch][plane][outputT][outputH][outputW] = maxVal; + } + } + +template +__global__ void fractional_max_pool3d_backward_out_frame( + PackedTensorAccessor gradInput, + PackedTensorAccessor gradOutput, + PackedTensorAccessor indices) { + // Output (h, w) point that this thread is responsible for + int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; + int64_t plane = blockIdx.y; + int64_t batch = blockIdx.z; + + // Each thread generates a specific output point + if (ourOutputPoint < gradOutput.size(2) * + gradOutput.size(3) * gradOutput.size(4)) { + int64_t outputW = ourOutputPoint % gradOutput.size(4); + int64_t outputH = (ourOutputPoint / gradOutput.size(4)) % + gradOutput.size(3); + int64_t outputT = ourOutputPoint / (gradOutput.size(3) * + gradOutput.size(4)); + + int64_t index = indices[batch][plane][outputT][outputH][outputW]; + assert(index >= 0); + int64_t inputW = index % gradInput.size(4); + int64_t inputH = (index / gradInput.size(4)) % + gradInput.size(3); + int64_t inputT = index / (gradInput.size(3) * + gradInput.size(4)); + assert(inputT < gradInput.size(2)); + + atomicAdd( + &gradInput[batch][plane][inputT][inputH][inputW], + gradOutput[batch][plane][outputT][outputH][outputW] + ); + } + } + +void fractional_max_pool3d_out_cuda_template( + Tensor& output, + Tensor& indices, + const Tensor& input, + IntList pool_size, + IntList output_size, + const Tensor& randomSamples) { + int64_t planeDim = 0; + int64_t dimt = 1; + int64_t dimh = 2; + int64_t dimw = 3; + int64_t numBatch = 1; + + int64_t outputT = output_size[0]; + int64_t outputH = output_size[1]; + int64_t outputW = output_size[2]; + int64_t poolSizeT = pool_size[0]; + int64_t poolSizeH = pool_size[1]; + int64_t poolSizeW = pool_size[2]; + + int64_t ndims = input.ndimension(); + AT_CHECK( + input.numel() != 0 && (ndims == 4 || ndims == 5), + "fractional_max_pool3d_out_cuda_template(): ", + "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", + ndims); + + if (ndims == 5) { + numBatch = input.size(0); + planeDim++; + dimt++; + dimh++; + dimw++; + } + + /* sizes */ + int64_t numPlanes = input.size(planeDim); + int64_t inputT = input.size(dimt); + int64_t inputH = input.size(dimh); + int64_t inputW = input.size(dimw); + + AT_CHECK( + outputT + poolSizeT - 1 < inputT, + "fractional_max_pool3d_out_cuda_template(): ", + "pool time (", poolSizeT, ") too large relative to input time (", + inputT, ")"); + AT_CHECK( + outputH + poolSizeH - 1 < inputH, + "fractional_max_pool3d_out_cuda_template(): ", + "pool height (", poolSizeH, ") too large relative to input height (", + inputH, ")"); + AT_CHECK( + outputW + poolSizeW - 1 < inputW, + "fractional_max_pool3d_out_cuda_template(): ", + "pool width (", poolSizeW, ") too large relative to input width (", + inputW, ")"); + + if (ndims == 4) { + /* resize output */ + output.resize_({numPlanes, outputT, outputH, outputW}); + /* indices will contain the locations for each output point */ + indices.resize_({numPlanes, outputT, outputH, outputW}); + } else { + /* resize output */ + output.resize_({numBatch, numPlanes, outputT, outputH, outputW}); + /* indices will contain the locations for each output point */ + indices.resize_({numBatch, numPlanes, outputT, outputH, outputW}); + } + + auto output_ = output; + auto indices_ = indices; + auto input_ = input; + if(ndims == 4) { + output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW}); + indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW}); + input_ = input_.reshape({1, numPlanes, inputT, inputH, inputW}); + } + + // block is limited to 4 warps + // grid handles overflow per each plane + int64_t outputPlaneSize = output_.size(2) * + output_.size(3) * output_.size(4); + dim3 grid( + (outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128) + input_.size(1), + input_.size(0)); + dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), + "fractional_max_pool3d_out_frame", + [&]{ + fractional_max_pool3d_out_frame + <<>>( + input_.packed_accessor(), + output_.packed_accessor(), + indices_.packed_accessor(), + randomSamples.packed_accessor(), + poolSizeT, poolSizeH, poolSizeW + ); + } + ); + AT_CHECK(cudaGetLastError() == cudaSuccess, + "fractional_max_pool2d_out_cuda_template failed with error code ", + cudaGetLastError()); + } + +void fractional_max_pool3d_backward_out_cuda_template( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input, + IntList pool_size /* unused */, + IntList output_size, + const Tensor& indices) { + int64_t dimt = 1; + int64_t dimh = 2; + int64_t dimw = 3; + + int64_t outputT = output_size[0]; + int64_t outputH = output_size[1]; + int64_t outputW = output_size[2]; + + int64_t ndims = input.ndimension(); + if (ndims == 5) { + dimt++; + dimh++; + dimw++; + } + + /* sizes */ + int64_t inputT = input.size(dimt); + int64_t inputH = input.size(dimh); + int64_t inputW = input.size(dimw); + + AT_CHECK( + outputT == gradOutput.size(dimt), + "fractional_max_pool3d_backward_out_cuda_template(): ", + "gradOutput time unexpected" + ); + AT_CHECK( + outputH == gradOutput.size(dimh), + "fractional_max_pool3d_backward_out_cuda_template(): ", + "gradOutput height unexpected" + ); + AT_CHECK( + outputW == gradOutput.size(dimw), + "fractional_max_pool3d_backward_out_cuda_template(): ", + "gradOutput width unexpected" + ); + + /* resize */ + gradInput.resize_as_(input); + gradInput.zero_(); + + auto gradInput_ = gradInput; + auto gradOutput_ = gradOutput; + auto indices_ = indices; + + if(ndims == 4) { + gradInput_ = gradInput_.reshape({1, gradInput.size(0), inputT, + inputH, inputW}); + gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputT, + outputH, outputW}); + indices_ = indices_.reshape({1, indices.size(0), outputT, outputH, + outputW}); + } + + /* backprop */ + // block is limited to 4 warps + // grid handles overflow per each plane + int64_t outputPlaneSize = gradOutput_.size(2) * + gradOutput_.size(3) * gradOutput_.size(4); + dim3 grid( + (outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128) + gradInput_.size(1), + gradInput_.size(0)); + dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + gradOutput.type(), + "fractional_max_pool3d_backward_out_frame", + [&] { + fractional_max_pool3d_backward_out_frame + <<>>( + gradInput_.packed_accessor(), + gradOutput_.packed_accessor(), + indices_.packed_accessor() + ); + } + ); + AT_CHECK(cudaGetLastError() == cudaSuccess, + "fractional_max_pool2d_out_cuda_template failed with error code ", + cudaGetLastError()); + } + +}// namespace + +std::tuple fractional_max_pool3d_out_cuda( + at::Tensor& output, + at::Tensor& indices, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& randomSamples) { + fractional_max_pool3d_out_cuda_template( + output, + indices, + input, + pool_size, + output_size, + randomSamples + ); + return std::tuple(output, indices); + } + +std::tuple fractional_max_pool3d_cuda( + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& randomSamples) { + Tensor output = at::empty({0}, input.options()); + Tensor indices = at::empty({0}, input.options().dtype(kLong)); + fractional_max_pool3d_out_cuda_template( + output, + indices, + input, + pool_size, + output_size, + randomSamples + ); + return std::tuple(output, indices); + } + +Tensor& fractional_max_pool3d_backward_out_cuda( + at::Tensor& gradInput, + const at::Tensor& gradOutput_, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& indices) { + fractional_max_pool3d_backward_out_cuda_template( + gradInput, + gradOutput_, + input, + pool_size, + output_size, + indices + ); + return gradInput; + } + +Tensor fractional_max_pool3d_backward_cuda( + const at::Tensor& gradOutput, + const at::Tensor& input, + IntList pool_size, + IntList output_size, + const at::Tensor& indices) { + Tensor gradInput = at::empty({0}, input.options()); + fractional_max_pool3d_backward_out_cuda_template( + gradInput, + gradOutput, + input, + pool_size, + output_size, + indices + ); + return gradInput; + } + +}// native +}// at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fd6ed67..4f257ee 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3417,6 +3417,30 @@ CPU: fractional_max_pool2d_backward_cpu CUDA: fractional_max_pool2d_backward_cuda +- func: fractional_max_pool3d_out(Tensor output, Tensor indices, Tensor self, IntList[3] kernel_size, IntList[3] output_size, Tensor random_samples) -> (Tensor output, Tensor indices) + python_module: nn + dispatch: + CPU: fractional_max_pool3d_out_cpu + CUDA: fractional_max_pool3d_out_cuda + +- func: fractional_max_pool3d(Tensor self, IntList[3] kernel_size, IntList[3] output_size, Tensor random_samples) -> (Tensor output, Tensor indices) + python_module: nn + dispatch: + CPU: fractional_max_pool3d_cpu + CUDA: fractional_max_pool3d_cuda + +- func: fractional_max_pool3d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[3] kernel_size, IntList[3] output_size, Tensor indices) -> Tensor + python_module: nn + dispatch: + CPU: fractional_max_pool3d_backward_out_cpu + CUDA: fractional_max_pool3d_backward_out_cuda + +- func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, IntList[3] kernel_size, IntList[3] output_size, Tensor indices) -> Tensor + python_module: nn + dispatch: + CPU: fractional_max_pool3d_backward_cpu + CUDA: fractional_max_pool3d_backward_cuda + - func: max_pool2d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices) python_module: nn diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index 39b1ec1..8a09ab1 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -58,7 +58,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricAveragePooling.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricDilatedConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricDilatedMaxPooling.cu -${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFractionalMaxPooling.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullDilatedConvolution.cu ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricMaxPooling.cu diff --git a/aten/src/THCUNN/VolumetricFractionalMaxPooling.cu b/aten/src/THCUNN/VolumetricFractionalMaxPooling.cu deleted file mode 100644 index 96c6e38..0000000 --- a/aten/src/THCUNN/VolumetricFractionalMaxPooling.cu +++ /dev/null @@ -1,120 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -template -__device__ inline int getInterval(Acctype sample, - int index, - int inputSize, - int outputSize, - int poolSize) { - Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1); - if (index == outputSize - 1) { - return inputSize - poolSize; - } else { - return (int) ((index + sample) * alpha) - (int) (sample * alpha); - } -} - -// We template on poolSizeW to allow the innermost loop to be unrolled -template -__global__ void VolumetricFractionalMaxPooling_updateOutput( - THCDeviceTensor input, - THCDeviceTensor output, - THCDeviceTensor indices, - THCDeviceTensor samples, - int poolSizeT, int poolSizeW, int poolSizeH) { - - // Output (h, w) point that this thread is responsible for - int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - - // Each thread generates a specific output point - if (ourOutputPoint < output.getSize(2) * output.getSize(3) * output.getSize(4)){ - int outputT = ourOutputPoint % output.getSize(4); - int outputW = (ourOutputPoint / output.getSize(4)) % output.getSize(3); - int outputH = ourOutputPoint / (output.getSize(3)*output.getSize(4)); - - int poolT = getInterval(ScalarConvert::to(samples[batch][plane][0]), outputT, - input.getSize(4), output.getSize(4), poolSizeT); - int poolW = getInterval(ScalarConvert::to(samples[batch][plane][1]), outputW, - input.getSize(3), output.getSize(3), poolSizeW); - int poolH = getInterval(ScalarConvert::to(samples[batch][plane][2]), outputH, - input.getSize(2), output.getSize(2), poolSizeH); - - Dtype maxVal = THCNumerics::min(); - int maxIndex = -1; - - for (int h = poolH; h < poolH + poolSizeH; ++h) { - for (int w = poolW; w < poolW + poolSizeW; ++w) { - if (PoolSizeTStatic == -1) { - for (int t = poolT; t < poolT + poolSizeT; ++t) { - Dtype val = input[batch][plane][h][w][t]; - // for consistency with THNN, favor the first max - if (val > maxVal) { - maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; - maxVal = val; - } - } - } else { -#pragma unroll - for (int i = 0; i < PoolSizeTStatic; ++i) { - int t = i + poolT; - Dtype val = input[batch][plane][h][w][t]; - // for consistency with THNN, favor the first max - if (val > maxVal) { - maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; - maxVal = val; - } - } - } - } - } - - assert(THCNumerics::ne(maxVal, THCNumerics::min())); - assert(maxIndex != -1); - - // +1 for Lua index - indices[batch][plane][outputH][outputW][outputT] = maxIndex + TH_INDEX_BASE; - output[batch][plane][outputH][outputW][outputT] = maxVal; - } -} - -template -__global__ void VolumetricFractionalMaxPooling_updateGradInput( - THCDeviceTensor gradInput, - THCDeviceTensor gradOutput, - THCDeviceTensor indices) { - // Output (h, w) point that this thread is responsible for - int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; - int plane = blockIdx.y; - int batch = blockIdx.z; - - // Each thread generates a specific output point - if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3) * gradOutput.getSize(4)) { - int outputT = ourOutputPoint % gradOutput.getSize(4); - int outputW = (ourOutputPoint / gradOutput.getSize(4)) % gradOutput.getSize(3); - int outputH = ourOutputPoint / (gradOutput.getSize(3)*gradOutput.getSize(4)); - - int index = indices[batch][plane][outputH][outputW][outputT] - TH_INDEX_BASE; - assert(index >= 0); - int inputT = index % gradInput.getSize(4); - int inputW = (index / gradInput.getSize(4)) % gradInput.getSize(3); - int inputH = index / (gradInput.getSize(3) * gradInput.getSize(4)); - assert(inputH < gradInput.getSize(2)); - - atomicAdd(gradInput[batch][plane][inputH][inputW][inputT].data(), - gradOutput[batch][plane][outputH][outputW][outputT]); - } -} - -#include -#include diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index ecbd9d9..9ce4f85 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -1301,24 +1301,6 @@ THC_API void THNN_(VolumetricDilatedMaxPooling_updateGradInput)( int dilationT, int dilationW, int dilationH, bool ceilMode); -THC_API void THNN_(VolumetricFractionalMaxPooling_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THCIndexTensor *indices, - THCTensor *randomSamples); - -THC_API void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THCIndexTensor *indices); - THC_API void THNN_(VolumetricFullConvolution_updateOutput)( THCState *state, THCTensor *input, diff --git a/aten/src/THCUNN/generic/VolumetricFractionalMaxPooling.cu b/aten/src/THCUNN/generic/VolumetricFractionalMaxPooling.cu deleted file mode 100644 index 36ae581..0000000 --- a/aten/src/THCUNN/generic/VolumetricFractionalMaxPooling.cu +++ /dev/null @@ -1,168 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/VolumetricFractionalMaxPooling.cu" -#else - -void THNN_(VolumetricFractionalMaxPooling_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THCIndexTensor *indices, - THCTensor *randomSamples) -{ - int planeDim = 0; - int dimh = 1; - int dimw = 2; - int dimt = 3; - int64_t numBatch = 1; - - int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input, - "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s"); - - if (numInputDims == 5) { - numBatch = THCTensor_(size)(state, input, 0); - planeDim++; - dimh++; - dimw++; - dimt++; - } - - /* sizes */ - int64_t numPlanes = THCTensor_(size)(state, input, planeDim); - int64_t inputH = THCTensor_(size)(state, input, dimh); - int64_t inputW = THCTensor_(size)(state, input, dimw); - int64_t inputT = THCTensor_(size)(state, input, dimt); - - THArgCheck(outputH + poolSizeH - 1 < inputH, 7, - "poolSizeH (%d) too large relative to input height (%d)", - poolSizeH, inputH); - THArgCheck(outputW + poolSizeW - 1 < inputW, 6, - "poolSizeW (%d) too large relative to input width (%d)", - poolSizeW, inputW); - THArgCheck(outputT + poolSizeT - 1 < inputW, 5, - "poolSizeT (%d) too large relative to input time (%d)", - poolSizeT, inputT); - - THCDeviceTensor devInput; - THCDeviceTensor devOutput; - THCDeviceTensor devIndices; - THCDeviceTensor devSamples = - toDeviceTensor(state, randomSamples); - - if (numInputDims == 4) { - /* resize output */ - THCTensor_(resize4d)(state, output, numPlanes, outputH, outputW, outputT); - /* indices will contain the locations for each output point */ - THCIndexTensor_(resize4d)(state, indices, numPlanes, outputH, outputW, outputT); - - devInput = toDeviceTensor(state, input).upcastOuter<5>(); - devOutput = toDeviceTensor(state, output).upcastOuter<5>(); - devIndices = toDeviceTensor(state, indices).upcastOuter<5>(); - } else { - THCTensor_(resize5d)(state, output, numBatch, numPlanes, outputH, outputW, outputT); - /* indices will contain the locations for each output point */ - THCIndexTensor_(resize5d)(state, indices, numBatch, numPlanes, outputH, outputW, outputT); - - devInput = toDeviceTensor(state, input); - devOutput = toDeviceTensor(state, output); - devIndices = toDeviceTensor(state, indices); - } - - // block is limited to 4 warps - // grid handles overflow per each plane - int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3) * devOutput.getSize(4); - dim3 grid(THCCeilDiv(outputPlaneSize, 128), - devInput.getSize(1), - devInput.getSize(0)); - dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); - -#define SFMP_UPDATE_OUTPUT(POOL_W) \ - VolumetricFractionalMaxPooling_updateOutput \ - <<>>( \ - devInput, devOutput, devIndices, devSamples, poolSizeT, poolSizeW, poolSizeH); - -#define SFMP_UPDATE_OUTPUT_CASE(POOL_W) \ - case POOL_W: SFMP_UPDATE_OUTPUT(POOL_W); break - - switch (poolSizeW) { - SFMP_UPDATE_OUTPUT_CASE(2); - SFMP_UPDATE_OUTPUT_CASE(3); - SFMP_UPDATE_OUTPUT_CASE(4); - SFMP_UPDATE_OUTPUT_CASE(5); - SFMP_UPDATE_OUTPUT_CASE(6); - SFMP_UPDATE_OUTPUT_CASE(7); - default: - // dynamic pool width - SFMP_UPDATE_OUTPUT_CASE(-1); - } - THCudaCheck(cudaGetLastError()); -} - -void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THCIndexTensor *indices) -{ - int dimh = 1; - int dimw = 2; - int dimt = 3; - - int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input); - if (numInputDims == 5) { - dimh++; - dimw++; - dimt++; - } - - /* sizes */ - int64_t inputH = THCTensor_(size)(state, input, dimh); - int64_t inputW = THCTensor_(size)(state, input, dimw); - int64_t inputT = THCTensor_(size)(state, input, dimt); - - THArgCheck(outputH == THCTensor_(size)(state, gradOutput, dimh), 3, - "gradOutput height unexpected"); - THArgCheck(outputW == THCTensor_(size)(state, gradOutput, dimw), 3, - "gradOutput width unexpected"); - THArgCheck(outputT == THCTensor_(size)(state, gradOutput, dimt), 3, - "gradOutput time unexpected"); - - /* resize */ - THCTensor_(resizeAs)(state, gradInput, input); - THCTensor_(zero)(state, gradInput); - - THCDeviceTensor devGradInput; - THCDeviceTensor devGradOutput; - THCDeviceTensor devIndices; - - /* backprop */ - if (numInputDims == 4) { - devGradInput = toDeviceTensor(state, gradInput).upcastOuter<5>(); - devGradOutput = toDeviceTensor(state, gradOutput).upcastOuter<5>(); - devIndices = toDeviceTensor(state, indices).upcastOuter<5>(); - } else { - devGradInput = toDeviceTensor(state, gradInput); - devGradOutput = toDeviceTensor(state, gradOutput); - devIndices = toDeviceTensor(state, indices); - } - - // block is limited to 4 warps - // grid handles overflow per each plane - int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3) * devGradOutput.getSize(4); - dim3 grid(THCCeilDiv(outputPlaneSize, 128), - devGradInput.getSize(1), - devGradInput.getSize(0)); - dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); - - VolumetricFractionalMaxPooling_updateGradInput - <<>>( - devGradInput, devGradOutput, devIndices); - THCudaCheck(cudaGetLastError()); -} - -#endif diff --git a/aten/src/THNN/generic/VolumetricFractionalMaxPooling.c b/aten/src/THNN/generic/VolumetricFractionalMaxPooling.c deleted file mode 100644 index 0726eb2..0000000 --- a/aten/src/THNN/generic/VolumetricFractionalMaxPooling.c +++ /dev/null @@ -1,279 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/VolumetricFractionalMaxPooling.c" -#else - -static int64_t* THNN_(VolumetricFractionalMaxPooling_generateIntervals)( - scalar_t sample, - int64_t inputSize, - int64_t outputSize, - int poolSize) { - scalar_t alpha = (scalar_t) (inputSize - poolSize) / (scalar_t) (outputSize - 1); - int64_t* sequence = (int64_t*) THAlloc(sizeof(int64_t) * outputSize); - - int64_t i; - for (i = 0; i < outputSize - 1; ++i) { - sequence[i] = - (int64_t) ((i + sample) * alpha) - (int64_t) (sample * alpha); - } - sequence[outputSize - 1] = inputSize - poolSize; - - return sequence; -} - -static void THNN_(VolumetricFractionalMaxPooling_updateOutput_frame)( - scalar_t* input, - scalar_t* output, - THIndex_t* indices, - scalar_t* randomSamples, - int64_t numPlanes, - int64_t inputT, int64_t inputW, int64_t inputH, - int64_t outputT, int64_t outputW, int64_t outputH, - int poolSizeT, int poolSizeW, int poolSizeH) { - int64_t plane; -#pragma omp parallel for private(plane) - for (plane = 0; plane < numPlanes; ++plane) { - /* each plane contains 3 random samples, one for T, one for W, and one for H */ - scalar_t* randomSamplesForPlane = randomSamples + plane * 3; - - /* Generate interval sequence */ - int64_t* sequenceT = - THNN_(VolumetricFractionalMaxPooling_generateIntervals)( - randomSamplesForPlane[0], inputT, outputT, poolSizeT); - int64_t* sequenceW = - THNN_(VolumetricFractionalMaxPooling_generateIntervals)( - randomSamplesForPlane[1], inputW, outputW, poolSizeW); - int64_t* sequenceH = - THNN_(VolumetricFractionalMaxPooling_generateIntervals)( - randomSamplesForPlane[2], inputH, outputH, poolSizeH); - - /* loop over output */ - int64_t h, w, t; - - scalar_t* inputForPlane = input + plane * inputT * inputW * inputH; - scalar_t* outputForPlane = output + plane * outputT * outputW * outputH; - THIndex_t* indicesForPlane = indices + plane * outputT * outputW * outputH; - - for (h = 0; h < outputH; ++h) { - int64_t inputHStart = sequenceH[h]; - - for (w = 0; w < outputW; ++w) { - int64_t inputWStart = sequenceW[w]; - - for (t = 0; t < outputT; ++t) { - int64_t inputTStart = sequenceT[t]; - - scalar_t maxVal = -THInf; - int64_t maxIndex = -1; - - int64_t h2, w2, t2; - for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) { - for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) { - for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) { - THAssert(h2 >= 0 && h2 < inputH); - THAssert(w2 >= 0 && w2 < inputW); - THAssert(t2 >= 0 && t2 < inputT); - - int64_t planeIndex = h2 * inputW * inputT + w2 * inputT + t2; - scalar_t val = inputForPlane[planeIndex]; - if (val > maxVal) { - maxVal = val; - maxIndex = planeIndex; - } - } - } - } - - THAssert(maxVal != -THInf); - THAssert(maxIndex != -1); - - outputForPlane[h * outputW * outputT + w * outputT + t] = maxVal; - /* +1 to lua index */ - indicesForPlane[h * outputW * outputT + w * outputT + t] = maxIndex + TH_INDEX_BASE; - } - } - } - - THFree(sequenceT); - THFree(sequenceW); - THFree(sequenceH); - } -} - -void THNN_(VolumetricFractionalMaxPooling_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THIndexTensor *indices, - THTensor *randomSamples) { - - int64_t numBatch = 1; - int planeDim = 0; - int heightDim = 1; - int widthDim = 2; - int timeDim = 3; - - int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input); - THNN_ARGCHECK(!input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input, - "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s"); - - if (numInputDims == 5) { - numBatch = THTensor_(size)(input, 0); - planeDim++; - heightDim++; - widthDim++; - timeDim++; - } - - /* sizes */ - int64_t numPlanes = THTensor_(size)(input, planeDim); - int64_t inputH = THTensor_(size)(input, heightDim); - int64_t inputW = THTensor_(size)(input, widthDim); - int64_t inputT = THTensor_(size)(input, timeDim); - - THArgCheck(outputH + poolSizeH - 1 < inputH, 9, - "poolSizeH (%d) too large relative to input height (%d)", - poolSizeH, inputH); - THArgCheck(outputW + poolSizeW - 1 < inputW, 8, - "poolSizeW (%d) too large relative to input width (%d)", - poolSizeW, inputW); - THArgCheck(outputT + poolSizeT - 1 < inputT, 7, - "poolSizeT (%d) too large relative to input time (%d)", - poolSizeT, inputT); - - /* get contiguous input */ - input = THTensor_(newContiguous)(input); - - if (numInputDims == 4) { - /* resize output */ - THTensor_(resize4d)(output, numPlanes, outputH, outputW, outputT); - /* indices will contain the locations for each output point */ - THIndexTensor_(resize4d)(indices, numPlanes, outputH, outputW, outputT); - - THNN_(VolumetricFractionalMaxPooling_updateOutput_frame)( - input->data(), - output->data(), - THIndexTensor_(data)(indices), - randomSamples->data(), - numPlanes, inputT, inputW, inputH, - outputT, outputW, outputH, poolSizeT, poolSizeW, poolSizeH); - } else { - THTensor_(resize5d)(output, numBatch, numPlanes, outputH, outputW, outputT); - /* indices will contain the locations for each output point */ - THIndexTensor_(resize5d)(indices, numBatch, numPlanes, outputH, outputW, outputT); - - int64_t batch; -#pragma omp parallel for private(batch) - for (batch = 0; batch < numBatch; ++batch) { - THNN_(VolumetricFractionalMaxPooling_updateOutput_frame)( - input->data() + batch * numPlanes * inputH * inputW * inputT, - output->data() + batch * numPlanes * outputH * outputW * outputT, - THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW * outputT, - randomSamples->data() + batch * numPlanes * 3, - numPlanes, inputT, inputW, inputH, - outputT, outputW, outputH, poolSizeT, poolSizeW, poolSizeH); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(input); -} - -static void THNN_(VolumetricFractionalMaxPooling_updateGradInput_frame)( - scalar_t* gradInput, - scalar_t* gradOutput, - THIndex_t* indices, - int64_t numPlanes, - int64_t inputT, int64_t inputW, int64_t inputH, - int64_t outputT, int64_t outputW, int64_t outputH) { - int64_t plane; -#pragma omp parallel for private(plane) - for (plane = 0; plane < numPlanes; plane++) { - scalar_t* gradInputForPlane = gradInput + plane * inputT * inputW * inputH; - scalar_t* gradOutputForPlane = gradOutput + plane * outputT * outputW * outputH; - THIndex_t* indicesForPlane = indices + plane * outputT * outputW * outputH; - - int64_t h, w, t; - for (h = 0; h < outputH; ++h) { - for (w = 0; w < outputW; ++w) { - for (t = 0; t < outputT; ++t) { - int64_t outputIndex = h * outputW * outputT + w * outputT + t; - int64_t index = indicesForPlane[outputIndex] - TH_INDEX_BASE; - THAssert(index >= 0 && index < inputT * inputW * inputH); - - gradInputForPlane[index] += gradOutputForPlane[outputIndex]; - } - } - } - } -} - -void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int outputT, int outputW, int outputH, - int poolSizeT, int poolSizeW, int poolSizeH, - THIndexTensor *indices) { - - int64_t numBatch = 1; - int planeDim = 0; - int heightDim = 1; - int widthDim = 2; - int timeDim = 3; - - int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input); - if (numInputDims == 5) { - numBatch = THTensor_(size)(input, 0); - planeDim = 1; - heightDim++; - widthDim++; - timeDim++; - } - - /* sizes */ - int64_t numPlanes = THTensor_(size)(input, planeDim); - int64_t inputH = THTensor_(size)(input, heightDim); - int64_t inputW = THTensor_(size)(input, widthDim); - int64_t inputT = THTensor_(size)(input, timeDim); - - THArgCheck(outputT == THTensor_(size)(gradOutput, timeDim), 3, - "gradOutput time unexpected"); - THArgCheck(outputW == THTensor_(size)(gradOutput, widthDim), 3, - "gradOutput width unexpected"); - THArgCheck(outputH == THTensor_(size)(gradOutput, heightDim), 3, - "gradOutput height unexpected"); - - /* get contiguous gradOutput */ - gradOutput = THTensor_(newContiguous)(gradOutput); - - /* resize */ - THTensor_(resizeAs)(gradInput, input); - THTensor_(zero)(gradInput); - - /* backprop */ - if (numInputDims == 4) { - THNN_(VolumetricFractionalMaxPooling_updateGradInput_frame)( - gradInput->data(), - gradOutput->data(), - THIndexTensor_(data)(indices), - numPlanes, inputT, inputW, inputH, outputT, outputW, outputH); - } else { - int64_t batch; -#pragma omp parallel for private(batch) - for (batch = 0; batch < numBatch; ++batch) { - THNN_(VolumetricFractionalMaxPooling_updateGradInput_frame)( - gradInput->data() + batch * numPlanes * inputH * inputW * inputT, - gradOutput->data() + batch * numPlanes * outputH * outputW * outputT, - THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW * outputT, - numPlanes, inputT, inputW, inputH, outputT, outputW, outputH); - } - } - - /* cleanup */ - c10::raw::intrusive_ptr::decref(gradOutput); -} - -#endif diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index 45b073a..50a30c7 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -193,9 +193,6 @@ #include #include -#include -#include - #include #include diff --git a/test/common_nn.py b/test/common_nn.py index 7372468..5196ce3 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -848,6 +848,28 @@ def fractional_max_pool2d_test(test_case): fullname='FractionalMaxPool2d_size') +def fractional_max_pool3d_test(test_case): + random_samples = torch.DoubleTensor(2, 4, 3).uniform_() + if test_case == 'ratio': + return dict( + constructor=lambda: nn.FractionalMaxPool3d( + 2, output_ratio=0.5, _random_samples=random_samples), + input_size=(2, 4, 5, 5, 5), + fullname='FractionalMaxPool3d_ratio') + elif test_case == 'size': + return dict( + constructor=lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=( + 4, 4, 4), _random_samples=random_samples), + input_size=(2, 4, 7, 7, 7), + fullname='FractionalMaxPool3d_size') + elif test_case == 'asymsize': + return dict( + constructor=lambda: nn.FractionalMaxPool3d((4, 2, 3), output_size=( + 10, 3, 2), _random_samples=random_samples), + input_size=(2, 4, 16, 7, 5), + fullname='FractionalMaxPool3d_asymsize') + + new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), @@ -892,6 +914,9 @@ new_module_tests = [ multimarginloss_weights_no_reduce_test(), fractional_max_pool2d_test('ratio'), fractional_max_pool2d_test('size'), + fractional_max_pool3d_test('ratio'), + fractional_max_pool3d_test('size'), + fractional_max_pool3d_test('asymsize'), dict( module_name='BatchNorm1d', constructor_args=(10,), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9ec8f9e..5eef168 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1069,6 +1069,9 @@ - name: fractional_max_pool2d(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples) self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, indices) +- name: fractional_max_pool3d(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples) + self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, indices) + - name: max_pool2d_with_indices(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices) @@ -1164,6 +1167,10 @@ grad_output: max_pool_double_backward(grad, indices, 2) self: zeros_like(self) +- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList output_size, Tensor indices) + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + - name: glu_backward(Tensor grad_output, Tensor self, int64_t dim) grad_output: glu_double_backward_grad_output(grad, self, dim) self: glu_double_backward(grad, grad_output, self, dim) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 9f1494c..7c94060 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -369,6 +369,76 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch( @weak_script +def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa + r"""Applies 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) + or a tuple :math:`(kT \times kH \times kW)` + output_size: the target output size of the form :math:`oT \times oH \times oW`. + Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output + :math:`oH \times oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. + + Examples:: + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if output_size is None and output_ratio is None: + raise ValueError("fractional_max_pool3d requires specifying either " + "an output_size or an output_ratio") + if output_size is None: + _output_ratio = _triple(torch.jit._unwrap_optional(output_ratio)) + _output_size = [int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1]), + int(input.size(4) * _output_ratio[2])] + else: + _output_size = torch.jit._unwrap_optional(output_size) + + if _random_samples is None: + _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device) + else: + _random_samples = torch.jit._unwrap_optional(_random_samples) + return torch._C._nn.fractional_max_pool3d(input, kernel_size, _output_size, _random_samples) + + +@weak_script +def _fractional_max_pool3d(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa + return fractional_max_pool3d_with_indices(input, kernel_size, output_size, + output_ratio, return_indices, + _random_samples)[0] + +fractional_max_pool3d = torch._jit_internal.boolean_dispatch( + arg_name='return_indices', + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d) + + +@weak_script def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index f268375..a14d6d9 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -11,8 +11,8 @@ from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLos SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ - MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, LPPool1d, LPPool2d, AdaptiveMaxPool1d, \ - AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d + MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \ + AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm @@ -38,7 +38,7 @@ __all__ = [ 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', - 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', + 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d", 'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index e92c5ea..6d6129d 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -705,6 +705,66 @@ class FractionalMaxPool2d(Module): @weak_module +class FractionalMaxPool3d(Module): + r"""Applies a 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kTxkHxkW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)` + output_size: the target output size of the image of the form `oT x oH x oW`. + Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False`` + + Examples: + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)) + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> output = m(input) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + __constants__ = ['kernel_size', 'return_indices', 'output_size', + 'output_ratio'] + + def __init__(self, kernel_size, output_size=None, output_ratio=None, + return_indices=False, _random_samples=None): + super(FractionalMaxPool3d, self).__init__() + self.kernel_size = _triple(kernel_size) + self.return_indices = return_indices + self.register_buffer('_random_samples', _random_samples) + self.output_size = _triple(output_size) if output_size is not None else None + self.output_ratio = _triple(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError("FractionalMaxPool3d requires specifying either " + "an output size, or a pooling ratio") + if output_size is not None and output_ratio is not None: + raise ValueError("only one of output_size and output_ratio may be specified") + if self.output_ratio is not None: + if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1): + raise ValueError("output_ratio must be between 0 and 1 (got {})" + .format(output_ratio)) + + @weak_script_method + def forward(self, input): + return F.fractional_max_pool3d( + input, self.kernel_size, self.output_size, self.output_ratio, + self.return_indices, + _random_samples=self._random_samples) + + +@weak_module class _LPPoolNd(Module): __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode'] -- 2.7.4