From 00b6afc6ddf30fde104c5d2908a6d97ed414a58f Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 23 May 2018 17:58:42 -0700 Subject: [PATCH] Aggregating IndexedSlices: Do not require first element to be IndexedSlices. PiperOrigin-RevId: 197821479 --- tensorflow/contrib/distribute/python/cross_tower_utils.py | 2 +- tensorflow/python/ops/gradients_impl.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 8dd7831..4bff134 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -343,7 +343,7 @@ def unpack_small_tensors(tower_grads, packing): def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" - if isinstance(values[0], ops.IndexedSlices): + if any(isinstance(v, ops.IndexedSlices) for v in values): return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access else: return accumulation_fn(values) diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 1e808fd..7385cb7 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -1020,7 +1020,6 @@ def _AggregateIndexedSlicesGradients(grads): elif len(grads) == 1: return grads[0] else: - assert isinstance(grads[0], ops.IndexedSlices) grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access [g for g in grads if g is not None]) grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access -- 2.7.4