from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
non_neg_concat_dim)
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
- if constant_op.is_constant(concat_dim):
- # If concat_dim is a constant defined in a different context,
- # then we duplicate it in the current context to avoid passing it
- # through an Enter node.
- # This is a small optimization in general, but it is required when
- # compiling with XLA, as XLA needs the concat input to be folded into a
- # constant.
- grad_context = control_flow_util.GetOutputContext(grad.op)
- dim_context = control_flow_util.GetOutputContext(concat_dim.op)
- if dim_context != grad_context:
- value = tensor_util.constant_value(concat_dim)
- concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
-
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])