TensorIterator cuda launch configs update (#16224)
authorJie <jiej@nvidia.com>
Thu, 7 Feb 2019 07:05:49 +0000 (23:05 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Feb 2019 07:10:41 +0000 (23:10 -0800)
commit49443d49fb0dace0029062f58391ae3509fd1626
tree74e7c8f3d41840f1c140eaef8cbd209117f93cbd
parentb2135b2b72171adc24b1fda499ff5a5d73e3a4e6
TensorIterator cuda launch configs update (#16224)

Summary:
Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
block dimension to improve efficiency for reduction cases with small fast
dimension.

Previously TensorIterator launches blocks with fixed 32x16 threads.
For cases like:

  import torch
  torch.randn(2**20, 4, device='cuda').sum(0)

The fixed launch config does handle coalesced memory access efficiently.

Updated launch configure enables flexible block dimension. Combining with
improved reduction scheme (using flexible vertical / horizontal reduction
instead of limited warp / block reduction in the old code), it ensures optimal
memory access pattern even with reduction on dimension with small stride.

Possible future improvements:
1. Precise dynamic shared memory allocation.
2. Using warp shuffle for vertical (block_y) reduction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16224

Differential Revision: D13806753

Pulled By: soumith

fbshipit-source-id: 37e45c7767b5748cf9ecf894fad306e040e2f79f
aten/src/ATen/native/cuda/Reduce.cuh