From: Priya Gupta Date: Thu, 24 May 2018 00:58:42 +0000 (-0700) Subject: Aggregating IndexedSlices: Do not require first element to be IndexedSlices. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~140 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=00b6afc6ddf30fde104c5d2908a6d97ed414a58f;p=platform%2Fupstream%2Ftensorflow.git Aggregating IndexedSlices: Do not require first element to be IndexedSlices. PiperOrigin-RevId: 197821479 --- diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 8dd7831..4bff134 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -343,7 +343,7 @@ def unpack_small_tensors(tower_grads, packing): def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" - if isinstance(values[0], ops.IndexedSlices): + if any(isinstance(v, ops.IndexedSlices) for v in values): return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access else: return accumulation_fn(values) diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 1e808fd..7385cb7 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -1020,7 +1020,6 @@ def _AggregateIndexedSlicesGradients(grads): elif len(grads) == 1: return grads[0] else: - assert isinstance(grads[0], ops.IndexedSlices) grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access [g for g in grads if g is not None]) grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access