of segment entries with 0-entries set to 1 to allow division by N.
"""
# bincount doesn't support negative indices so we use unsorted_segment_sum
- ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype)
+ segment_ids_shape = array_ops.shape_internal(segment_ids)
+ ones_tensor = array_ops.ones(segment_ids_shape, dtype=data.dtype)
N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
# add dimensions for all non-reduced axes
ndims_output = data.shape.ndims - segment_ids.shape.ndims