Implements Gather operator for arbitrary axis, sharing the code with BatchGather. (#13756)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13756
This implements general Gather operator for arbitrary axis, sharing the code with BatchGather.
- CPU gather & batch gather logic is now shared through caffe2::gather_helper, for any axis.
- Shared CUDA kernel moved to gather_op.cuh, for any axis.
- Gradients of axis > 0 delegate to BatchGatherGradientOp which now has axis argument.
- BatchGatherOp doc strings updated to have correct rank (q + (r -1)) and output.
- Added tests for axis == 2.
GatherOp supports index wrapping for axis == 0 by default, which was earlier for ONNX.
This diff also extends it to work in Cuda kernel. Added "wrap_indices" argument which specifies
wheather this wrapping should be done; set it to true if you'd like wrapping for any axis.
TBD: Update gradients to support negative indices (separate diff).
TBD: Once we have operator versioning, we'd like to update GatherOp to NOT support axis 0 wrapping
by default, but rather do it only if wrap_indices is set.
Reviewed By: dzhulgakov
Differential Revision:
D12983815
fbshipit-source-id:
8add9d67b47fe8c5ba7a335f581ca0530b205cd7