Add support for IndexedSlices in Distribution Strategy all reduce. Issue reported...
authorPriya Gupta <priyag@google.com>
Wed, 23 May 2018 23:10:30 +0000 (16:10 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 23:13:09 +0000 (16:13 -0700)
PiperOrigin-RevId: 197806955

tensorflow/contrib/distribute/python/BUILD
tensorflow/contrib/distribute/python/cross_tower_ops.py
tensorflow/contrib/distribute/python/cross_tower_ops_test.py
tensorflow/contrib/distribute/python/cross_tower_utils.py
tensorflow/contrib/distribute/python/cross_tower_utils_test.py [new file with mode: 0644]
tensorflow/python/ops/gradients_impl.py
tensorflow/python/ops/gradients_test.py

index 00161b2..3118dea 100644 (file)
@@ -445,6 +445,7 @@ py_library(
     srcs = ["cross_tower_utils.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":values",
         "//tensorflow/contrib/nccl:nccl_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework_ops",
@@ -452,6 +453,24 @@ py_library(
     ],
 )
 
+cuda_py_test(
+    name = "cross_tower_utils_test",
+    srcs = ["cross_tower_utils_test.py"],
+    additional_deps = [
+        ":combinations",
+        ":cross_tower_utils",
+        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:test",
+    ],
+    tags = [
+        "no_pip",
+    ],
+)
+
 py_library(
     name = "cross_tower_ops",
     srcs = ["cross_tower_ops.py"],
index c6a1bf6..a411b88 100644 (file)
@@ -77,12 +77,12 @@ def _all_devices_match(value_destination_pairs):
   return True
 
 
-def _simple_broadcast(tensor, destinations):
+def _simple_broadcast(value, destinations):
   index = {}
   devices = _get_devices_from(destinations)
   for d in devices:
-    with ops.device(d):
-      index[d] = array_ops.identity(tensor)
+    index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+        value, d)
   return value_lib.Mirrored(index)
 
 
@@ -98,7 +98,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
         continue
       count += len(v_list)
       # Sum within each device before aggregating across devices.
-      v = math_ops.add_n(v_list)
+      # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
+      v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+          v_list, math_ops.add_n)
     else:
       count += 1
     all_values.append(v)
@@ -107,11 +109,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
 
   with ops.device(reduce_to_device):
     with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
-      if method_string == "sum":
-        reduced = accumulation_fn(all_values)
-      elif method_string == "mean":
-        reduced = accumulation_fn(all_values) / count
-      else:
+      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+          all_values, accumulation_fn)
+      if method_string == "mean":
+        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
+            reduced, count)
+      elif method_string != "sum":
         raise ValueError("`method_string` must be 'sum' or 'mean'")
   return reduced
 
@@ -444,10 +447,18 @@ class AllReduceCrossTowerOps(CrossTowerOps):
     super(AllReduceCrossTowerOps, self).__init__()
 
   def _reduce(self, method_string, per_device_value, destinations):
+    contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+        per_device_value)
     if ((destinations is None or _devices_match(per_device_value, destinations))
-        and not context.executing_eagerly()):
+        and not context.executing_eagerly()
+        and not contains_indexed_slices):
       return self._batch_all_reduce(method_string, [per_device_value])[0]
     else:
+      if contains_indexed_slices:
+        logging.log_first_n(
+            logging.WARN,
+            "Efficient allreduce is not supported for IndexedSlices.", 10)
+
       devices = _get_devices_from(destinations or per_device_value)
       reduce_to_device = devices[0]
       reduced = _simple_reduce(per_device_value, reduce_to_device,
@@ -455,14 +466,18 @@ class AllReduceCrossTowerOps(CrossTowerOps):
       return self.broadcast(reduced, devices)
 
   def _batch_reduce(self, method_string, value_destination_pairs):
-    if (_all_devices_match(value_destination_pairs) and
-        not context.executing_eagerly()):
+    all_devices_match = _all_devices_match(value_destination_pairs)
+    contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+        value_destination_pairs)
+    if (all_devices_match and not context.executing_eagerly()
+        and not contains_indexed_slices):
       return self._batch_all_reduce(method_string,
                                     [v[0] for v in value_destination_pairs])
     else:
