From d0b207e68bc4e390cbd3dd64c8f116ba0a162d3e Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 8 Sep 2021 22:07:12 -0700 Subject: [PATCH] Migrate uses of THCReduceApplyUtils to cuda_utils::BlockReduce (#64713) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64713 Resubmit of #64442 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D30825646 Pulled By: ngimel fbshipit-source-id: 66b06bd0b30b401833e337920681d19d96b11f9d --- aten/src/ATen/native/cuda/Embedding.cu | 16 ++-- aten/src/ATen/native/cuda/MultinomialKernel.cu | 39 ++++++--- aten/src/ATen/native/cuda/TensorModeKernel.cu | 21 +++-- aten/src/ATen/native/cuda/TensorModeKernel.cuh | 71 ++++++++------- aten/src/ATen/native/cuda/block_reduce.cuh | 7 +- aten/src/THC/CMakeLists.txt | 2 - aten/src/THC/THCReduceApplyUtils.cu | 35 -------- aten/src/THC/THCReduceApplyUtils.cuh | 117 ------------------------- 8 files changed, 92 insertions(+), 216 deletions(-) delete mode 100644 aten/src/THC/THCReduceApplyUtils.cu delete mode 100644 aten/src/THC/THCReduceApplyUtils.cuh diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 155b389..37747c0 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -5,14 +5,11 @@ #include #include -#include -#include -#include - #include #include #include +#include namespace at { namespace native { @@ -202,8 +199,7 @@ __global__ void renorm_kernel( } } - using Op = ReduceAdd; - v = reduceBlock(sdata, blockDim.x, v, Op(), 0); + v = cuda_utils::BlockReduceSum(v, sdata); if (tid == 0) { sdata[0] = std::pow(v, static_cast(1.0 / norm_type)); @@ -314,13 +310,17 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, num_indices ); + constexpr int num_threads = 128; + static_assert(num_threads % C10_WARP_SIZE == 0 && + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, + "BlockReduceSum requires all warps be active"); dim3 grid = num_unique_indices.item(); - dim3 block = 128; + dim3 block = num_threads; int dim = self.stride(0); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] { using accscalar_t = acc_type; - renorm_kernel<<>>( + renorm_kernel<<>>( self.data_ptr(), unique_indices.data_ptr(), static_cast(max_norm), diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 65c45e7..6efc350 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -10,10 +10,6 @@ #include #include -#include -#include -#include - #include #include #include @@ -22,6 +18,20 @@ namespace at { namespace native { namespace { +int64_t div_up(int64_t a, int64_t b) { + return (a + (b - 1)) / b; +} + +template +inline __device__ bool _isinf(T x) { return ::isinf(x); } + +inline __device__ bool _isinf(c10::Half x) { + return ::isinf(static_cast(x)); +} +inline __device__ bool _isinf(c10::BFloat16 x) { + return ::isinf(static_cast(x)); +} + #define MAX_NUM_BLOCKS 200 // Normalizes the L1 norm of every row to 1; used by multinomial @@ -36,13 +46,13 @@ __global__ void renormRowsL1(scalar_t* dist, long rows, long cols) { scalar_t sum = static_cast(0); for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) { val = dist[row * cols + col]; - CUDA_KERNEL_ASSERT(!THCNumerics::lt(val, zero)); // ! < 0 for NaN handling + CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling sum = sum + val; } - sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd(), zero); + sum = cuda_utils::BlockReduceSum(sum, smem); if (threadIdx.x == 0) { - CUDA_KERNEL_ASSERT(!THCNumerics::lt(val, zero)); // ! < 0 for NaN handling + CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling smem[0] = sum; } __syncthreads(); @@ -64,14 +74,15 @@ void renormRows(Tensor& t) { auto props = at::cuda::getCurrentDeviceProperties(); CUDA_KERNEL_ASSERT(props != NULL); int numSM = props->multiProcessorCount; - int maxThreads = props->maxThreadsPerBlock; + const int64_t maxThreads = std::min( + props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); dim3 grid(rows < numSM * 4 ? rows : numSM * 4); - dim3 block(cols < maxThreads ? cols : maxThreads); + dim3 block(std::min(maxThreads, C10_WARP_SIZE * div_up(cols, C10_WARP_SIZE))); AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(), "renormRows_cuda", [&] { renormRowsL1 - <<>>(t.data_ptr(), rows, cols); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -190,9 +201,9 @@ __global__ void sampleMultinomialOnce( scalar_t val; for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { val = dist[curDist * stride_dist + cat * stride_categories]; - CUDA_KERNEL_ASSERT(!THCNumerics::isnan(val)); - CUDA_KERNEL_ASSERT(!THCNumerics::isinf(val)); - CUDA_KERNEL_ASSERT(val >= zero); + CUDA_KERNEL_ASSERT(!at::_isnan(val)); + CUDA_KERNEL_ASSERT(!_isinf(val)); + CUDA_KERNEL_ASSERT(!(val < zero)); sum = sum + static_cast(val); } @@ -202,7 +213,7 @@ __global__ void sampleMultinomialOnce( // Broadcast sum and sample value if (threadIdx.x == 0) { // Make sure the sum of our distribution didn't overflow - CUDA_KERNEL_ASSERT(!THCNumerics::isinf(sum)); + CUDA_KERNEL_ASSERT(!_isinf(val)); CUDA_KERNEL_ASSERT(sum > accZero); foundPos = 0; diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index c24d7fb..dbbcdc9 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -138,11 +138,13 @@ void handle_fused_mode( cuda::detail::TensorInfo& ti_indices, int64_t slice_size, int64_t slices) { - const dim3 block(size / 2); + constexpr int num_threads = size / 2; + static_assert(num_threads % C10_WARP_SIZE == 0 && + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, ""); const auto memsize = (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); compute_mode - <<>>( + <<>>( self.data_ptr(), ti_values, ti_indices, slice_size, slices); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -182,16 +184,19 @@ void fused_mode( break; case 128: case 64: - handle_fused_mode<128, scalar_t>( - grid, self, ti_values, ti_indices, slice_size, slices); - break; case 32: case 16: case 8: case 4: - case 2: - handle_fused_mode<32, scalar_t>( - grid, self, ti_values, ti_indices, slice_size, slices); + case 2: { + if (ceilPowerOf2 > 2 * C10_WARP_SIZE) { + handle_fused_mode<128, scalar_t>( + grid, self, ti_values, ti_indices, slice_size, slices); + } else { + handle_fused_mode<2 * C10_WARP_SIZE, scalar_t>( + grid, self, ti_values, ti_indices, slice_size, slices); + } + } break; case 1: default: diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cuh b/aten/src/ATen/native/cuda/TensorModeKernel.cuh index f45b9b3..d29454b 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cuh +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cuh @@ -3,8 +3,7 @@ #include #include #include - -#include +#include namespace at { namespace native { @@ -88,11 +87,10 @@ __device__ T reduceBlockWithNThreadLocalReductions( for (int i = 1; i < N; ++i) { ++offset; T next = offset < numVals ? threadVals[i] : init; - local = reduceOp(local, next); + local = reduceOp.combine(local, next); } - return reduceBlock( - smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init); + return cuda_utils::BlockReduce(local, reduceOp, init, smem); } template @@ -354,13 +352,24 @@ __global__ void compute_mode( struct ModeUnsignedPair max = {0, 0}; + struct MaxOp { + inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const { + return b.val > a.val ? b : a; + } + + inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const { + ModeUnsignedPair ret; + ret.index = WARP_SHFL_DOWN(acc.index, offset); + ret.val = WARP_SHFL_DOWN(acc.val, offset); + return ret; + } + } max_op; + max = reduceBlockWithNThreadLocalReductions<2>( uupmem, uup, sliceSize, - [=] GPU_LAMBDA(const auto& a, const auto& b) { - return b.val > a.val ? b : a; - }, + max_op, max); // Store the mode in shared memory for use in finding the mode in the input @@ -374,43 +383,43 @@ __global__ void compute_mode( } __syncthreads(); // broadcast mode - // Finally, we need to find the "an" index of the mode in the input Tensor. - // The API does not constrain which index we pick, so it can be any of the - // indices that contain the mode. We will do a reduction to find the index. We - // go back to using the (index, flag) buffer arrangement. First, we mark - // indices that are equal to the mode, i.e B[i] = true if input[i] == mode, - // and initialize C[i] to be the index + // Finally, we need to find "an" index of the mode in the input + // Tensor. The API does not constrain which index we pick, but here + // we always pick the largest index. We store the index if the value + // is the mode, or 0 otherwise. Then find the maximum value. // // Again we reduce 2 elements in the thread's registers prior to the // block-wide reduction - struct ModeUnsignedBoolPair ubpp[2]; + unsigned mode_index[2] = {0u, 0u}; if (tidx * 2 < sliceSize) { - ubpp[0].flag = input[linearOffset + (tidx * 2)] == mode; - ubpp[0].val = tidx * 2; + const unsigned idx = tidx * 2; + mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u; } if (tidx * 2 + 1 < sliceSize) { - ubpp[1].flag = input[linearOffset + (tidx * 2 + 1)] == mode; - ubpp[1].val = tidx * 2 + 1; + const unsigned idx = tidx * 2 + 1; + mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u; } - // Then we perform a similar reduction to the one above, except this time we - // update the element if the element at the base position is not equal to the - // mode and the element at the offset position is. At the end, C[0] will - // contain an index with the mode. - struct ModeUnsignedBoolPair match = {0, false}; + struct MaxIndexOp { + inline __device__ unsigned combine(unsigned a, unsigned b) const { + return b > a ? b : a; + } + + inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } + } max_index_op; - match = reduceBlockWithNThreadLocalReductions<2>( - ubpmem, - ubpp, + int64_t index = reduceBlockWithNThreadLocalReductions<2>( + reinterpret_cast(&shmem[0]), + mode_index, sliceSize, - [=] GPU_LAMBDA(const auto& a, const auto& b) { return b.flag ? b : a; }, - match); + max_index_op, + 0u); // Finally, we have the mode, and an index where it occurs. We use a single // thread to place this in the appropriate output position if (tidx == 0) { - int64_t index = match.val; - unsigned int outputOffset = at::cuda::detail::IndexToOffset::get( blockId, values); diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index a3f600e4..095ce7a 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -10,6 +10,11 @@ namespace native { namespace cuda_utils { constexpr int kCUDABlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; // Sums `val` accross all threads in a warp. // @@ -41,7 +46,7 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) { shared[wid] = val; } __syncthreads(); - val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : 0; + val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : T(0); if (wid == 0) { val = WarpReduceSum(val); } diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index ab7f72b..ca666b4 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -11,7 +11,6 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/THCReduceApplyUtils.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCSleep.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCStorage.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cu @@ -30,7 +29,6 @@ install(FILES THCTensor.h THCTensorCopy.h THCTensorCopy.hpp - THCReduceApplyUtils.cuh THCTensorMathReduce.cuh THCAsmUtils.cuh THCAtomics.cuh diff --git a/aten/src/THC/THCReduceApplyUtils.cu b/aten/src/THC/THCReduceApplyUtils.cu deleted file mode 100644 index d586f8c..0000000 --- a/aten/src/THC/THCReduceApplyUtils.cu +++ /dev/null @@ -1,35 +0,0 @@ -#include - -#include -#include - -// Maximum size per grid dimension that we assume (compute capability >= 2.0) -#define MAX_GRID_SIZE 65535LL - -void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg) { - int64_t dims = THCudaTensor_nDimensionLegacyAll(state, tensor); - THArgCheck(dims <= MAX_CUTORCH_DIMS, arg, CUTORCH_DIM_WARNING); -} - -bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid) { - if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { - return false; - } - - int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; - int64_t gridY = 1; - int64_t gridZ = 1; - - if (gridTiles > MAX_GRID_SIZE) { - gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE); - gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; - - if (gridTiles > MAX_GRID_SIZE) { - gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE); - gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; - } - } - - grid = dim3(gridX, gridY, gridZ); - return true; -} diff --git a/aten/src/THC/THCReduceApplyUtils.cuh b/aten/src/THC/THCReduceApplyUtils.cuh deleted file mode 100644 index cf0d923..0000000 --- a/aten/src/THC/THCReduceApplyUtils.cuh +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef THC_REDUCE_APPLY_UTILS_INC -#define THC_REDUCE_APPLY_UTILS_INC - -#include -#include -#include -#include -#include -#include - -// Enum that indicates whether tensor arguments are read/write or -// read-only -enum TensorArgType { ReadWrite, ReadOnly }; - -// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: -// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0 -// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20) -// -// If smem is not used again, there is no need to __syncthreads before this -// call. However, if smem will be used, e.g., this function is called in a loop, -// then __syncthreads is needed either before or afterwards to prevent non-0 -// threads overriding smem in the next loop before num-0 thread reads from it. -template -__device__ void reduceNValuesInBlock(T *smem, - T threadVals[N], - const unsigned int numVals, - ReduceOp reduceOp, - T init) { - if (numVals == 0) { - #pragma unroll - for (int i = 0; i < N; ++i) { - threadVals[i] = init; - } - return; - } - - // We store each of the N values contiguously, so if N = 2, all values for - // the first threadVal for each thread in the block are stored followed by - // all of the values for the second threadVal for each thread in the block - if (threadIdx.x < numVals) { - #pragma unroll - for (int i = 0; i < N; ++i) { - smem[i * numVals + threadIdx.x] = threadVals[i]; - } - } - __syncthreads(); - - // Number of lanes in the final reduction --> this is used to determine - // where to put the outputs of each of the n things we are reducing. If - // nLP = 32, then we have the 32 outputs for the first threadVal, - // followed by the 32 outputs for the second threadVal, etc. - const unsigned int numLanesParticipating = min(numVals, warpSize); - - if (numVals > warpSize && ((threadIdx.x / warpSize) == 0 )) { - #pragma unroll - for (int i = 0; i < N; ++i) { - threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init; - } - - for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) { - #pragma unroll - for (int j = 0; j < N; ++j) { - threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]); - } - } - - #pragma unroll - for (int i = 0; i < N; ++i) { - smem[i * numLanesParticipating + threadIdx.x] = threadVals[i]; - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - if (numLanesParticipating == 32) { - #pragma unroll - for (int i = 0; i < N; ++i) { - #pragma unroll - for (int j = 1; j < 32; ++j) { - threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]); - } - } - } else { - #pragma unroll - for (int i = 0; i < N; ++i) { - for (int j = 1; j < numLanesParticipating; ++j) { - threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]); - } - } - } - } -} - -// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will -// return the reduced value -// -// If smem is not used again, there is no need to __syncthreads before this -// call. However, if smem will be used, e.g., this function is called in a loop, -// then __syncthreads is needed either before or afterwards to prevent non-0 -// threads overriding smem in the next loop before num-0 thread reads from it. -template -__device__ T reduceBlock(T* smem, - const unsigned int numVals, - T threadVal, - ReduceOp reduceOp, - T init) { - reduceNValuesInBlock(smem, &threadVal, numVals, reduceOp, init); - return threadVal; -} - -// Make sure the given tensor doesn't have too many dimensions -void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg); - -// Produces a grid with at least one point per tile -TORCH_CUDA_CU_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid); - -#endif // THC_REDUCE_APPLY_UTILS_INC -- 2.7.4