Aggregating IndexedSlices: Do not require first element to be IndexedSlices.
authorPriya Gupta <priyag@google.com>
Thu, 24 May 2018 00:58:42 +0000 (17:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 01:01:17 +0000 (18:01 -0700)
PiperOrigin-RevId: 197821479

tensorflow/contrib/distribute/python/cross_tower_utils.py
tensorflow/python/ops/gradients_impl.py

index 8dd7831..4bff134 100644 (file)
@@ -343,7 +343,7 @@ def unpack_small_tensors(tower_grads, packing):
 
 def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
   """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
-  if isinstance(values[0], ops.IndexedSlices):
+  if any(isinstance(v, ops.IndexedSlices) for v in values):
     return gradients_impl._AggregateIndexedSlicesGradients(values)  # pylint: disable=protected-access
   else:
     return accumulation_fn(values)
index 1e808fd..7385cb7 100644 (file)
@@ -1020,7 +1020,6 @@ def _AggregateIndexedSlicesGradients(grads):
   elif len(grads) == 1:
     return grads[0]
   else:
-    assert isinstance(grads[0], ops.IndexedSlices)
     grads = math_ops._as_indexed_slices_list(  # pylint: disable=protected-access
         [g for g in grads if g is not None])
     grads = [_HandleNestedIndexedSlices(x) for x in grads]  # pylint: disable=protected-access