From da87d648b3dae56635008ef8f86a9fa41518cbbc Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 17 Aug 2021 16:51:34 -0700 Subject: [PATCH] `F.avg_pool3` CUDA backward: gpuAtomicAddNoReturn -> fastAtomicAdd (#63387) Summary: Rel: https://github.com/pytorch/pytorch/issues/62695 In the following two tables, I set `kernel_size` to 3 and `stride` to 2. In benchmark, input tensors have the shape of (N, C, n_features, n_features, n_features). Tested on RTX3080 w/ CUDA11.4 Update 1. ## This PR | N | C | n_features | dtype | time | |----:|----:|-------------:|:--------------|------------:| | 32 | 3 | 8 | torch.float16 | 7.46846e-05 | | 32 | 3 | 8 | torch.float32 | 8.18968e-05 | | 32 | 3 | 32 | torch.float16 | 0.000156748 | | 32 | 3 | 32 | torch.float32 | 0.000165236 | | 32 | 3 | 128 | torch.float16 | 0.00549854 | | 32 | 3 | 128 | torch.float32 | 0.008926 | ## master (6acd87f) | N | C | n_features | dtype | time | |----:|----:|-------------:|:--------------|------------:| | 32 | 3 | 8 | torch.float16 | 7.60436e-05 | | 32 | 3 | 8 | torch.float32 | 7.55072e-05 | | 32 | 3 | 32 | torch.float16 | 0.000189292 | | 32 | 3 | 32 | torch.float32 | 0.000168645 | | 32 | 3 | 128 | torch.float16 | 0.00699538 | | 32 | 3 | 128 | torch.float32 | 0.00890226 | master's time divided by PR's time is as follows: | N | C | n_features | master / PR | |---:|---:|---------------:|----------------:| | 32 | 3 | 8 | 1.018 | | 32 | 3 | 32 | 1.208 | | 32 | 3 | 128 | 1.272| cc: xwang233 ptrblck ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/63387 Reviewed By: mruberry Differential Revision: D30381434 Pulled By: ngimel fbshipit-source-id: 3b97aee4b0d457a0277a0d31ac56d4151134c099 --- aten/src/ATen/native/cuda/AveragePool3d.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 671b354..6c712af 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -210,7 +211,7 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( int dT, int dH, int dW, int padT, int padH, int padW, bool count_include_pad, - int offsetZ, int divisor_override) + int offsetZ, int divisor_override, const int gradInput_numel) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -253,7 +254,8 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( { for (int iCol = wstart; iCol < wend; ++iCol) { - gpuAtomicAddNoReturn(&gradInput[slice][iFrame][iRow][iCol], val); + const int index = slice * gradInput.stride(0) + iFrame * gradInput.stride(1) + iRow * gradInput.stride(2) + iCol * gradInput.stride(3); + fastAtomicAdd(gradInput.data(), index, gradInput_numel, val, true); } } } @@ -568,7 +570,7 @@ TORCH_IMPL_FUNC(avg_pool3d_backward_out_cuda) ( dT, dH, dW, padT, padH, padW, count_include_pad, - offsetZ, divisor); + offsetZ, divisor, work_grad_input.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { -- 2.7.4