-      if not context.executing_eagerly():
+      if not all_devices_match:
         logging.warning("Efficient batch_reduce is not supported if "
                         "destinations are different.")
+
       return [
           self._reduce(method_string, t, destinations=v)
           for t, v in value_destination_pairs
index 7c7b087..2a26632 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
 
 
 def _make_per_device(values, devices):
@@ -56,19 +57,46 @@ def _fake_mirrored(value, devices):
       {d: v for d, v in zip(devices, [value] * len(devices))})
 
 
+def _make_indexed_slices(values, indices, dense_shape, device):
+  with ops.device(device):
+    tensor = ops.IndexedSlices(
+        values=constant_op.constant(values),
+        indices=constant_op.constant(indices),
+        dense_shape=constant_op.constant(dense_shape))
+  return tensor
+
+
+def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
+  return value_lib.Mirrored({
+      d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices
+  })
+
+
 _cpu_device = "/device:CPU:0"
 
 
 class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
 
-  def _assert_value_equal(self, left, right):
+  def _assert_indexed_slices_equal(self, left, right):
+    self.assertIsInstance(left, ops.IndexedSlices)
+    self.assertIsInstance(right, ops.IndexedSlices)
+    self.assertEqual(device_util.resolve(left.device),
+                     device_util.resolve(right.device))
+    self.assertAllEqual(
+        self.evaluate(ops.convert_to_tensor(left)),
+        self.evaluate(ops.convert_to_tensor(right)))
+
+  def _assert_values_equal(self, left, right):
     if isinstance(left, list):
       for l, r in zip(left, right):
-        self._assert_value_equal(l, r)
+        self._assert_values_equal(l, r)
     else:
       self.assertEqual(type(left), type(right))
       self.assertEqual(left.devices, right.devices)
-      if context.executing_eagerly():
+      if isinstance(list(left._index.values())[0], ops.IndexedSlices):
+        for (d, v) in left._index.iteritems():
+          self._assert_indexed_slices_equal(v, right._index[d])
+      elif context.executing_eagerly():
         self.assertEqual([v.numpy() for v in left._index.values()],
                          list(right._index.values()))
       else:
@@ -143,29 +171,29 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
 
     # test reduce()
     for destinations in all_destinations:
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.reduce("mean", per_device, destinations=destinations),
           _fake_mirrored(mean, destinations or per_device))
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.reduce(
               "mean", per_device_2, destinations=destinations),
           _fake_mirrored(mean_2, destinations or per_device))
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.reduce("sum", per_device, destinations=destinations),
           _fake_mirrored(mean * len(devices), destinations or per_device))
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.reduce(
               "sum", per_device_2, destinations=destinations),
           _fake_mirrored(mean_2 * len(devices), destinations or per_device))
 
     # test batch_reduce()
     for d1, d2 in itertools.product(all_destinations, all_destinations):
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.batch_reduce(
               "mean", [(per_device, d1), (per_device_2, d2)]),
           [_fake_mirrored(mean, d1 or per_device),
            _fake_mirrored(mean_2, d2 or per_device_2)])
