Implements Gather operator for arbitrary axis, sharing the code with BatchGather...
authorMichael Antonov <michael.antonov@oculus.com>
Tue, 4 Dec 2018 19:42:43 +0000 (11:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 19:54:28 +0000 (11:54 -0800)
commit773f4d8081634884ec421236165bd2ae19417503
treed9fe46fca8f7a04c3c18cbfc4d7bc264757956b3
parent16558a1e9d55a9d29d46245931ca5158684c9fda
Implements Gather operator for arbitrary axis, sharing the code with BatchGather. (#13756)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13756

This implements general Gather operator for arbitrary axis, sharing the code with BatchGather.
 - CPU gather & batch gather logic is now shared through caffe2::gather_helper, for any axis.
 - Shared CUDA kernel moved to gather_op.cuh, for any axis.
 - Gradients of axis > 0 delegate to BatchGatherGradientOp which now has axis argument.
 - BatchGatherOp doc strings updated to have correct rank (q + (r -1)) and output.
 - Added tests for axis == 2.

GatherOp supports index wrapping for axis == 0 by default, which was earlier for ONNX.
This diff also extends it to work in Cuda kernel. Added "wrap_indices" argument which specifies
wheather this wrapping should be done; set it to true if you'd like wrapping for any axis.

TBD: Update gradients to support negative indices (separate diff).
TBD: Once we have operator versioning, we'd like to update GatherOp to NOT support axis 0 wrapping
by default, but rather do it only if wrap_indices is set.

Reviewed By: dzhulgakov

Differential Revision: D12983815

fbshipit-source-id: 8add9d67b47fe8c5ba7a335f581ca0530b205cd7
caffe2/core/operator.h
caffe2/operators/batch_gather_ops.cc
caffe2/operators/batch_gather_ops.cu
caffe2/operators/batch_gather_ops.h
caffe2/operators/gather_op.cc
caffe2/operators/gather_op.cu
caffe2/operators/gather_op.cuh [new file with mode: 0644]
caffe2/operators/gather_op.h
caffe2/python/operator_test/gather_ops_test.py
tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py