Enabling partitioned variables to work with TPU.
When partitioned variables are used in a TPU training loop,
concat gradient operations get generated for which XLA requires
the concat dimension argument to be a constant (or foldable to a constant).
However since such constant is defined outside of the train while context
an Enter node is generated in order to pass it.
The fix consists in detecting such case, and to duplicate the (scalar) constant
inside the while context, so that XLA can succesfully process the resulting
graph.
PiperOrigin-RevId:
184273245