From: Priya Gupta Date: Wed, 23 May 2018 23:10:30 +0000 (-0700) Subject: Add support for IndexedSlices in Distribution Strategy all reduce. Issue reported... X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~151 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fa5c52e31f30e8cb88a5452e3b4aefc786fb8852;p=platform%2Fupstream%2Ftensorflow.git Add support for IndexedSlices in Distribution Strategy all reduce. Issue reported in #19069 PiperOrigin-RevId: 197806955 --- diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 00161b2..3118dea 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -445,6 +445,7 @@ py_library( srcs = ["cross_tower_utils.py"], srcs_version = "PY2AND3", deps = [ + ":values", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", @@ -452,6 +453,24 @@ py_library( ], ) +cuda_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", + ], +) + py_library( name = "cross_tower_ops", srcs = ["cross_tower_ops.py"], diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index c6a1bf6..a411b88 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -77,12 +77,12 @@ def _all_devices_match(value_destination_pairs): return True -def _simple_broadcast(tensor, destinations): +def _simple_broadcast(value, destinations): index = {} devices = _get_devices_from(destinations) for d in devices: - with ops.device(d): - index[d] = array_ops.identity(tensor) + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) return value_lib.Mirrored(index) @@ -98,7 +98,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, continue count += len(v_list) # Sum within each device before aggregating across devices. - v = math_ops.add_n(v_list) + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) else: count += 1 all_values.append(v) @@ -107,11 +109,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - if method_string == "sum": - reduced = accumulation_fn(all_values) - elif method_string == "mean": - reduced = accumulation_fn(all_values) / count - else: + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if method_string == "mean": + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif method_string != "sum": raise ValueError("`method_string` must be 'sum' or 'mean'") return reduced @@ -444,10 +447,18 @@ class AllReduceCrossTowerOps(CrossTowerOps): super(AllReduceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) - and not context.executing_eagerly()): + and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [per_device_value])[0] else: + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + devices = _get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, @@ -455,14 +466,18 @@ class AllReduceCrossTowerOps(CrossTowerOps): return self.broadcast(reduced, devices) def _batch_reduce(self, method_string, value_destination_pairs): - if (_all_devices_match(value_destination_pairs) and - not context.executing_eagerly()): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [v[0] for v in value_destination_pairs]) else: - if not context.executing_eagerly(): + if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") + return [ self._reduce(method_string, t, destinations=v) for t, v in value_destination_pairs diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 7c7b087..2a26632 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util def _make_per_device(values, devices): @@ -56,19 +57,46 @@ def _fake_mirrored(value, devices): {d: v for d, v in zip(devices, [value] * len(devices))}) +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + _cpu_device = "/device:CPU:0" class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): - def _assert_value_equal(self, left, right): + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): if isinstance(left, list): for l, r in zip(left, right): - self._assert_value_equal(l, r) + self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) self.assertEqual(left.devices, right.devices) - if context.executing_eagerly(): + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.iteritems(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], list(right._index.values())) else: @@ -143,29 +171,29 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("mean", per_device, destinations=destinations), _fake_mirrored(mean, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "mean", per_device_2, destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("sum", per_device, destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "sum", per_device_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "mean", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean, d1 or per_device), _fake_mirrored(mean_2, d2 or per_device_2)]) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "sum", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean * len(devices), d1 or per_device), @@ -176,7 +204,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): if destinations is None: continue else: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) @@ -184,16 +212,14 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "hierarchical_copy") self.assertEqual(result.num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "nccl") self.assertEqual(result.num_packs, 1) @@ -202,8 +228,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "hierarchical_copy") self.assertEqual(result.num_packs, 8) @@ -211,11 +236,85 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "nccl") self.assertEqual(result.num_packs, 1) + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], + math_ops.add_n, "sum") + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate(combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + method_string=["sum", "mean"], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, + method_string, batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(method_string, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(method_string, per_device, + devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if method_string == "sum": + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert method_string == "mean" + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index fc04e21..8dd7831 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,9 +21,11 @@ from __future__ import print_function import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -337,3 +339,46 @@ def unpack_small_tensors(tower_grads, packing): new_gv_list.insert(idx, gv[gi]) new_tower_grads.append(new_gv_list) return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if isinstance(values[0], ops.IndexedSlices): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple, pycoll.Sequence)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 0000000..4ef8db6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 716b54f..1e808fd 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -1006,21 +1006,33 @@ def _AggregatedGrads(grads, logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), tensor_shape, used) else: - out_grad = math_ops._as_indexed_slices_list( - [g for g in out_grad if g is not None]) - out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad] - # Form IndexedSlices out of the concatenated values and - # indices. - out_grads[i] = ops.IndexedSlices( - array_ops.concat([x.values for x in out_grad], 0), - array_ops.concat([x.indices for x in out_grad], 0), - out_grad[0].dense_shape) + out_grads[i] = _AggregateIndexedSlicesGradients(out_grad) else: # not out_grad # out_grads[i] is [], thus its aggregation is simply None. out_grads[i] = None return out_grads +def _AggregateIndexedSlicesGradients(grads): + """Aggregates gradients of type `IndexedSlices` by concatenation.""" + if len(grads) < 1: + return None + elif len(grads) == 1: + return grads[0] + else: + assert isinstance(grads[0], ops.IndexedSlices) + grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access + [g for g in grads if g is not None]) + grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access + # Form IndexedSlices out of the concatenated values and indices. + concat_grad = ops.IndexedSlices( + array_ops.concat([x.values for x in grads], axis=0), + array_ops.concat([x.indices for x in grads], axis=0), + grads[0].dense_shape) + + return concat_grad + + # TODO(vrv): Make this available when we want to make it public. def _hessian_vector_product(ys, xs, v): """Multiply the Hessian of `ys` wrt `xs` by `v`. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 70d500a..6891501 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -946,5 +946,53 @@ class CustomGradientTest(test_util.TensorFlowTestCase): self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) +class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase): + + def _assert_indexed_slices_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def testNoGradients(self): + self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([])) + + def testOneGradient(self): + t = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + result = gradients_impl._AggregateIndexedSlicesGradients([t]) + self._assert_indexed_slices_equal(t, result) + + def testMultipleGradients(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + def testMultipleGradientsWithNones(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + t3 = None + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3]) + self._assert_indexed_slices_equal(total, result) + + def testMixedTensorAndIndexedSlices(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + if __name__ == "__main__": googletest.main()