From b5193b6a8116dec81a8ed88a424be8e8ceba3de6 Mon Sep 17 00:00:00 2001 From: jiej Date: Thu, 14 Feb 2019 14:40:13 -0800 Subject: [PATCH] Second PR to restore reverted commit (#16224) (#17040) Summary: update: 1. global_reduce check for should_block_y_reduce first. This avoids the enabling global_reduce without block_y_reduce. Leading to accessing shared memory during global reduce without allocation. 2. updating block_y_reduce heuristics. Improves perf on tiny tensors 3. adding test case covering old cases where illegal memory access might occur TensorIterator cuda launch configs update (#16224) Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible block dimension to improve efficiency for reduction cases with small fast dimension. Previously TensorIterator launches blocks with fixed 32x16 threads. For cases like: import torch torch.randn(2**20, 4, device='cuda').sum(0) The fixed launch config does handle coalesced memory access efficiently. Updated launch configure enables flexible block dimension. Combining with improved reduction scheme (using flexible vertical / horizontal reduction instead of limited warp / block reduction in the old code), it ensures optimal memory access pattern even with reduction on dimension with small stride. Possible future improvements: 1. Precise dynamic shared memory allocation. 2. Using warp shuffle for vertical (block_y) reduction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17040 Differential Revision: D14078295 Pulled By: umanwizard fbshipit-source-id: ecc55054a5a4035e731f0196d633412225c3b06c --- aten/src/ATen/native/cuda/Reduce.cuh | 163 +++++++++++++++++++++++------------ test/test_cuda.py | 4 + 2 files changed, 112 insertions(+), 55 deletions(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index acd45c8..afeed51 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -25,11 +25,22 @@ static inline int64_t div_up(int64_t a, int64_t b) { return (a + b - 1) / b; } +// returns floor(log2(n)) +static inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + struct ReduceConfig { - static constexpr int LANE = 0; - static constexpr int WARP = 1; + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; static constexpr int CTA = 2; - static constexpr int NUM_THREADS = 512; + + static constexpr int MAX_NUM_THREADS = 512; ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs) : element_size_bytes(element_size_bytes) @@ -45,6 +56,19 @@ struct ReduceConfig { int input_mult[3] = {0, 0, 0}; int output_mult[2] = {0, 0}; + int block_width; + int block_height; + int num_threads; + + void set_block_dimension(int64_t dim0, int64_t dim1) { + int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast(last_pow2(dim0)) : MAX_NUM_THREADS; + int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast(last_pow2(dim1)) : MAX_NUM_THREADS; + block_width = std::min(dim0_pow2, int(at::cuda::warp_size())); + block_height = std::min(dim1_pow2, int(MAX_NUM_THREADS / block_width)); + block_width = std::min(dim0_pow2, int(MAX_NUM_THREADS / block_height)); + num_threads = block_width * block_height; + } + int split_input(int parallelism) { int step = step_input; step_input *= parallelism; @@ -58,20 +82,19 @@ struct ReduceConfig { } dim3 block() const { - int warp_size = at::cuda::warp_size(); - return dim3(warp_size, NUM_THREADS / warp_size); + return dim3(block_width, block_height); } dim3 grid() const { return dim3(div_up(num_outputs, step_output), ctas_per_output); } - C10_HOST_DEVICE bool should_warp_reduce() const { - return input_mult[LANE] != 0; + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; } - C10_HOST_DEVICE bool should_block_reduce() const { - return input_mult[WARP] != 0; + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; } C10_HOST_DEVICE bool should_global_reduce() const { @@ -80,16 +103,16 @@ struct ReduceConfig { C10_DEVICE bool should_store(int output_idx) const { return output_idx < num_outputs && - (!should_warp_reduce() || threadIdx.x == 0) && - (!should_block_reduce() || threadIdx.y == 0); + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); } C10_HOST_DEVICE int input_idx() const { int lane = threadIdx.x; int warp = threadIdx.y; int cta2 = blockIdx.y; - return (lane * input_mult[LANE] + - warp * input_mult[WARP] + + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + cta2 * input_mult[CTA]); } @@ -97,8 +120,8 @@ struct ReduceConfig { int lane = threadIdx.x; int warp = threadIdx.y; int cta1 = blockIdx.x; - return (lane * output_mult[LANE] + - warp * output_mult[WARP] + + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + cta1 * step_output); } @@ -108,17 +131,19 @@ struct ReduceConfig { C10_DEVICE int staging_memory_offset(int cta2) const { int offset = cta2 + blockIdx.x * gridDim.y; - if (!should_warp_reduce()) { + if (!should_block_x_reduce()) { offset = threadIdx.x + offset * blockDim.x; } return offset; } int shared_memory_size() const { - if (!should_block_reduce()) { + if (!should_block_y_reduce() && + (!should_block_x_reduce() || + block_width <= at::cuda::warp_size())) { return 0; } - return element_size_bytes * NUM_THREADS; + return element_size_bytes * num_threads; } int64_t global_memory_size() const { @@ -126,7 +151,7 @@ struct ReduceConfig { return 0; } auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output; - if (!should_warp_reduce()) { + if (!should_block_x_reduce()) { size *= block().x; } return size; @@ -267,6 +292,7 @@ struct ReduceOp { } C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; index_t output_idx = config.output_idx(); index_t input_idx = config.input_idx(); auto base_offsets = output_calc.get(output_idx); @@ -276,17 +302,17 @@ struct ReduceOp { auto input_slice = (const char*)src + base_offsets[1]; value = thread_reduce((const scalar_t*)input_slice); } - bool should_block_reduce = config.should_block_reduce(); - if (should_block_reduce) { - value = block_reduce(value); + bool should_block_y_reduce = config.should_block_y_reduce(); + if (should_block_y_reduce) { + value = block_y_reduce(value, shared_memory); } - if (config.should_warp_reduce() && (!should_block_reduce || threadIdx.y == 0)) { - value = warp_reduce(value); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); } auto out = (out_scalar_t*)((char*)dst + base_offsets[0]); if (config.should_global_reduce()) { - value = global_reduce(value, out); + value = global_reduce(value, out, shared_memory); } else if (config.should_store(output_idx)) { if (accumulate) { value = accumulate_in_output(out, value); @@ -331,22 +357,38 @@ struct ReduceOp { return value; } - C10_DEVICE arg_t warp_reduce(arg_t value) const { - for (int offset = 1; offset < warpSize; offset <<= 1) { + C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const { + int dim_x = blockDim.x; + arg_t* shared = (arg_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + arg_t other = shared[address_base + offset]; + value = ops.combine(value, other); + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { arg_t other = ops.warp_shfl_down(value, offset); value = ops.combine(value, other); } return value; } - C10_DEVICE arg_t block_reduce(arg_t value) const { - extern __shared__ char shared_memory[]; + C10_DEVICE arg_t block_y_reduce(arg_t value, char* shared_memory) const { arg_t* shared = (arg_t*)shared_memory; shared[config.shared_memory_offset(0)] = value; - int num_warps = (blockDim.x * blockDim.y) / warpSize; - for (int offset = num_warps / 2; offset > 0; offset >>= 1) { + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { __syncthreads(); - if (threadIdx.y < offset && threadIdx.y + offset < num_warps) { + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { arg_t other = shared[config.shared_memory_offset(offset)]; value = ops.combine(value, other); shared[config.shared_memory_offset(0)] = value; @@ -356,19 +398,17 @@ struct ReduceOp { } C10_DEVICE bool mark_block_finished() const { - extern __shared__ int is_last_block_done_shared[]; + __shared__ bool is_last_block_done_shared; __syncthreads(); if (threadIdx.x == 0 && threadIdx.y == 0) { int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); - is_last_block_done_shared[0] = (prev_blocks_finished == gridDim.y - 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); } __syncthreads(); - bool is_last_block_done = is_last_block_done_shared[0]; - __syncthreads(); - return is_last_block_done; + return is_last_block_done_shared; } template @@ -409,7 +449,7 @@ struct ReduceOp { return ops.project(value); } - C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const { + C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out, char* shared_memory) const { arg_t* reduce_buffer = (arg_t*)buffer; bool should_store = config.should_store(config.output_idx()); @@ -424,7 +464,7 @@ struct ReduceOp { if (is_last_block_done) { value = ident; - if (config.should_warp_reduce()) { + if (config.should_block_x_reduce()) { index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; index_t step = blockDim.x * blockDim.y; for (; input_offset < config.ctas_per_output; input_offset += step) { @@ -441,9 +481,9 @@ struct ReduceOp { value = ops.combine(value, next); } } - value = block_reduce(value); - if (config.should_warp_reduce()) { - value = warp_reduce(value); + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); } if (should_store) { if (accumulate) { @@ -461,6 +501,7 @@ template static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) { dim3 block = config.block(); dim3 grid = config.grid(); + auto stream = at::cuda::getCurrentCUDAStream(); int shared_memory = config.shared_memory_size(); reduce_kernel<<>>(reduction); @@ -487,10 +528,6 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id char* out_data = (char*)iter.data_ptr(0); const char* in_data = (char*)iter.data_ptr(1); - - int warp_size = at::cuda::warp_size(); - int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size; - // Start by assuming that each thread handles a single output and all // the inputs for that output. int64_t num_outputs = iter.num_output_elements(); @@ -498,27 +535,43 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output); + int64_t dim0; + int64_t dim1; + // adjust block size to fit width to fast changing dimension + if (iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) { + dim0 = iter.shape()[0]; + dim1 = num_outputs; + } else { + dim0 = iter.shape()[iter.num_reduce_dims()]; + dim1 = inputs_per_output; + } + + config.set_block_dimension(dim0, dim1); + + int block_width = config.block_width; + int block_height = config.block_height; + if (iter.ndim() == 0 || iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) { // Split the input across lanes if the input is contiguous in the reduced // dimension. This will require reduction between threads using warp - // shuffle instructions. - config.input_mult[0] = config.split_input(warp_size); + // shuffle instructions and shared memory (if block_width > warpSize). + config.input_mult[0] = config.split_input(block_width); } else { // Otherwise split the output across lanes in a warp. - config.output_mult[0] = config.split_output(warp_size); + config.output_mult[0] = config.split_output(block_width); } - if (config.values_per_thread() >= warps_per_cta * 16) { + if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= 256) { // Divide the input across warps in a thread-block, if that leaves at least // 16 elements to be summed by each thread. This will require inter-warp // reduction using shared memory. - config.input_mult[1] = config.split_input(warps_per_cta); + config.input_mult[1] = config.split_input(block_height); } else { // Otherwise, each warp handles a separate output. - config.output_mult[1] = config.split_output(warps_per_cta); + config.output_mult[1] = config.split_output(block_height); } - if (config.values_per_thread() >= 256 && num_outputs <= 4096) { + if (config.input_mult[1] != 0 && config.values_per_thread() >= 256 && num_outputs <= 4096) { // Divide the input across thread-blocks if the amount of work per-thread // is large enough and the size of the output is small enough. This will // require a reduction using global memory. @@ -556,7 +609,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id reduce.accumulate = iter.should_accumulate(); reduce.final_output = iter.is_final_output(); - launch_reduce_kernel(config, reduce); + launch_reduce_kernel(config, reduce); } else { auto output_calc = make_output_calculator(iter); auto input_calc = make_input_calculator(iter); @@ -574,7 +627,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id reduce.accumulate = false; reduce.final_output = true; - launch_reduce_kernel(config, reduce); + launch_reduce_kernel(config, reduce); } } diff --git a/test/test_cuda.py b/test/test_cuda.py index 26fc9fe..e2ebe43 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1969,6 +1969,10 @@ class TestCuda(TestCase): self.assertEqual(gpu_tensor1[0], 1) self.assertEqual(gpu_tensor0[0], 2) + def test_reduction_gpu_memory_accessing(self): + x = torch.ones(512, 8, dtype=torch.float32, device='cuda') + torch.sum(x, 0) + def test_sum_cpu_gpu_mismatch(self): x = torch.randn(20, dtype=torch.float32, device='cuda') y = torch.randn(1, dtype=torch.float32) -- 2.7.4