Automated g4 rollback of changelist 184303789
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Feb 2018 11:18:36 +0000 (03:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 11:23:08 +0000 (03:23 -0800)
PiperOrigin-RevId: 184970903

tensorflow/python/ops/array_grad.py

index c929218..9745d38 100644 (file)
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import sparse_ops
@@ -115,6 +116,19 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
                                                         non_neg_concat_dim)
       out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
     else:
+      if constant_op.is_constant(concat_dim):
+        # If concat_dim is a constant defined in a different context,
+        # then we duplicate it in the current context to avoid passing it
+        # through an Enter node.
+        # This is a small optimization in general, but it is required when
+        # compiling with XLA, as XLA needs the concat input to be folded into a
+        # constant.
+        grad_context = control_flow_util.GetOutputContext(grad.op)
+        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
+        if dim_context != grad_context:
+          value = tensor_util.constant_value(concat_dim)
+          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
+
       # Using mod here for convenience since concat_dim is already verified
       # in concat implementation to be within the allowed [-rank, rank) range.
       non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])