Revert D30735341: Migrate uses of THCReduceApplyUtils to cuda_utils::BlockReduce
authorNatalia Gimelshein <ngimel@fb.com>
Wed, 8 Sep 2021 21:25:42 +0000 (14:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 21:27:40 +0000 (14:27 -0700)
Test Plan: revert-hammer

Differential Revision:
D30735341 (https://github.com/pytorch/pytorch/commit/a5ad08ec704a3f765814eacf5c393e871c0174e1)

Original commit changeset: 3cb58bed8f1f

fbshipit-source-id: 874dd0f93b24a99694db42a15714834069d402bc

aten/src/ATen/native/cuda/Embedding.cu
aten/src/ATen/native/cuda/MultinomialKernel.cu
aten/src/ATen/native/cuda/TensorModeKernel.cu
aten/src/ATen/native/cuda/TensorModeKernel.cuh
aten/src/ATen/native/cuda/block_reduce.cuh
aten/src/THC/CMakeLists.txt
aten/src/THC/THCReduceApplyUtils.cu [new file with mode: 0644]
aten/src/THC/THCReduceApplyUtils.cuh [new file with mode: 0644]

index 354c106..155b389 100644 (file)
@@ -5,11 +5,14 @@
 #include <c10/util/Exception.h>
 #include <c10/macros/Macros.h>
 
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCReduceApplyUtils.cuh>
+
 #include <ATen/cuda/cub.cuh>
 
 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
-#include <ATen/native/cuda/block_reduce.cuh>
 
 namespace at { namespace native {
 
@@ -199,7 +202,8 @@ __global__ void renorm_kernel(
     }
   }
 
-  v = cuda_utils::BlockReduceSum(v, sdata);
+  using Op = ReduceAdd<accscalar_t>;
+  v = reduceBlock<accscalar_t>(sdata, blockDim.x, v, Op(), 0);
 
   if (tid == 0) {
     sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
@@ -311,15 +315,12 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
     );
 
     dim3 grid = num_unique_indices.item<int64_t>();
-    constexpr int num_threads = 128;
+    dim3 block = 128;
     int dim = self.stride(0);
 
     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
-      static_assert(num_threads % C10_WARP_SIZE == 0 &&
-                    num_threads <= cuda_utils::kCUDABlockReduceMaxThreads,
-                    "BlockReduceSum requires all warps be active");
-      renorm_kernel<<<grid, num_threads, (num_threads / C10_WARP_SIZE) * sizeof(accscalar_t), stream>>>(
+      renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
         self.data_ptr<scalar_t>(),
         unique_indices.data_ptr<index_t>(),
         static_cast<accscalar_t>(max_norm),
index 6efc350..65c45e7 100644 (file)
 #include <ATen/cuda/CUDAGraphsUtils.cuh>
 #include <ATen/native/cuda/block_reduce.cuh>
 
+#include <THC/THCReduceApplyUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh>
+#include <THC/THCNumerics.cuh>
+
 #include <curand.h>
 #include <curand_kernel.h>
 #include <curand_philox4x32_x.h>
@@ -18,20 +22,6 @@ namespace at { namespace native {
 
 namespace {
 
-int64_t div_up(int64_t a, int64_t b) {
-  return (a + (b - 1)) / b;
-}
-
-template <typename T>
-inline __device__ bool _isinf(T x) { return ::isinf(x); }
-
-inline __device__ bool _isinf(c10::Half x) {
-  return ::isinf(static_cast<float>(x));
-}
-inline __device__ bool _isinf(c10::BFloat16 x) {
-  return ::isinf(static_cast<float>(x));
-}
-
 #define MAX_NUM_BLOCKS 200
 
 // Normalizes the L1 norm of every row to 1; used by multinomial
@@ -46,13 +36,13 @@ __global__ void renormRowsL1(scalar_t* dist, long rows, long cols) {
     scalar_t sum = static_cast<scalar_t>(0);
     for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
       val = dist[row * cols + col];
-      CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
+      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
       sum = sum + val;
     }
 
-    sum = cuda_utils::BlockReduceSum(sum, smem);
+    sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<scalar_t>(), zero);
     if (threadIdx.x == 0) {
-      CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
+      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
       smem[0] = sum;
     }
     __syncthreads();
@@ -74,15 +64,14 @@ void renormRows(Tensor& t) {
   auto props = at::cuda::getCurrentDeviceProperties();
   CUDA_KERNEL_ASSERT(props != NULL);
   int numSM = props->multiProcessorCount;
-  const int64_t maxThreads = std::min(
-      props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
+  int maxThreads = props->maxThreadsPerBlock;
 
   dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
-  dim3 block(std::min(maxThreads, C10_WARP_SIZE * div_up(cols, C10_WARP_SIZE)));
+  dim3 block(cols < maxThreads ? cols : maxThreads);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(), "renormRows_cuda", [&] {
     renormRowsL1<scalar_t>
-        <<<grid, block, (block.x / C10_WARP_SIZE) * sizeof(scalar_t),
+        <<<grid, block, block.x * sizeof(scalar_t),
         at::cuda::getCurrentCUDAStream()>>>(t.data_ptr<scalar_t>(),
             rows, cols);
     C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -201,9 +190,9 @@ __global__ void sampleMultinomialOnce(
     scalar_t val;
     for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
       val = dist[curDist * stride_dist + cat * stride_categories];
-      CUDA_KERNEL_ASSERT(!at::_isnan(val));
-      CUDA_KERNEL_ASSERT(!_isinf(val));
-      CUDA_KERNEL_ASSERT(!(val < zero));
+      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::isnan(val));
+      CUDA_KERNEL_ASSERT(!THCNumerics<scalar_t>::isinf(val));
+      CUDA_KERNEL_ASSERT(val >= zero);
       sum = sum + static_cast<accscalar_t>(val);
     }
 
@@ -213,7 +202,7 @@ __global__ void sampleMultinomialOnce(
     // Broadcast sum and sample value
     if (threadIdx.x == 0) {
       // Make sure the sum of our distribution didn't overflow
-      CUDA_KERNEL_ASSERT(!_isinf(val));
+      CUDA_KERNEL_ASSERT(!THCNumerics<accscalar_t>::isinf(sum));
       CUDA_KERNEL_ASSERT(sum > accZero);
 
       foundPos = 0;
index dbbcdc9..c24d7fb 100644 (file)
@@ -138,13 +138,11 @@ void handle_fused_mode(
     cuda::detail::TensorInfo<int64_t, unsigned int>& ti_indices,
     int64_t slice_size,
     int64_t slices) {
-  constexpr int num_threads = size / 2;
-  static_assert(num_threads % C10_WARP_SIZE == 0 &&
-                num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
+  const dim3 block(size / 2);
   const auto memsize =
       (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
   compute_mode<scalar_t, size>
-      <<<grid, num_threads, memsize, at::cuda::getCurrentCUDAStream()>>>(
+      <<<grid, block, memsize, at::cuda::getCurrentCUDAStream()>>>(
           self.data_ptr<scalar_t>(), ti_values, ti_indices, slice_size, slices);
   C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
@@ -184,19 +182,16 @@ void fused_mode(
       break;
     case 128:
     case 64:
+      handle_fused_mode<128, scalar_t>(
+          grid, self, ti_values, ti_indices, slice_size, slices);
+      break;
     case 32:
     case 16:
     case 8:
     case 4:
-    case 2: {
-      if (ceilPowerOf2 > 2 * C10_WARP_SIZE) {
-        handle_fused_mode<128, scalar_t>(
-            grid, self, ti_values, ti_indices, slice_size, slices);
-      } else {
-        handle_fused_mode<2 * C10_WARP_SIZE, scalar_t>(
-            grid, self, ti_values, ti_indices, slice_size, slices);
-      }
-    }
+    case 2:
+      handle_fused_mode<32, scalar_t>(
+          grid, self, ti_values, ti_indices, slice_size, slices);
       break;
     case 1:
     default:
index d29454b..f45b9b3 100644 (file)
@@ -3,7 +3,8 @@
 #include <ATen/cuda/detail/IndexUtils.cuh>
 #include <ATen/native/cuda/Loops.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
-#include <ATen/native/cuda/block_reduce.cuh>
+
+#include <THC/THCReduceApplyUtils.cuh>
 
 namespace at {
 namespace native {
@@ -87,10 +88,11 @@ __device__ T reduceBlockWithNThreadLocalReductions(
   for (int i = 1; i < N; ++i) {
     ++offset;
     T next = offset < numVals ? threadVals[i] : init;
-    local = reduceOp.combine(local, next);
+    local = reduceOp(local, next);
   }
 
-  return cuda_utils::BlockReduce(local, reduceOp, init, smem);
+  return reduceBlock<T, ReduceOp>(
+      smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init);
 }
 
 template <typename T>
@@ -352,24 +354,13 @@ __global__ void compute_mode(
 
   struct ModeUnsignedPair max = {0, 0};
 
-  struct MaxOp {
-    inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
-      return b.val > a.val ? b : a;
-    }
-
-    inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
-      ModeUnsignedPair ret;
-      ret.index = WARP_SHFL_DOWN(acc.index, offset);
-      ret.val = WARP_SHFL_DOWN(acc.val, offset);
-      return ret;
-    }
-  } max_op;
-
   max = reduceBlockWithNThreadLocalReductions<2>(
       uupmem,
       uup,
       sliceSize,
-      max_op,
+      [=] GPU_LAMBDA(const auto& a, const auto& b) {
+        return b.val > a.val ? b : a;
+      },
       max);
 
   // Store the mode in shared memory for use in finding the mode in the input
@@ -383,43 +374,43 @@ __global__ void compute_mode(
   }
   __syncthreads(); // broadcast mode
 
-  // Finally, we need to find "an" index of the mode in the input
-  // Tensor. The API does not constrain which index we pick, but here
-  // we always pick the largest index. We store the index if the value
-  // is the mode, or 0 otherwise. Then find the maximum value.
+  // Finally, we need to find the "an" index of the mode in the input Tensor.
+  // The API does not constrain which index we pick, so it can be any of the
+  // indices that contain the mode. We will do a reduction to find the index. We
+  // go back to using the (index, flag) buffer arrangement. First, we mark
+  // indices that are equal to the mode, i.e B[i] = true if input[i] == mode,
+  // and initialize C[i] to be the index
   //
   // Again we reduce 2 elements in the thread's registers prior to the
   // block-wide reduction
-  unsigned mode_index[2] = {0u, 0u};
+  struct ModeUnsignedBoolPair ubpp[2];
   if (tidx * 2 < sliceSize) {
-    const unsigned idx = tidx * 2;
-    mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u;
+    ubpp[0].flag = input[linearOffset + (tidx * 2)] == mode;
+    ubpp[0].val = tidx * 2;
   }
   if (tidx * 2 + 1 < sliceSize) {
-    const unsigned idx = tidx * 2 + 1;
-    mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u;
+    ubpp[1].flag = input[linearOffset + (tidx * 2 + 1)] == mode;
+    ubpp[1].val = tidx * 2 + 1;
   }
 
-  struct MaxIndexOp {
-    inline __device__ unsigned combine(unsigned a, unsigned b) const {
-      return b > a ? b : a;
-    }
-
-    inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
-      return WARP_SHFL_DOWN(acc, offset);
-    }
-  } max_index_op;
+  // Then we perform a similar reduction to the one above, except this time we
+  // update the element if the element at the base position is not equal to the
+  // mode and the element at the offset position is. At the end, C[0] will
+  // contain an index with the mode.
+  struct ModeUnsignedBoolPair match = {0, false};
 
-  int64_t index = reduceBlockWithNThreadLocalReductions<2>(
-      reinterpret_cast<unsigned*>(&shmem[0]),
-      mode_index,
+  match = reduceBlockWithNThreadLocalReductions<2>(
+      ubpmem,
+      ubpp,
       sliceSize,
-      max_index_op,
-      0u);
+      [=] GPU_LAMBDA(const auto& a, const auto& b) { return b.flag ? b : a; },
+      match);
 
   // Finally, we have the mode, and an index where it occurs. We use a single
   // thread to place this in the appropriate output position
   if (tidx == 0) {
+    int64_t index = match.val;
+
     unsigned int outputOffset =
         at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
             blockId, values);
index 095ce7a..a3f600e 100644 (file)
@@ -10,11 +10,6 @@ namespace native {
 namespace cuda_utils {
 
 constexpr int kCUDABlockReduceNumThreads = 512;
-// Algorithmic limitation: BlockReduce does two WarpReduce calls, each
-// of which reduces C10_WARP_SIZE elements. So, at most
-// C10_WARP_SIZE**2 elements can be reduced at a time.
-// NOTE: This is >= the max block size on current hardware anyway (1024).
-constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
 
 // Sums `val` accross all threads in a warp.
 //
@@ -46,7 +41,7 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
     shared[wid] = val;
   }
   __syncthreads();
-  val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : T(0);
+  val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : 0;
   if (wid == 0) {
     val = WarpReduceSum(val);
   }
index ca666b4..ab7f72b 100644 (file)
@@ -11,6 +11,7 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cpp
 
+  ${CMAKE_CURRENT_SOURCE_DIR}/THCReduceApplyUtils.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCSleep.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorage.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cu
@@ -29,6 +30,7 @@ install(FILES
           THCTensor.h
           THCTensorCopy.h
           THCTensorCopy.hpp
+          THCReduceApplyUtils.cuh
           THCTensorMathReduce.cuh
           THCAsmUtils.cuh
           THCAtomics.cuh
diff --git a/aten/src/THC/THCReduceApplyUtils.cu b/aten/src/THC/THCReduceApplyUtils.cu
new file mode 100644 (file)
index 0000000..d586f8c
--- /dev/null
@@ -0,0 +1,35 @@
+#include <THC/THCReduceApplyUtils.cuh>
+
+#include <assert.h>
+#include <stdlib.h>
+
+// Maximum size per grid dimension that we assume (compute capability >= 2.0)
+#define MAX_GRID_SIZE 65535LL
+
+void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg) {
+  int64_t dims = THCudaTensor_nDimensionLegacyAll(state, tensor);
+  THArgCheck(dims <= MAX_CUTORCH_DIMS, arg, CUTORCH_DIM_WARNING);
+}
+
+bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid) {
+  if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
+    return false;
+  }
+
+  int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+  int64_t gridY = 1;
+  int64_t gridZ = 1;
+
+  if (gridTiles > MAX_GRID_SIZE) {
+    gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE);
+    gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+
+    if (gridTiles > MAX_GRID_SIZE) {
+      gridTiles = THCCeilDiv(gridTiles, (ptrdiff_t) MAX_GRID_SIZE);
+      gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+    }
+  }
+
+  grid = dim3(gridX, gridY, gridZ);
+  return true;
+}
diff --git a/aten/src/THC/THCReduceApplyUtils.cuh b/aten/src/THC/THCReduceApplyUtils.cuh
new file mode 100644 (file)
index 0000000..cf0d923
--- /dev/null
@@ -0,0 +1,117 @@
+#ifndef THC_REDUCE_APPLY_UTILS_INC
+#define THC_REDUCE_APPLY_UTILS_INC
+
+#include <cuda.h>
+#include <assert.h>
+#include <THC/THCGeneral.h>
+#include <THC/THCTensor.h>
+#include <THC/THCDeviceUtils.cuh>
+#include <THC/THCTensorInfo.cuh>
+
+// Enum that indicates whether tensor arguments are read/write or
+// read-only
+enum TensorArgType { ReadWrite, ReadOnly };
+
+// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
+// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
+// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
+//
+// If smem is not used again, there is no need to __syncthreads before this
+// call. However, if smem will be used, e.g., this function is called in a loop,
+// then __syncthreads is needed either before or afterwards to prevent non-0
+// threads overriding smem in the next loop before num-0 thread reads from it.
+template <typename T, typename ReduceOp, int N>
+__device__ void reduceNValuesInBlock(T *smem,
+                             T threadVals[N],
+                             const unsigned int numVals,
+                             ReduceOp reduceOp,
+                             T init) {
+  if (numVals == 0) {
+    #pragma unroll
+    for (int i = 0; i < N; ++i) {
+      threadVals[i] = init;
+    }
+    return;
+  }
+
+  // We store each of the N values contiguously, so if N = 2, all values for
+  // the first threadVal for each thread in the block are stored followed by
+  // all of the values for the second threadVal for each thread in the block
+  if (threadIdx.x < numVals) {
+    #pragma unroll
+    for (int i = 0; i < N; ++i) {
+      smem[i * numVals + threadIdx.x] = threadVals[i];
+    }
+  }
+  __syncthreads();
+
+  // Number of lanes in the final reduction --> this is used to determine
+  // where to put the outputs of each of the n things we are reducing. If
+  // nLP = 32, then we have the 32 outputs for the first threadVal,
+  // followed by the 32 outputs for the second threadVal, etc.
+  const unsigned int numLanesParticipating = min(numVals, warpSize);
+
+  if (numVals > warpSize && ((threadIdx.x / warpSize) == 0 )) {
+    #pragma unroll
+    for (int i = 0; i < N; ++i) {
+      threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init;
+    }
+
+    for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
+      #pragma unroll
+      for (int j = 0; j < N; ++j) {
+        threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]);
+      }
+    }
+
+    #pragma unroll
+    for (int i = 0; i < N; ++i) {
+      smem[i * numLanesParticipating + threadIdx.x] = threadVals[i];
+    }
+  }
+  __syncthreads();
+
+  if (threadIdx.x == 0) {
+    if (numLanesParticipating == 32) {
+      #pragma unroll
+      for (int i = 0; i < N; ++i) {
+        #pragma unroll
+        for (int j = 1; j < 32; ++j) {
+          threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]);
+        }
+      }
+    } else {
+      #pragma unroll
+      for (int i = 0; i < N; ++i) {
+        for (int j = 1; j < numLanesParticipating; ++j) {
+          threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]);
+        }
+      }
+    }
+  }
+}
+
+// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
+// return the reduced value
+//
+// If smem is not used again, there is no need to __syncthreads before this
+// call. However, if smem will be used, e.g., this function is called in a loop,
+// then __syncthreads is needed either before or afterwards to prevent non-0
+// threads overriding smem in the next loop before num-0 thread reads from it.
+template <typename T, typename ReduceOp>
+__device__ T reduceBlock(T* smem,
+                         const unsigned int numVals,
+                         T threadVal,
+                         ReduceOp reduceOp,
+                         T init) {
+  reduceNValuesInBlock<T, ReduceOp, 1>(smem, &threadVal, numVals, reduceOp, init);
+  return threadVal;
+}
+
+// Make sure the given tensor doesn't have too many dimensions
+void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg);
+
+// Produces a grid with at least one point per tile
+TORCH_CUDA_CU_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid);
+
+#endif // THC_REDUCE_APPLY_UTILS_INC