Compute shape of segment_ids dynamically in _unsorted_segment_N
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 3 May 2018 01:35:55 +0000 (18:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 01:38:37 +0000 (18:38 -0700)
PiperOrigin-RevId: 195186950

tensorflow/python/ops/math_ops.py

index 7ac3bd8..ab5997e 100644 (file)
@@ -2515,7 +2515,8 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
       of segment entries with 0-entries set to 1 to allow division by N.
   """
   # bincount doesn't support negative indices so we use unsorted_segment_sum
-  ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype)
+  segment_ids_shape = array_ops.shape_internal(segment_ids)
+  ones_tensor = array_ops.ones(segment_ids_shape, dtype=data.dtype)
   N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
   # add dimensions for all non-reduced axes
   ndims_output = data.shape.ndims - segment_ids.shape.ndims