From: A. Unique TensorFlower Date: Thu, 8 Feb 2018 11:18:36 +0000 (-0800) Subject: Automated g4 rollback of changelist 184303789 X-Git-Tag: upstream/v1.7.0~31^2~892 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=14f5cfc159dff6855bde7ac5b0e037eec0229e89;p=platform%2Fupstream%2Ftensorflow.git Automated g4 rollback of changelist 184303789 PiperOrigin-RevId: 184970903 --- diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index c929218..9745d38 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @@ -115,6 +116,19 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: + if constant_op.is_constant(concat_dim): + # If concat_dim is a constant defined in a different context, + # then we duplicate it in the current context to avoid passing it + # through an Enter node. + # This is a small optimization in general, but it is required when + # compiling with XLA, as XLA needs the concat input to be folded into a + # constant. + grad_context = control_flow_util.GetOutputContext(grad.op) + dim_context = control_flow_util.GetOutputContext(concat_dim.op) + if dim_context != grad_context: + value = tensor_util.constant_value(concat_dim) + concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) + # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])