srcs = ["cross_tower_utils.py"],
srcs_version = "PY2AND3",
deps = [
+ ":values",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
],
)
+cuda_py_test(
+ name = "cross_tower_utils_test",
+ srcs = ["cross_tower_utils_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":cross_tower_utils",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
py_library(
name = "cross_tower_ops",
srcs = ["cross_tower_ops.py"],
return True
-def _simple_broadcast(tensor, destinations):
+def _simple_broadcast(value, destinations):
index = {}
devices = _get_devices_from(destinations)
for d in devices:
- with ops.device(d):
- index[d] = array_ops.identity(tensor)
+ index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ value, d)
return value_lib.Mirrored(index)
continue
count += len(v_list)
# Sum within each device before aggregating across devices.
- v = math_ops.add_n(v_list)
+ # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
+ v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+ v_list, math_ops.add_n)
else:
count += 1
all_values.append(v)
with ops.device(reduce_to_device):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- if method_string == "sum":
- reduced = accumulation_fn(all_values)
- elif method_string == "mean":
- reduced = accumulation_fn(all_values) / count
- else:
+ reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+ all_values, accumulation_fn)
+ if method_string == "mean":
+ reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
+ reduced, count)
+ elif method_string != "sum":
raise ValueError("`method_string` must be 'sum' or 'mean'")
return reduced
super(AllReduceCrossTowerOps, self).__init__()
def _reduce(self, method_string, per_device_value, destinations):
+ contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+ per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
- and not context.executing_eagerly()):
+ and not context.executing_eagerly()
+ and not contains_indexed_slices):
return self._batch_all_reduce(method_string, [per_device_value])[0]
else:
+ if contains_indexed_slices:
+ logging.log_first_n(
+ logging.WARN,
+ "Efficient allreduce is not supported for IndexedSlices.", 10)
+
devices = _get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
return self.broadcast(reduced, devices)
def _batch_reduce(self, method_string, value_destination_pairs):
- if (_all_devices_match(value_destination_pairs) and
- not context.executing_eagerly()):
+ all_devices_match = _all_devices_match(value_destination_pairs)
+ contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+ value_destination_pairs)
+ if (all_devices_match and not context.executing_eagerly()
+ and not contains_indexed_slices):
return self._batch_all_reduce(method_string,
[v[0] for v in value_destination_pairs])
else:
- if not context.executing_eagerly():
+ if not all_devices_match:
logging.warning("Efficient batch_reduce is not supported if "
"destinations are different.")
+
return [
self._reduce(method_string, t, destinations=v)
for t, v in value_destination_pairs
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
def _make_per_device(values, devices):
{d: v for d, v in zip(devices, [value] * len(devices))})
+def _make_indexed_slices(values, indices, dense_shape, device):
+ with ops.device(device):
+ tensor = ops.IndexedSlices(
+ values=constant_op.constant(values),
+ indices=constant_op.constant(indices),
+ dense_shape=constant_op.constant(dense_shape))
+ return tensor
+
+
+def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
+ return value_lib.Mirrored({
+ d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices
+ })
+
+
_cpu_device = "/device:CPU:0"
class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
- def _assert_value_equal(self, left, right):
+ def _assert_indexed_slices_equal(self, left, right):
+ self.assertIsInstance(left, ops.IndexedSlices)
+ self.assertIsInstance(right, ops.IndexedSlices)
+ self.assertEqual(device_util.resolve(left.device),
+ device_util.resolve(right.device))
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ def _assert_values_equal(self, left, right):
if isinstance(left, list):
for l, r in zip(left, right):
- self._assert_value_equal(l, r)
+ self._assert_values_equal(l, r)
else:
self.assertEqual(type(left), type(right))
self.assertEqual(left.devices, right.devices)
- if context.executing_eagerly():
+ if isinstance(list(left._index.values())[0], ops.IndexedSlices):
+ for (d, v) in left._index.iteritems():
+ self._assert_indexed_slices_equal(v, right._index[d])
+ elif context.executing_eagerly():
self.assertEqual([v.numpy() for v in left._index.values()],
list(right._index.values()))
else:
# test reduce()
for destinations in all_destinations:
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce("mean", per_device, destinations=destinations),
_fake_mirrored(mean, destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce(
"mean", per_device_2, destinations=destinations),
_fake_mirrored(mean_2, destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce("sum", per_device, destinations=destinations),
_fake_mirrored(mean * len(devices), destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce(
"sum", per_device_2, destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations or per_device))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.batch_reduce(
"mean", [(per_device, d1), (per_device_2, d2)]),
[_fake_mirrored(mean, d1 or per_device),
_fake_mirrored(mean_2, d2 or per_device_2)])
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.batch_reduce(
"sum", [(per_device, d1), (per_device_2, d2)]),
[_fake_mirrored(mean * len(devices), d1 or per_device),
if destinations is None:
continue
else:
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
[0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testSimpleReduceWithIndexedSlices(self):
+ devices = ["/cpu:0", "/gpu:0"]
+ t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
+ t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
+ per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+ result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
+ math_ops.add_n, "sum")
+
+ # Test that the result is semantically equal to both the concatenated
+ # IndexedSlices with and without duplicate indices.
+ total_with_dups = _make_indexed_slices(
+ [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
+ total_without_dups = _make_indexed_slices(
+ [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
+ self._assert_indexed_slices_equal(total_with_dups, result)
+ self._assert_indexed_slices_equal(total_without_dups, result)
+
+ @combinations.generate(combinations.combine(
+ cross_tower_ops_instance=[
+ combinations.NamedObject(
+ "ReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "AllReduceCrossTowerOps",
+ cross_tower_ops_lib.AllReduceCrossTowerOps())
+ ],
+ method_string=["sum", "mean"],
+ batch_reduce=[True, False],
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
+ method_string, batch_reduce):
+ devices = ["/cpu:0", "/gpu:0"]
+ dense_shape = [5, 2]
+ t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
+ t1 = _make_indexed_slices(
+ [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
+ per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+
+ if batch_reduce:
+ result = cross_tower_ops_instance.batch_reduce(method_string,
+ [(per_device, devices)])
+ else:
+ result = cross_tower_ops_instance.reduce(method_string, per_device,
+ devices)
+
+ total_indices_with_dups = [1, 1, 3]
+ total_indices_without_dups = [1, 3]
+
+ if method_string == "sum":
+ total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
+ total_values_without_dups = [[4., 6.], [5., 6.]]
+ else:
+ assert method_string == "mean"
+ total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
+ total_values_without_dups = [[2., 3.], [2.5, 3.]]
+
+ total_mirrored_with_dups = _make_mirrored_indexed_slices(
+ devices, total_values_with_dups, total_indices_with_dups, dense_shape)
+ total_mirrored_without_dups = _make_mirrored_indexed_slices(
+ devices, total_values_without_dups, total_indices_without_dups,
+ dense_shape)
+
+ # Test that the result is semantically equal to both the concatenated
+ # IndexedSlices, as well as when the duplicate indices are summed up.
+ if batch_reduce:
+ total_mirrored_with_dups = [total_mirrored_with_dups]
+ total_mirrored_without_dups = [total_mirrored_without_dups]
+
+ self._assert_values_equal(total_mirrored_with_dups, result)
+ self._assert_values_equal(total_mirrored_without_dups, result)
+
if __name__ == "__main__":
test.main()
import collections as pycoll
from tensorflow.contrib import nccl
+from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
new_gv_list.insert(idx, gv[gi])
new_tower_grads.append(new_gv_list)
return new_tower_grads
+
+
+def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
+ """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
+ if isinstance(values[0], ops.IndexedSlices):
+ return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access
+ else:
+ return accumulation_fn(values)
+
+
+def divide_by_n_tensors_or_indexed_slices(value, n):
+ if isinstance(value, ops.IndexedSlices):
+ value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access
+ return ops.IndexedSlices(
+ value.values / n, value.indices, value.dense_shape)
+ else:
+ return value / n
+
+
+def copy_tensor_or_indexed_slices_to_device(value, device):
+ with ops.device(device):
+ if isinstance(value, ops.IndexedSlices):
+ copied_values = array_ops.identity(value.values)
+ copied_indices = array_ops.identity(value.indices)
+ copied_shape = array_ops.identity(value.dense_shape)
+ result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
+ else:
+ result = array_ops.identity(value)
+ return result
+
+
+def contains_indexed_slices(value):
+ """Check whether the value is `IndexedSlices` or contains `IndexedSlices`."""
+ if isinstance(value, ops.IndexedSlices):
+ return True
+ elif isinstance(value, (list, tuple, pycoll.Sequence)) and value:
+ return any(contains_indexed_slices(v) for v in value)
+ elif isinstance(value, value_lib.DistributedValues):
+ return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access
+ elif isinstance(value, value_lib.MapOutput):
+ return contains_indexed_slices(value.get())
+ else:
+ return False
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cross_tower_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
+
+
+class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
+
+ def _assert_values_equal(self, left, right):
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testAggregateTensors(self):
+ t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
+ total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+ result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+ self._assert_values_equal(total, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testAggregateIndexedSlices(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+ result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(total, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testDivideTensor(self):
+ t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ n = 2
+ expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+ result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+ self._assert_values_equal(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testDivideIndexedSlices(self):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ n = 2
+ expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+ result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testIsIndexedSlices(self):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_List(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_Tuple(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_PerDevice(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1})
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_PerDeviceMapOutput(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ per_device = value_lib.PerDevice({
+ "/gpu:0": value_lib.MapOutput([t0]),
+ "/cpu:0": value_lib.MapOutput([t1])})
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testCopyTensor(self):
+ with ops.device("/cpu:0"):
+ t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ destination = "/gpu:0"
+ result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ t, destination)
+
+ self._assert_values_equal(t, result)
+ self.assertEqual(device_util.resolve(destination),
+ device_util.resolve(result.device))
+
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testCopyIndexedSlices(self):
+ with ops.device("/cpu:0"):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ destination = "/gpu:0"
+ result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ t, destination)
+
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(t, result)
+ self.assertEqual(device_util.resolve(destination),
+ device_util.resolve(result.device))
+
+
+if __name__ == "__main__":
+ test.main()
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else:
- out_grad = math_ops._as_indexed_slices_list(
- [g for g in out_grad if g is not None])
- out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
- # Form IndexedSlices out of the concatenated values and
- # indices.
- out_grads[i] = ops.IndexedSlices(
- array_ops.concat([x.values for x in out_grad], 0),
- array_ops.concat([x.indices for x in out_grad], 0),
- out_grad[0].dense_shape)
+ out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
else: # not out_grad
# out_grads[i] is [], thus its aggregation is simply None.
out_grads[i] = None
return out_grads
+def _AggregateIndexedSlicesGradients(grads):
+ """Aggregates gradients of type `IndexedSlices` by concatenation."""
+ if len(grads) < 1:
+ return None
+ elif len(grads) == 1:
+ return grads[0]
+ else:
+ assert isinstance(grads[0], ops.IndexedSlices)
+ grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access
+ [g for g in grads if g is not None])
+ grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access
+ # Form IndexedSlices out of the concatenated values and indices.
+ concat_grad = ops.IndexedSlices(
+ array_ops.concat([x.values for x in grads], axis=0),
+ array_ops.concat([x.indices for x in grads], axis=0),
+ grads[0].dense_shape)
+
+ return concat_grad
+
+
# TODO(vrv): Make this available when we want to make it public.
def _hessian_vector_product(ys, xs, v):
"""Multiply the Hessian of `ys` wrt `xs` by `v`.
self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
+class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
+
+ def _assert_indexed_slices_equal(self, left, right):
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ def testNoGradients(self):
+ self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
+
+ def testOneGradient(self):
+ t = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ result = gradients_impl._AggregateIndexedSlicesGradients([t])
+ self._assert_indexed_slices_equal(t, result)
+
+ def testMultipleGradients(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMultipleGradientsWithNones(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ t3 = None
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMixedTensorAndIndexedSlices(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]])
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+
if __name__ == "__main__":
googletest.main()