return (a + b - 1) / b;
}
+// returns floor(log2(n))
+static inline int last_pow2(int n) {
+ n |= (n >> 1);
+ n |= (n >> 2);
+ n |= (n >> 4);
+ n |= (n >> 8);
+ n |= (n >> 16);
+ return std::max(1, n - (n >> 1));
+}
+
struct ReduceConfig {
- static constexpr int LANE = 0;
- static constexpr int WARP = 1;
+ static constexpr int BLOCK_X = 0;
+ static constexpr int BLOCK_Y = 1;
static constexpr int CTA = 2;
- static constexpr int NUM_THREADS = 512;
+
+ static constexpr int MAX_NUM_THREADS = 512;
ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
: element_size_bytes(element_size_bytes)
int input_mult[3] = {0, 0, 0};
int output_mult[2] = {0, 0};
+ int block_width;
+ int block_height;
+ int num_threads;
+
+ void set_block_dimension(int64_t dim0, int64_t dim1) {
+ int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim0)) : MAX_NUM_THREADS;
+ int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim1)) : MAX_NUM_THREADS;
+ block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
+ block_height = std::min(dim1_pow2, int(MAX_NUM_THREADS / block_width));
+ block_width = std::min(dim0_pow2, int(MAX_NUM_THREADS / block_height));
+ num_threads = block_width * block_height;
+ }
+
int split_input(int parallelism) {
int step = step_input;
step_input *= parallelism;
}
dim3 block() const {
- int warp_size = at::cuda::warp_size();
- return dim3(warp_size, NUM_THREADS / warp_size);
+ return dim3(block_width, block_height);
}
dim3 grid() const {
return dim3(div_up(num_outputs, step_output), ctas_per_output);
}
- C10_HOST_DEVICE bool should_warp_reduce() const {
- return input_mult[LANE] != 0;
+ C10_HOST_DEVICE bool should_block_x_reduce() const {
+ return input_mult[BLOCK_X] != 0;
}
- C10_HOST_DEVICE bool should_block_reduce() const {
- return input_mult[WARP] != 0;
+ C10_HOST_DEVICE bool should_block_y_reduce() const {
+ return input_mult[BLOCK_Y] != 0;
}
C10_HOST_DEVICE bool should_global_reduce() const {
C10_DEVICE bool should_store(int output_idx) const {
return output_idx < num_outputs &&
- (!should_warp_reduce() || threadIdx.x == 0) &&
- (!should_block_reduce() || threadIdx.y == 0);
+ (!should_block_x_reduce() || threadIdx.x == 0) &&
+ (!should_block_y_reduce() || threadIdx.y == 0);
}
C10_HOST_DEVICE int input_idx() const {
int lane = threadIdx.x;
int warp = threadIdx.y;
int cta2 = blockIdx.y;
- return (lane * input_mult[LANE] +
- warp * input_mult[WARP] +
+ return (lane * input_mult[BLOCK_X] +
+ warp * input_mult[BLOCK_Y] +
cta2 * input_mult[CTA]);
}
int lane = threadIdx.x;
int warp = threadIdx.y;
int cta1 = blockIdx.x;
- return (lane * output_mult[LANE] +
- warp * output_mult[WARP] +
+ return (lane * output_mult[BLOCK_X] +
+ warp * output_mult[BLOCK_Y] +
cta1 * step_output);
}
C10_DEVICE int staging_memory_offset(int cta2) const {
int offset = cta2 + blockIdx.x * gridDim.y;
- if (!should_warp_reduce()) {
+ if (!should_block_x_reduce()) {
offset = threadIdx.x + offset * blockDim.x;
}
return offset;
}
int shared_memory_size() const {
- if (!should_block_reduce()) {
+ if (!should_block_y_reduce() &&
+ (!should_block_x_reduce() ||
+ block_width <= at::cuda::warp_size())) {
return 0;
}
- return element_size_bytes * NUM_THREADS;
+ return element_size_bytes * num_threads;
}
int64_t global_memory_size() const {
return 0;
}
auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
- if (!should_warp_reduce()) {
+ if (!should_block_x_reduce()) {
size *= block().x;
}
return size;
}
C10_DEVICE void run() const {
+ extern __shared__ char shared_memory[];
index_t output_idx = config.output_idx();
index_t input_idx = config.input_idx();
auto base_offsets = output_calc.get(output_idx);
auto input_slice = (const char*)src + base_offsets[1];
value = thread_reduce((const scalar_t*)input_slice);
}
- bool should_block_reduce = config.should_block_reduce();
- if (should_block_reduce) {
- value = block_reduce(value);
+ bool should_block_y_reduce = config.should_block_y_reduce();
+ if (should_block_y_reduce) {
+ value = block_y_reduce(value, shared_memory);
}
- if (config.should_warp_reduce() && (!should_block_reduce || threadIdx.y == 0)) {
- value = warp_reduce(value);
+ if (config.should_block_x_reduce()) {
+ value = block_x_reduce(value, shared_memory);
}
auto out = (out_scalar_t*)((char*)dst + base_offsets[0]);
if (config.should_global_reduce()) {
- value = global_reduce(value, out);
+ value = global_reduce(value, out, shared_memory);
} else if (config.should_store(output_idx)) {
if (accumulate) {
value = accumulate_in_output<can_accumulate_in_output>(out, value);
return value;
}
- C10_DEVICE arg_t warp_reduce(arg_t value) const {
- for (int offset = 1; offset < warpSize; offset <<= 1) {
+ C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const {
+ int dim_x = blockDim.x;
+ arg_t* shared = (arg_t*)shared_memory;
+ if (dim_x > warpSize) {
+ int address_base = threadIdx.x + threadIdx.y*blockDim.x;
+ shared[address_base] = value;
+ for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
+ __syncthreads();
+ if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
+ arg_t other = shared[address_base + offset];
+ value = ops.combine(value, other);
+ shared[address_base] = value;
+ }
+ }
+ dim_x = warpSize;
+ }
+
+ __syncthreads();
+
+ for (int offset = 1; offset < dim_x; offset <<= 1) {
arg_t other = ops.warp_shfl_down(value, offset);
value = ops.combine(value, other);
}
return value;
}
- C10_DEVICE arg_t block_reduce(arg_t value) const {
- extern __shared__ char shared_memory[];
+ C10_DEVICE arg_t block_y_reduce(arg_t value, char* shared_memory) const {
arg_t* shared = (arg_t*)shared_memory;
shared[config.shared_memory_offset(0)] = value;
- int num_warps = (blockDim.x * blockDim.y) / warpSize;
- for (int offset = num_warps / 2; offset > 0; offset >>= 1) {
+ for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
__syncthreads();
- if (threadIdx.y < offset && threadIdx.y + offset < num_warps) {
+ if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
arg_t other = shared[config.shared_memory_offset(offset)];
value = ops.combine(value, other);
shared[config.shared_memory_offset(0)] = value;
}
C10_DEVICE bool mark_block_finished() const {
- extern __shared__ int is_last_block_done_shared[];
+ __shared__ bool is_last_block_done_shared;
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0) {
int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
- is_last_block_done_shared[0] = (prev_blocks_finished == gridDim.y - 1);
+ is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
}
__syncthreads();
- bool is_last_block_done = is_last_block_done_shared[0];
- __syncthreads();
- return is_last_block_done;
+ return is_last_block_done_shared;
}
template <bool can_acc>
return ops.project(value);
}
- C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const {
+ C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out, char* shared_memory) const {
arg_t* reduce_buffer = (arg_t*)buffer;
bool should_store = config.should_store(config.output_idx());
if (is_last_block_done) {
value = ident;
- if (config.should_warp_reduce()) {
+ if (config.should_block_x_reduce()) {
index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
index_t step = blockDim.x * blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
value = ops.combine(value, next);
}
}
- value = block_reduce(value);
- if (config.should_warp_reduce()) {
- value = warp_reduce(value);
+ value = block_y_reduce(value, shared_memory);
+ if (config.should_block_x_reduce()) {
+ value = block_x_reduce(value, shared_memory);
}
if (should_store) {
if (accumulate) {
static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
dim3 block = config.block();
dim3 grid = config.grid();
+
auto stream = at::cuda::getCurrentCUDAStream();
int shared_memory = config.shared_memory_size();
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>>(reduction);
char* out_data = (char*)iter.data_ptr(0);
const char* in_data = (char*)iter.data_ptr(1);
-
- int warp_size = at::cuda::warp_size();
- int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
-
// Start by assuming that each thread handles a single output and all
// the inputs for that output.
int64_t num_outputs = iter.num_output_elements();
auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
+ int64_t dim0;
+ int64_t dim1;
+ // adjust block size to fit width to fast changing dimension
+ if (iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) {
+ dim0 = iter.shape()[0];
+ dim1 = num_outputs;
+ } else {
+ dim0 = iter.shape()[iter.num_reduce_dims()];
+ dim1 = inputs_per_output;
+ }
+
+ config.set_block_dimension(dim0, dim1);
+
+ int block_width = config.block_width;
+ int block_height = config.block_height;
+
if (iter.ndim() == 0 || iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) {
// Split the input across lanes if the input is contiguous in the reduced
// dimension. This will require reduction between threads using warp
- // shuffle instructions.
- config.input_mult[0] = config.split_input(warp_size);
+ // shuffle instructions and shared memory (if block_width > warpSize).
+ config.input_mult[0] = config.split_input(block_width);
} else {
// Otherwise split the output across lanes in a warp.
- config.output_mult[0] = config.split_output(warp_size);
+ config.output_mult[0] = config.split_output(block_width);
}
- if (config.values_per_thread() >= warps_per_cta * 16) {
+ if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= 256) {
// Divide the input across warps in a thread-block, if that leaves at least
// 16 elements to be summed by each thread. This will require inter-warp
// reduction using shared memory.
- config.input_mult[1] = config.split_input(warps_per_cta);
+ config.input_mult[1] = config.split_input(block_height);
} else {
// Otherwise, each warp handles a separate output.
- config.output_mult[1] = config.split_output(warps_per_cta);
+ config.output_mult[1] = config.split_output(block_height);
}
- if (config.values_per_thread() >= 256 && num_outputs <= 4096) {
+ if (config.input_mult[1] != 0 && config.values_per_thread() >= 256 && num_outputs <= 4096) {
// Divide the input across thread-blocks if the amount of work per-thread
// is large enough and the size of the output is small enough. This will
// require a reduction using global memory.
reduce.accumulate = iter.should_accumulate();
reduce.final_output = iter.is_final_output();
- launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+ launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS>(config, reduce);
} else {
auto output_calc = make_output_calculator<uint64_t>(iter);
auto input_calc = make_input_calculator<uint64_t>(iter);
reduce.accumulate = false;
reduce.final_output = true;
- launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+ launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS>(config, reduce);
}
}