-      self._assert_value_equal(
+      self._assert_values_equal(
           cross_tower_ops.batch_reduce(
               "sum", [(per_device, d1), (per_device_2, d2)]),
           [_fake_mirrored(mean * len(devices), d1 or per_device),
@@ -176,7 +204,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
       if destinations is None:
         continue
       else:
-        self._assert_value_equal(
+        self._assert_values_equal(
             cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
             _fake_mirrored(1., destinations))
 
@@ -184,16 +212,14 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
     device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
                     [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
     result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
-    self.assertTrue(
-        isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
     self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
     self.assertEqual(result.num_packs, 8)
 
     # if there are only 4 devices
     device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
     result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
-    self.assertTrue(
-        isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
     self.assertEqual(result.all_reduce_alg, "nccl")
     self.assertEqual(result.num_packs, 1)
 
@@ -202,8 +228,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
                     [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
                     [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
     result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
-    self.assertTrue(
-        isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
     self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
     self.assertEqual(result.num_packs, 8)
 
@@ -211,11 +236,85 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
     device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
                     [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
     result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
-    self.assertTrue(
-        isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
     self.assertEqual(result.all_reduce_alg, "nccl")
     self.assertEqual(result.num_packs, 1)
 
+  @combinations.generate(combinations.combine(
+      mode=["graph", "eager"],
+      required_gpus=1))
+  def testSimpleReduceWithIndexedSlices(self):
+    devices = ["/cpu:0", "/gpu:0"]
+    t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
+    t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
+    per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+    result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
+                                                math_ops.add_n, "sum")
+
+    # Test that the result is semantically equal to both the concatenated
+    # IndexedSlices with and without duplicate indices.
+    total_with_dups = _make_indexed_slices(
+        [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
+    total_without_dups = _make_indexed_slices(
+        [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
+    self._assert_indexed_slices_equal(total_with_dups, result)
+    self._assert_indexed_slices_equal(total_without_dups, result)
+
+  @combinations.generate(combinations.combine(
+      cross_tower_ops_instance=[
+          combinations.NamedObject(
+              "ReductionToOneDeviceCrossTowerOps",
+              cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+          combinations.NamedObject(
+              "AllReduceCrossTowerOps",
+              cross_tower_ops_lib.AllReduceCrossTowerOps())
+      ],
+      method_string=["sum", "mean"],
+      batch_reduce=[True, False],
+      mode=["graph", "eager"],
+      required_gpus=1))
+  def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
+                                 method_string, batch_reduce):
+    devices = ["/cpu:0", "/gpu:0"]
+    dense_shape = [5, 2]
+    t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
+    t1 = _make_indexed_slices(
+        [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
+    per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+
+    if batch_reduce:
+      result = cross_tower_ops_instance.batch_reduce(method_string,
+                                                     [(per_device, devices)])
+    else:
+      result = cross_tower_ops_instance.reduce(method_string, per_device,
+                                               devices)
+
+    total_indices_with_dups = [1, 1, 3]
+    total_indices_without_dups = [1, 3]
+
+    if method_string == "sum":
+      total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
+      total_values_without_dups = [[4., 6.], [5., 6.]]
+    else:
+      assert method_string == "mean"
+      total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
+      total_values_without_dups = [[2., 3.], [2.5, 3.]]
+
+    total_mirrored_with_dups = _make_mirrored_indexed_slices(
+        devices, total_values_with_dups, total_indices_with_dups, dense_shape)
+    total_mirrored_without_dups = _make_mirrored_indexed_slices(
+        devices, total_values_without_dups, total_indices_without_dups,
+        dense_shape)
+
+    # Test that the result is semantically equal to both the concatenated
+    # IndexedSlices, as well as when the duplicate indices are summed up.
+    if batch_reduce:
+      total_mirrored_with_dups = [total_mirrored_with_dups]
+      total_mirrored_without_dups = [total_mirrored_without_dups]
+
+    self._assert_values_equal(total_mirrored_with_dups, result)
+    self._assert_values_equal(total_mirrored_without_dups, result)
+
 
 if __name__ == "__main__":
   test.main()
index fc04e21..8dd7831 100644 (file)
@@ -21,9 +21,11 @@ from __future__ import print_function
 import collections as pycoll
 
 from tensorflow.contrib import nccl
+from tensorflow.contrib.distribute.python import values as value_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 
 
@@ -337,3 +339,46 @@ def unpack_small_tensors(tower_grads, packing):
         new_gv_list.insert(idx, gv[gi])
     new_tower_grads.append(new_gv_list)
   return new_tower_grads
+
+
+def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
+  """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
+  if isinstance(values[0], ops.IndexedSlices):
+    return gradients_impl._AggregateIndexedSlicesGradients(values)  # pylint: disable=protected-access
+  else:
+    return accumulation_fn(values)
+
+
+def divide_by_n_tensors_or_indexed_slices(value, n):
+  if isinstance(value, ops.IndexedSlices):
+    value = gradients_impl._HandleNestedIndexedSlices(value)  # pylint: disable=protected-access
+    return ops.IndexedSlices(
+        value.values / n, value.indices, value.dense_shape)
+  else:
+    return value / n
+
+
+def copy_tensor_or_indexed_slices_to_device(value, device):
+  with ops.device(device):
+    if isinstance(value, ops.IndexedSlices):
+      copied_values = array_ops.identity(value.values)
+      copied_indices = array_ops.identity(value.indices)
+      copied_shape = array_ops.identity(value.dense_shape)
+      result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
+    else:
+      result = array_ops.identity(value)
+  return result
+
+
+def contains_indexed_slices(value):
+  """Check whether the value is `IndexedSlices` or contains `IndexedSlices`."""
+  if isinstance(value, ops.IndexedSlices):
+    return True
+  elif isinstance(value, (list, tuple, pycoll.Sequence)) and value:
+    return any(contains_indexed_slices(v) for v in value)
+  elif isinstance(value, value_lib.DistributedValues):
+    return contains_indexed_slices(list(value._index.values()))  # pylint: disable=protected-access
+  elif isinstance(value, value_lib.MapOutput):
+    return contains_indexed_slices(value.get())
+  else:
+    return False
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
new file mode 100644 (file)
index 0000000..4ef8db6
--- /dev/null
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cross_tower_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
+
+
+class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
+
+  def _assert_values_equal(self, left, right):
+    self.assertAllEqual(
+        self.evaluate(ops.convert_to_tensor(left)),
+        self.evaluate(ops.convert_to_tensor(right)))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testAggregateTensors(self):
+    t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+    t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
+    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+    result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+    self._assert_values_equal(total, result)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testAggregateIndexedSlices(self):
+    t0 = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(
+        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+    result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+    self.assertIsInstance(result, ops.IndexedSlices)
+    self._assert_values_equal(total, result)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testDivideTensor(self):
+    t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+    n = 2
+    expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+    result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+    self._assert_values_equal(expected, result)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testDivideIndexedSlices(self):
+    t = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    n = 2
+    expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+    result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+    self.assertIsInstance(result, ops.IndexedSlices)
+    self._assert_values_equal(expected, result)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testIsIndexedSlices(self):
+    t = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testContainsIndexedSlices_List(self):
+    t0 = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(
+        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+    self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testContainsIndexedSlices_Tuple(self):
+    t0 = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(
+        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+    self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testContainsIndexedSlices_PerDevice(self):
+    t0 = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(
+        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+    per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1})
+    self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testContainsIndexedSlices_PerDeviceMapOutput(self):
+    t0 = math_ops._as_indexed_slices(
+        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(
+        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+    per_device = value_lib.PerDevice({
+        "/gpu:0": value_lib.MapOutput([t0]),
+        "/cpu:0": value_lib.MapOutput([t1])})
+    self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+  @combinations.generate(combinations.combine(
+      mode=["graph", "eager"],
+      required_gpus=1))
+  def testCopyTensor(self):
+    with ops.device("/cpu:0"):
+      t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+    destination = "/gpu:0"
+    result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+        t, destination)
+
+    self._assert_values_equal(t, result)
+    self.assertEqual(device_util.resolve(destination),
+                     device_util.resolve(result.device))
+
+  @combinations.generate(combinations.combine(
+      mode=["graph", "eager"],
+      required_gpus=1))
+  def testCopyIndexedSlices(self):
+    with ops.device("/cpu:0"):
+      t = math_ops._as_indexed_slices(
+          constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+    destination = "/gpu:0"
+    result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+        t, destination)
+
+    self.assertIsInstance(result, ops.IndexedSlices)
+    self._assert_values_equal(t, result)
+    self.assertEqual(device_util.resolve(destination),
+                     device_util.resolve(result.device))
+
+
+if __name__ == "__main__":
+  test.main()
index 716b54f..1e808fd 100644 (file)
@@ -1006,21 +1006,33 @@ def _AggregatedGrads(grads,
         logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
                      tensor_shape, used)
       else:
-        out_grad = math_ops._as_indexed_slices_list(
-            [g for g in out_grad if g is not None])
-        out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
-        # Form IndexedSlices out of the concatenated values and
-        # indices.
-        out_grads[i] = ops.IndexedSlices(
-            array_ops.concat([x.values for x in out_grad], 0),
-            array_ops.concat([x.indices for x in out_grad], 0),
-            out_grad[0].dense_shape)
+        out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
     else:  # not out_grad
       # out_grads[i] is [], thus its aggregation is simply None.
       out_grads[i] = None
   return out_grads
 
 
+def _AggregateIndexedSlicesGradients(grads):
+  """Aggregates gradients of type `IndexedSlices` by concatenation."""
+  if len(grads) < 1:
+    return None
+  elif len(grads) == 1:
+    return grads[0]
+  else:
+    assert isinstance(grads[0], ops.IndexedSlices)
+    grads = math_ops._as_indexed_slices_list(  # pylint: disable=protected-access
+        [g for g in grads if g is not None])
+    grads = [_HandleNestedIndexedSlices(x) for x in grads]  # pylint: disable=protected-access
+    # Form IndexedSlices out of the concatenated values and indices.
+    concat_grad = ops.IndexedSlices(
+        array_ops.concat([x.values for x in grads], axis=0),
+        array_ops.concat([x.indices for x in grads], axis=0),
+        grads[0].dense_shape)
+
+    return concat_grad
+
+
 # TODO(vrv): Make this available when we want to make it public.
 def _hessian_vector_product(ys, xs, v):
   """Multiply the Hessian of `ys` wrt `xs` by `v`.
index 70d500a..6891501 100644 (file)
@@ -946,5 +946,53 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
       self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
 
 
+class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
+
+  def _assert_indexed_slices_equal(self, left, right):
+    self.assertAllEqual(
+        self.evaluate(ops.convert_to_tensor(left)),
+        self.evaluate(ops.convert_to_tensor(right)))
+
+  def testNoGradients(self):
+    self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
+
+  def testOneGradient(self):
+    t = math_ops._as_indexed_slices(constant_op.constant(
+        [[1., 2.], [0, 0], [3., 4.]]))
+    result = gradients_impl._AggregateIndexedSlicesGradients([t])
+    self._assert_indexed_slices_equal(t, result)
+
+  def testMultipleGradients(self):
+    t0 = math_ops._as_indexed_slices(constant_op.constant(
+        [[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(constant_op.constant(
+        [[0., 0.], [5, 6], [7., 8.]]))
+    total = constant_op.constant(
+        [[1., 2.], [5, 6], [10., 12.]])
+    result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+    self._assert_indexed_slices_equal(total, result)
+
+  def testMultipleGradientsWithNones(self):
+    t0 = math_ops._as_indexed_slices(constant_op.constant(
+        [[1., 2.], [0, 0], [3., 4.]]))
+    t1 = math_ops._as_indexed_slices(constant_op.constant(
+        [[0., 0.], [5, 6], [7., 8.]]))
+    t3 = None
+    total = constant_op.constant(
+        [[1., 2.], [5, 6], [10., 12.]])
+    result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
+    self._assert_indexed_slices_equal(total, result)
+
+  def testMixedTensorAndIndexedSlices(self):
+    t0 = math_ops._as_indexed_slices(constant_op.constant(
+        [[1., 2.], [0, 0], [3., 4.]]))
+    t1 = constant_op.constant(
+        [[0., 0.], [5, 6], [7., 8.]])
+    total = constant_op.constant(
+        [[1., 2.], [5, 6], [10., 12.]])
+    result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+    self._assert_indexed_slices_equal(total, result)
+
+
 if __name__ == "__main__":
   googletest.main()