Migrate uses of THCReduceApplyUtils to cuda_utils::BlockReduce (#64442)
authorPeter Bell <peterbell10@live.co.uk>
Wed, 8 Sep 2021 17:57:30 +0000 (10:57 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 18:02:12 +0000 (11:02 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64442

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30735341

Pulled By: ngimel

fbshipit-source-id: 3cb58bed8f1f5aa32fd49fd37b10c8490bcc645a

aten/src/ATen/native/cuda/Embedding.cu
aten/src/ATen/native/cuda/MultinomialKernel.cu
aten/src/ATen/native/cuda/TensorModeKernel.cu
aten/src/ATen/native/cuda/TensorModeKernel.cuh
aten/src/ATen/native/cuda/block_reduce.cuh
aten/src/THC/CMakeLists.txt
aten/src/THC/THCReduceApplyUtils.cu [deleted file]
aten/src/THC/THCReduceApplyUtils.cuh [deleted file]

index 155b389..354c106 100644 (file)
@@ -5,14 +5,11 @@
 #include <c10/util/Exception.h>
 #include <c10/macros/Macros.h>
 
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCReduceApplyUtils.cuh>
-
 #include <ATen/cuda/cub.cuh>
 
 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
+#include <ATen/native/cuda/block_reduce.cuh>
 
 namespace at { namespace native {
 
@@ -202,8 +199,7 @@ __global__ void renorm_kernel(
     }
   }
 
-  using Op = ReduceAdd<accscalar_t>;
-  v = reduceBlock<accscalar_t>(sdata, blockDim.x, v, Op(), 0);
+  v = cuda_utils::BlockReduceSum(v, sdata);
 
   if (tid == 0) {
     sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
@@ -315,12 +311,15 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
     );
 
     dim3 grid = num_unique_indices.item<int64_t>();
-    dim3 block = 128;
+    constexpr int num_threads = 128;
     int dim = self.stride(0);
 
     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
-      renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
+      static_assert(num_threads % C10_WARP_SIZE == 0 &&
+                    num_threads <= cuda_utils::kCUDABlockReduceMaxThreads,
+                    "BlockReduceSum requires all warps be active");
+      renorm_kernel<<<grid, num_threads, (num_threads / C10_WARP_SIZE) * sizeof(accscalar_t), stream>>>(
         self.data_ptr<scalar_t>(),
         unique_indices.data_ptr<index_t>(),
         static_cast<accscalar_t>(max_norm),
index 65c45e7..6efc350 100644 (file)
 #include <ATen/cuda/CUDAGraphsUtils.cuh>
 #include <ATen/native/cuda/block_reduce.cuh>
 
-#include <THC/THCReduceApplyUtils.cuh>
-#include <THC/THCTensorMathReduce.cuh>
-#include <THC/THCNumerics.cuh>
-
 #include <curand.h>
 #include <curand_kernel.h>
 #include <curand_philox4x32_x.h>
@@ -22,6 +18,20 @@ namespace at { namespace native {
 
 namespace {
 
+int64_t div_up(int64_t a, int64_t b) {
+  return (a + (b - 1)) / b;
+}
+
+template <typename T>
+inline __device__ bool _isinf(T x) { return ::isinf(x); }
+
+inline __device__ bool _isinf(c10::Half x) {
+  return ::isinf(static_cast<float>(x));
+}
+inline __device__ bool _isinf(c10::BFloat16 x) {
+  return ::isinf(static_cast<float>(x));
+}
+
 #define MAX_NUM_BLOCKS 200
 
 // Normalizes the L1 norm of every row to 1; used by multinomial
@@ -36,13 +46,13 @@ __global__ void renormRowsL1(scalar_t* dist, long rows, long cols) {
     scalar_t sum = static_cast<scalar_t>(0);
     for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
       val = dist[row * cols + col];
-      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
+      CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
       sum = sum + val;
     }
 
-    sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<scalar_t>(), zero);
+    sum = cuda_utils::BlockReduceSum(sum, smem);
     if (threadIdx.x == 0) {
-      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
+      CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
       smem[0] = sum;
     }
     __syncthreads();
@@ -64,14 +74,15 @@ void renormRows(Tensor& t) {
   auto props = at::cuda::getCurrentDeviceProperties();
   CUDA_KERNEL_ASSERT(props != NULL);
   int numSM = props->multiProcessorCount;
-  int maxThreads = props->maxThreadsPerBlock;
+  const int64_t maxThreads = std::min(
+      props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
 
   dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
-  dim3 block(cols < maxThreads ? cols : maxThreads);
+  dim3 block(std::min(maxThreads, C10_WARP_SIZE * div_up(cols, C10_WARP_SIZE)));
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(), "renormRows_cuda", [&] {
     renormRowsL1<scalar_t>
-        <<<grid, block, block.x * sizeof(scalar_t),
+        <<<grid, block, (block.x / C10_WARP_SIZE) * sizeof(scalar_t),
         at::cuda::getCurrentCUDAStream()>>>(t.data_ptr<scalar_t>(),
             rows, cols);
     C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -190,9 +201,9 @@ __global__ void sampleMultinomialOnce(
     scalar_t val;
     for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
       val = dist[curDist * stride_dist + cat * stride_categories];
-      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::isnan(val));
-      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::isinf(val));
-      CUDA_KERNEL_ASSERT(val >= zero);
+      CUDA_KERNEL_ASSERT(!at::_isnan(val));
+      CUDA_KERNEL_ASSERT(!_isinf(val));
+      CUDA_KERNEL_ASSERT(!(val < zero));
       sum = sum + static_cast<accscalar_t>(val);
     }
 
@@ -202,7 +213,7 @@ __global__ void sampleMultinomialOnce(
     // Broadcast sum and sample value
     if (threadIdx.x == 0) {
       // Make sure the sum of our distribution didn't overflow
-      CUDA_KERNEL_ASSERT(!THCNumerics<accscalar_t>::isinf(sum));
+      CUDA_KERNEL_ASSERT(!_isinf(val));
       CUDA_KERNEL_ASSERT(sum > accZero);
 
       foundPos = 0;
index c24d7fb..dbbcdc9 100644 (file)
@@ -138,11 +138,13 @@ void handle_fused_mode(
     cuda::detail::TensorInfo<int64_t, unsigned int>& ti_indices,
     int64_t slice_size,
     int64_t slices) {
-  const dim3 block(size / 2);
+  constexpr int num_threads = size / 2;
+  static_assert(num_threads % C10_WARP_SIZE == 0 &&
+                num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
   const auto memsize =
       (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
   compute_mode<scalar_t, size>
-      <<<grid, block, memsize, at::cuda::getCurrentCUDAStream()>>>(
+      <<<grid, num_threads, memsize, at::cuda::getCurrentCUDAStream()>>>(
           self.data_ptr<scalar_t>(), ti_values, ti_indices, slice_size, slices);
   C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
@@ -182,16 +184,19 @@ void fused_mode(
       break;
     case 128:
     case 64:
-      handle_fused_mode<128, scalar_t>(
-          grid, self, ti_values, ti_indices, slice_size, slices);
-      break;
     case 32:
     case 16:
     case 8:
     case 4:
-    case 2:
-      handle_fused_mode<32, scalar_t>(
-          grid, self, ti_values, ti_indices, slice_size, slices);
+    case 2: {
+      if (ceilPowerOf2 > 2 * C10_WARP_SIZE) {
+        handle_fused_mode<128, scalar_t>(
+            grid, self, ti_values, ti_indices, slice_size, slices);
+      } else {
+        handle_fused_mode<2 * C10_WARP_SIZE, scalar_t>(
+            grid, self, ti_values, ti_indices, slice_size, slices);
+      }
+    }
       break;
     case 1:
     default:
index f45b9b3..d29454b 100644 (file)
@@ -3,8 +3,7 @@
 #include <ATen/cuda/detail/IndexUtils.cuh>
 #include <ATen/native/cuda/Loops.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
-
-#include <THC/THCReduceApplyUtils.cuh>
+#include <ATen/native/cuda/block_reduce.cuh>
 
 namespace at {
 namespace native {
@@ -88,11 +87,10 @@ __device__ T reduceBlockWithNThreadLocalReductions(
   for (int i = 1; i < N; ++i) {
     ++offset;
     T next = offset < numVals ? threadVals[i] : init;
-    local = reduceOp(local, next);
+    local = reduceOp.combine(local, next);
   }
 
-  return reduceBlock<T, ReduceOp>(
-      smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init);
+  return cuda_utils::BlockReduce(local, reduceOp, init, smem);
 }
 
 template <typename T>
@@ -354,13 +352,24 @@ __global__ void compute_mode(
 
   struct ModeUnsignedPair max = {0, 0};
 
+  struct MaxOp {
+    inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
+      return b.val > a.val ? b : a;
+    }
+
+    inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
+      ModeUnsignedPair ret;
+      ret.index = WARP_SHFL_DOWN(acc.index, offset);
+      ret.val = WARP_SHFL_DOWN(acc.val, offset);
+      return ret;
+    }
+  } max_op;
+
   max = reduceBlockWithNThreadLocalReductions<2>(
       uupmem,
       uup,
       sliceSize,
-      [=] GPU_LAMBDA(const auto& a, const auto& b) {
-        return b.val > a.val ? b : a;
-      },
+      max_op,
       max);
 
   // Store the mode in shared memory for use in finding the mode in the input
@@ -374,43 +383,43 @@ __global__ void compute_mode(
   }
   __syncthreads(); // broadcast mode
 
-  // Finally, we need to find the "an" index of the mode in the input Tensor.
-  // The API does not constrain which index we pick, so it can be any of the
-  // indices that contain the mode. We will do a reduction to find the index. We
-  // go back to using the (index, flag) buffer arrangement. First, we mark
-  // indices that are equal to the mode, i.e B[i] = true if input[i] == mode,
-  // and initialize C[i] to be the index
+  // Finally, we need to find "an" index of the mode in the input
+  // Tensor. The API does not constrain which index we pick, but here
+  // we always pick the largest index. We store the index if the value
+  // is the mode, or 0 otherwise. Then find the maximum value.
   //
   // Again we reduce 2 elements in the thread's registers prior to the
   // block-wide reduction
-  struct ModeUnsignedBoolPair ubpp[2];
+  unsigned mode_index[2] = {0u, 0u};
   if (tidx * 2 < sliceSize) {
-    ubpp[0].flag = input[linearOffset + (tidx * 2)] == mode;
-    ubpp[0].val = tidx * 2;
+    const unsigned idx = tidx * 2;
+    mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u;
   }
   if (tidx * 2 + 1 < sliceSize) {
-    ubpp[1].flag = input[linearOffset + (tidx * 2 + 1)] == mode;
-    ubpp[1].val = tidx * 2 + 1;
+    const unsigned idx = tidx * 2 + 1;
+    mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u;
   }
 
-  // Then we perform a similar reduction to the one above, except this time we
-  // update the element if the element at the base position is not equal to the
-  // mode and the element at the offset position is. At the end, C[0] will
-  // contain an index with the mode.
-  struct ModeUnsignedBoolPair match = {0, false};
+  struct MaxIndexOp {
+    inline __device__ unsigned combine(unsigned a, unsigned b) const {
+      return b > a ? b : a;
+    }
+
+    inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
+      return WARP_SHFL_DOWN(acc, offset);
+    }
+  } max_index_op;
 
-  match = reduceBlockWithNThreadLocalReductions<2>(
-      ubpmem,
-      ubpp,
+  int64_t index = reduceBlockWithNThreadLocalReductions<2>(
+      reinterpret_cast<unsigned*>(&shmem[0]),
+      mode_index,
       sliceSize,
-      [=] GPU_LAMBDA(const auto& a, const auto& b) { return b.flag ? b : a; },
-      match);
+      max_index_op,
+      0u);
 
   // Finally, we have the mode, and an index where it occurs. We use a single
   // thread to place this in the appropriate output position
   if (tidx == 0) {
-    int64_t index = match.val;
-
     unsigned int outputOffset =
         at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
             blockId, values);
index a3f600e..095ce7a 100644 (file)
@@ -10,6 +10,11 @@ namespace native {
 namespace cuda_utils {
 
 constexpr int kCUDABlockReduceNumThreads = 512;
+// Algorithmic limitation: BlockReduce does two WarpReduce calls, each
+// of which reduces C10_WARP_SIZE elements. So, at most
+// C10_WARP_SIZE**2 elements can be reduced at a time.
+// NOTE: This is >= the max block size on current hardware anyway (1024).
+constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
 
 // Sums `val` accross all threads in a warp.
 //
@@ -41,7 +46,7 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
     shared[wid] = val;
   }
   __syncthreads();
-  val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : 0;
+  val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : T(0);
   if (wid == 0) {
     val = WarpReduceSum(val);
   }
index ab7f72b..ca666b4 100644 (file)
@@ -11,7 +11,6 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cpp
 
-  ${CMAKE_CURRENT_SOURCE_DIR}/THCReduceApplyUtils.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCSleep.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorage.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cu
@@ -30,7 +29,6 @@ install(FILES
           THCTensor.h
           THCTensorCopy.h
           THCTensorCopy.hpp
-          THCReduceApplyUtils.cuh
           THCTensorMathReduce.cuh
           THCAsmUtils.cuh
           THCAtomics.cuh
diff --git a/aten/src/THC/THCReduceApplyUtils.cu b/aten/src/THC/THCReduceApplyUtils.cu
deleted file mode 100644 (file)
index d586f8c..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-#include <THC/THCReduceApplyUtils.cuh>
-
-#include <assert.h>
-#include <stdlib.h>
-
-// Maximum size per grid dimension that we assume (compute capability >= 2.0)
-#define MAX_GRID_SIZE 65535LL
-
-void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg) {
-  int64_t dims = THCudaTensor_nDimensionLegacyAll(state, tensor);
-  THArgCheck(dims <= MAX_CUTORCH_DIMS, arg, CUTORCH_DIM_WARNING);
-}
-
-bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid) {
-  if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
-    return false;
-  }
-
-  int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
-  int64_t gridY = 1;
-  int64_t gridZ = 1;
-
-  if (gridTiles > MAX_GRID_SIZE) {
-    gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE);
-    gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
-
-    if (gridTiles > MAX_GRID_SIZE) {
-      gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE);
-      gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
-    }
-  }
-
-  grid = dim3(gridX, gridY, gridZ);
-  return true;
-}
diff --git a/aten/src/THC/THCReduceApplyUtils.cuh b/aten/src/THC/THCReduceApplyUtils.cuh
deleted file mode 100644 (file)
index cf0d923..0000000
+++ /dev/null
@@ -1,117 +0,0 @@
-#ifndef THC_REDUCE_APPLY_UTILS_INC
-#define THC_REDUCE_APPLY_UTILS_INC
-
-#include <cuda.h>
-#include <assert.h>
-#include <THC/THCGeneral.h>
-#include <THC/THCTensor.h>
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCTensorInfo.cuh>
-
-// Enum that indicates whether tensor arguments are read/write or
-// read-only
-enum TensorArgType { ReadWrite, ReadOnly };
-
-// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
-// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
-// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
-//
-// If smem is not used again, there is no need to __syncthreads before this
-// call. However, if smem will be used, e.g., this function is called in a loop,
-// then __syncthreads is needed either before or afterwards to prevent non-0
-// threads overriding smem in the next loop before num-0 thread reads from it.
-template <typename T, typename ReduceOp, int N>
-__device__ void reduceNValuesInBlock(T *smem,
-                             T threadVals[N],
-                             const unsigned int numVals,
-                             ReduceOp reduceOp,
-                             T init) {
-  if (numVals == 0) {
-    #pragma unroll
-    for (int i = 0; i < N; ++i) {
-      threadVals[i] = init;
-    }
-    return;
-  }
-
-  // We store each of the N values contiguously, so if N = 2, all values for
-  // the first threadVal for each thread in the block are stored followed by
-  // all of the values for the second threadVal for each thread in the block
-  if (threadIdx.x < numVals) {
-    #pragma unroll
-    for (int i = 0; i < N; ++i) {
-      smem[i * numVals + threadIdx.x] = threadVals[i];
-    }
-  }
-  __syncthreads();
-
-  // Number of lanes in the final reduction --> this is used to determine
-  // where to put the outputs of each of the n things we are reducing. If
-  // nLP = 32, then we have the 32 outputs for the first threadVal,
-  // followed by the 32 outputs for the second threadVal, etc.
-  const unsigned int numLanesParticipating = min(numVals, warpSize);
-
-  if (numVals > warpSize && ((threadIdx.x / warpSize) == 0 )) {
-    #pragma unroll
-    for (int i = 0; i < N; ++i) {
-      threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init;
-    }
-
-    for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
-      #pragma unroll
-      for (int j = 0; j < N; ++j) {
-        threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]);
-      }
-    }
-
-    #pragma unroll
-    for (int i = 0; i < N; ++i) {
-      smem[i * numLanesParticipating + threadIdx.x] = threadVals[i];
-    }
-  }
-  __syncthreads();
-
-  if (threadIdx.x == 0) {
-    if (numLanesParticipating == 32) {
-      #pragma unroll
-      for (int i = 0; i < N; ++i) {
-        #pragma unroll
-        for (int j = 1; j < 32; ++j) {
-          threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]);
-        }
-      }
-    } else {
-      #pragma unroll
-      for (int i = 0; i < N; ++i) {
-        for (int j = 1; j < numLanesParticipating; ++j) {
-          threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]);
-        }
-      }
-    }
-  }
-}
-
-// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
-// return the reduced value
-//
-// If smem is not used again, there is no need to __syncthreads before this
-// call. However, if smem will be used, e.g., this function is called in a loop,
-// then __syncthreads is needed either before or afterwards to prevent non-0
-// threads overriding smem in the next loop before num-0 thread reads from it.
-template <typename T, typename ReduceOp>
-__device__ T reduceBlock(T* smem,
-                         const unsigned int numVals,
-                         T threadVal,
-                         ReduceOp reduceOp,
-                         T init) {
-  reduceNValuesInBlock<T, ReduceOp, 1>(smem, &threadVal, numVals, reduceOp, init);
-  return threadVal;
-}
-
-// Make sure the given tensor doesn't have too many dimensions
-void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg);
-
-// Produces a grid with at least one point per tile
-TORCH_CUDA_CU_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid);
-
-#endif // THC_REDUCE_APPLY_UTILS_INC