Fix some edge cases around scalar indices in the gather expander
authorSanjoy Das <sanjoy@google.com>
Tue, 20 Mar 2018 06:15:42 +0000 (23:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 06:20:15 +0000 (23:20 -0700)
commite2e67c528316be8ea4f624af8757e80d7f00b5b6
treec372505a6c50f1c5476504dfe2766c6a7f2fa2b0
parent2311e9ced599d08f705afd631ee45cf027d05618
Fix some edge cases around scalar indices in the gather expander

I discovered these when changing the tf2xla bridge to directly emit gather
operations.

 - DeScalarizeGatherIndices was assuming that gather_indices must be of at least
   rank 1.  Fix this to be more general.

 - We were passing in the wrong version of gather indices to
   ExpandFirstDimIntoNDims.  We don't strictly need to pass in
   transposed_gather_indices (since if transposed_gather_indices is rank 1 then
   the transpose has to be an identity transpose), passing in
   descalarized_gather_indices would also have been fine, but
   transposed_gather_indices seems more uniform.

 - ExpandGatherDimsInAccumulator was assuming that gather_indices must be of at
   least rank 1 (by calling CollapseFirstNDims).  Fix this to be more general.

 - We were trying to go through with emitting zero sized gather operations.  I
   don't think it is worth dealing with all of the edge cases this would expose
   so now we just punt to ZeroSizedHloElimination.

PiperOrigin-RevId: 189696444
tensorflow/compiler/xla/service/gather_expander.cc
tensorflow/compiler/xla/service/hlo_creation_utils.cc
tensorflow/compiler/xla/service/hlo_creation_utils.h
tensorflow/compiler/xla/tests/gather_operation_test.cc