From: A. Unique TensorFlower Date: Thu, 26 Apr 2018 19:42:54 +0000 (-0700) Subject: For tf.gradients(), do not backpropagate through integer tensors. X-Git-Tag: upstream/v1.9.0_rc1~206^2~1^2~31 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f63750645826df65b05cad505546a86f0e347674;p=platform%2Fupstream%2Ftensorflow.git For tf.gradients(), do not backpropagate through integer tensors. All integer tensors are now considered constant with respect to all `xs`. This fixes a bug in gradients through tf.while_loop. PiperOrigin-RevId: 194438529 --- diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 7624d6e..f332aa2 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in self.numeric_types: + for dtype in self.float_types: + self._testTensorArrayGradientWriteReadType(dtype) + for dtype in self.complex_types: self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index fac7aff..e22f978 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -250,7 +250,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" with self.test_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, batch_timeout_micros=36000000, grad_timeout_micros=1000000, diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 29a593f..b2f678f 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -175,7 +175,7 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant([[3]]) + x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): y_c = math_ops.matmul(y_nc, y_nc, name="compiled") @@ -200,11 +200,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) @@ -222,11 +222,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index 0af282a..820c167 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -51,18 +51,15 @@ from tensorflow.python.util import compat class IteratorTest(test.TestCase): - def testAttemptingGradientsRaiseExceptions(self): - component = constant_op.constant([1]) - side = constant_op.constant(0) + def testNoGradients(self): + component = constant_op.constant([1.]) + side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset.make_one_shot_iterator().get_next() - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, component) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, side) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, [component, side]) + self.assertIsNone(gradients_impl.gradients(value, component)[0]) + self.assertIsNone(gradients_impl.gradients(value, side)[0]) + self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 1828c98..185f6d9 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -309,7 +309,7 @@ class FunctionTest(test.TestCase): def g(x): return backprop.gradients_function(f, [0])(x)[0] - self.assertAllEqual(2, g(constant_op.constant(2))) + self.assertAllEqual(2, g(constant_op.constant(2.))) def testGraphModeEagerGradError(self): with context.graph_mode(): diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index e5b1576..0532ed4 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -476,11 +476,12 @@ class ScopedMetaGraphTest(test.TestCase): # Create a simple while loop. with ops.Graph().as_default(): with ops.name_scope("export"): - var = variables.Variable(0) + var = variables.Variable(0.) var_name = var.name - _, output = control_flow_ops.while_loop(lambda i, x: i < 5, - lambda i, x: (i + 1, x + i), - [0, var]) + _, output = control_flow_ops.while_loop( + lambda i, x: i < 5, + lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)), + [0, var]) output_name = output.name # Generate a MetaGraphDef containing the while loop with an export scope. diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 5a20eeb..7acca0a 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -730,7 +730,7 @@ class GradSliceChecker(object): analytic_grad2 = 2 * slice_val dy = variables.Variable( - array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.int32)) + array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.float32)) assign = dy.assign(slice_var) slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy) slice_val_grad2, = gradients_impl.gradients( @@ -755,7 +755,8 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase): def testGradient(self): with self.test_session(use_gpu=True) as sess: var = variables.Variable( - array_ops.reshape(math_ops.range(1, 97, 1), shape=(6, 4, 4))) + array_ops.reshape( + math_ops.range(1, 97, 1, dtype=dtypes.float32), shape=(6, 4, 4))) init = variables.global_variables_initializer() sess.run(init) @@ -774,7 +775,7 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase): def testGradientZero(self): with self.test_session(use_gpu=True) as sess: - var = variables.Variable(8) + var = variables.Variable(8.) init = variables.global_variables_initializer() sess.run(init) grad = GradSliceChecker(self, sess, var, np.array(8)) @@ -782,11 +783,11 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase): def testInt64Indices(self): with self.test_session(use_gpu=True) as sess: - a = math_ops.range(3) + a = math_ops.range(3, dtype=dtypes.float32) index = constant_op.constant(1, dtype=dtypes.int64) - b = 2 * a[index] + b = 2. * a[index] grad, = gradients_impl.gradients(b, a) - self.assertAllEqual(sess.run(grad), [0, 2, 0]) + self.assertAllEqual(sess.run(grad), [0., 2., 0.]) class StridedSliceGradTypeTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 209411c..77e6f5f 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -2222,14 +2222,14 @@ class ControlFlowTest(test.TestCase): def testWhileWithRefsWithGradients_1(self): with self.test_session() as sess: - x = variables.Variable(0)._ref() # pylint: disable=protected-access + x = variables.Variable(0.)._ref() # pylint: disable=protected-access i = constant_op.constant(0) c = lambda i, x: math_ops.less(i, 10) - self.assertEqual(x.dtype, dtypes.int32_ref) + self.assertEqual(x.dtype, dtypes.float32_ref) def body(i, x): - self.assertEqual(x.dtype, dtypes.int32_ref) + self.assertEqual(x.dtype, dtypes.float32_ref) return [i + 1, gen_array_ops.ref_identity(x)] r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) @@ -2240,7 +2240,7 @@ class ControlFlowTest(test.TestCase): variables.global_variables_initializer().run() self.assertEqual(r[0].dtype, dtypes.int32) - self.assertEqual(r[1].dtype, dtypes.int32_ref) + self.assertEqual(r[1].dtype, dtypes.float32_ref) value_i, value_x, value_x_grad = sess.run(r + grad) @@ -2443,6 +2443,63 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r, y)[0] self.assertEqual(388.0, r.eval()) + def testWhileGradientWithNontrainablePath1(self): + q = variables.Variable([7., 8.]) + + def cond(_, y): + del y + return False + + def body(x, _): + return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) + + _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) + dy_dq, = gradients_impl.gradients(y, q) + self.assertIsNotNone(dy_dq) + with self.test_session() as sess: + sess.run(q.initializer) + self.assertAllClose([0., 0.], sess.run(dy_dq)) + + def testWhileGradientWithNontrainablePath2(self): + q = variables.Variable([7., 8.]) + + def cond(_, y): + return math_ops.equal(y, 0.) + + def body(x, _): + zero = constant_op.constant(0, dtype=dtypes.int64) + return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) + + _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) + dy_dq, = gradients_impl.gradients(y, q) + self.assertIsNotNone(dy_dq) + with self.test_session() as sess: + sess.run(q.initializer) + self.assertAllClose([1., 1.], sess.run(dy_dq)) + + def testIssue16504(self): + c = constant_op.constant(np.arange(100), dtype=dtypes.float32) + w = variables.Variable( + initial_value=np.ones(100), dtype=dtypes.float32) / 100 + k = variables.Variable(0, dtype=dtypes.int32) + chg_w = constant_op.constant(np.inf, dtype=dtypes.float32) + + def cond(k, _, chg_w): + return math_ops.logical_and(k < 10, chg_w > 1e-3) + + def body(k, w, chg_w): + grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w) + w_n = w * math_ops.exp(-0.1 * grad) + w_n /= math_ops.reduce_sum(w_n) + chg_w = ( + math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum( + math_ops.abs(w))) + return k + 1, w_n, chg_w + + _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w]) + grad, = gradients_impl.gradients(w, c) + self.assertIsNotNone(grad) + def testStopGradMultiFlows(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index a4b30e4..159cba5 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -113,22 +113,23 @@ class DynamicStitchTestBase(object): constant_op.constant([[5, 2], [0, 3]]) ] data = [ - constant_op.constant([61, 62]), - constant_op.constant([[41, 42], [11, 12]]), - constant_op.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]) + constant_op.constant([61., 62.]), + constant_op.constant([[41., 42.], [11., 12.]]), + constant_op.constant([[[51., 52.], [21., 22.]], + [[1., 2.], [31., 32.]]]) ] stitched_t = self.stitch_op(indices, data) stitched_val = stitched_t.eval() - correct = 10 * np.arange(7)[:, None] + [1, 2] + correct = 10. * np.arange(7)[:, None] + [1., 2.] self.assertAllEqual(correct, stitched_val) self.assertEqual([7, 2], stitched_t.get_shape().as_list()) # Test gradients - stitched_grad = 7 * stitched_val + stitched_grad = 7. * stitched_val grads = gradients_impl.gradients(stitched_t, indices + data, stitched_grad) self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients for datum, grad in zip(data, sess.run(grads[3:])): - self.assertAllEqual(7 * datum.eval(), grad) + self.assertAllEqual(7. * datum.eval(), grad) def testErrorIndicesMultiDimensional(self): indices = [ diff --git a/tensorflow/python/kernel_tests/gradient_correctness_test.py b/tensorflow/python/kernel_tests/gradient_correctness_test.py index 10fe4f5..e93c623 100644 --- a/tensorflow/python/kernel_tests/gradient_correctness_test.py +++ b/tensorflow/python/kernel_tests/gradient_correctness_test.py @@ -40,6 +40,71 @@ class GradientCorrectnessTest(test.TestCase): # [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1 self.assertAllClose(grad_vals[0], exp1_plus_one) + def testIdentityGradient(self): + x = constant_op.constant(3.) + dx_dx, = gradients_impl.gradients(x, x) + with self.test_session() as sess: + self.assertAllClose(1., sess.run(dx_dx)) + + def testIntegerIdentityGradient(self): + x = constant_op.constant(3) + dx_dx, = gradients_impl.gradients(x, x) + with self.test_session() as sess: + self.assertAllClose(1, sess.run(dx_dx)) + + def testGradientWithIntegerPath(self): + x = constant_op.constant([3.9, 4.1]) + k = math_ops.to_float(math_ops.to_int32(x)) + y = x * k + dy_dx, = gradients_impl.gradients(y, x) + with self.test_session() as sess: + self.assertAllClose([3., 4.], sess.run(dy_dx)) + + def testNoIntegerGradient1(self): + x = constant_op.constant([3.9, 4.1]) + k = math_ops.to_float(math_ops.to_int32(x)) + y = k * k + dy_dx, = gradients_impl.gradients(y, x) + self.assertIsNone(dy_dx) + + def testNoIntegerGradient2(self): + k = constant_op.constant([3, 4]) + x = math_ops.to_float(k) + y = x * x + dy_dk, = gradients_impl.gradients(y, k) + self.assertIsNone(dy_dk) + + def testNoIntegerGradient3(self): + k = constant_op.constant([3, 4]) + m = k * k + dm_dk, = gradients_impl.gradients(m, k) + self.assertIsNone(dm_dk) + + def testNoIntegerGradient4(self): + k = constant_op.constant([3, 4]) + m = k * k * k + dm_dk, = gradients_impl.gradients(m, k) + self.assertIsNone(dm_dk) + + def testNoIntegerGradient5(self): + k = constant_op.constant([3, 4]) + m = k * k + n = m * m + dn_dk, = gradients_impl.gradients(n, k) + self.assertIsNone(dn_dk) + + def testNoIntegerGradient6(self): + k = constant_op.constant(3) + x = math_ops.to_float(k) + grad_1, = gradients_impl.gradients(k * k, k) + grad_2, = gradients_impl.gradients(x * x, k) + grad_3, = gradients_impl.gradients(math_ops.square(k), k) + grad_4, = gradients_impl.gradients(math_ops.square(x), k) + self.assertIsNone(grad_1) + self.assertIsNone(grad_2) + self.assertIsNone(grad_3) + self.assertIsNone(grad_4) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/nth_element_op_test.py b/tensorflow/python/kernel_tests/nth_element_op_test.py index 58cd46d..1b8f021 100644 --- a/tensorflow/python/kernel_tests/nth_element_op_test.py +++ b/tensorflow/python/kernel_tests/nth_element_op_test.py @@ -154,14 +154,14 @@ class NthElementTest(test.TestCase): def testGradients(self): with self.test_session(use_gpu=False) as sess: - inputs = array_ops.placeholder(dtypes.int32, shape=[3, 5]) + inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5]) values = nn_ops.nth_element(inputs, 3) grad = sess.run( gradients_impl.gradients( values, inputs, grad_ys=[[-1., 2., 5.]]), - feed_dict={inputs: [[2, -1, 1000, 3, 1000], - [1, 5, 2, 4, 3], - [2, 2, 2, 2, 2], + feed_dict={inputs: [[2., -1., 1000., 3., 1000.], + [1., 5., 2., 4., 3.], + [2., 2., 2., 2., 2.], ]}) self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5], [0, 0, 0, 2, 0], diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index a834675..918bbd3 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -615,8 +615,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(c(-2.0), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64, - np.complex128): + for dtype in (np.float32, np.float64, np.complex64, np.complex128): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index 6ab931f..fa7c6a0 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -197,13 +197,15 @@ class TopKTest(test.TestCase): def testTopKGradients(self): with self.test_session(use_gpu=True) as sess: - inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5]) + inputs = array_ops.placeholder(dtypes.float32, shape=[2, 5]) values, _ = nn_ops.top_k(inputs, 3) grad = sess.run( gradients_impl.gradients( - values, inputs, grad_ys=[[[1, 2, 3], [4, 5, 6]]]), - feed_dict={inputs: [[2, -1, 1000, 3, 4], [1, 5, 2, 4, 3]]})[0] - self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]]) + values, inputs, grad_ys=[[[1., 2., 3.], [4., 5., 6.]]]), + feed_dict={inputs: [[2., -1., 1000., 3., 4.], + [1., 5., 2., 4., 3.]]})[0] + self.assertEqual( + grad.tolist(), [[0., 0., 1., 3., 2.], [0., 4., 0., 5., 6.]]) class TopKBenchmark(test.Benchmark): diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 13420b7..581ba7d 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -121,7 +121,8 @@ def _MarkReachedOps(from_ops, reached_ops): if not reached_ops[op._id]: reached_ops[op._id] = True for output in op.outputs: - queue.extend(output.consumers()) + if _IsBackpropagatable(output): + queue.extend(output.consumers()) def _GatherInputs(to_ops, reached_ops): @@ -163,16 +164,19 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): colocate_gradients_with_ops: Python bool. See docstring of gradients(). Returns: - A tuple containing: (1) a list of integers indexed by operation id, - indicating the number of backprop inputs to this operation, and (2) - a ControlFlowState object which is not None if the ops between from_ops - and to_ops contain control flow loops. + A tuple containing: (1) the subset of to_ops ids reachable from from_ops + by a path of zero or more backpropagatable tensors, (2) a list of integers + indexed by operation id, indicating the number of backprop inputs to this + operation, and (3) a ControlFlowState object which is not None if the ops + between from_ops and to_ops contain control flow loops. """ # Mark reachable ops from from_ops. reached_ops = [False] * (graph._last_id + 1) - for op in to_ops: - reached_ops[op._id] = True _MarkReachedOps(from_ops, reached_ops) + # reached_ops[X] iff X is reachable from from_ops by a path of zero or more + # backpropagatable tensors. + + reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access # Mark between ops. between_ops = [False] * (graph._last_id + 1) @@ -189,6 +193,8 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): reached_ops[op._id] = False for inp in op.inputs: queue.append(inp.op) + # between_ops[X] iff X is on a path of zero or more backpropagatable tensors + # between from_ops and to_ops # 'loop_state' is None if there are no while loops. loop_state = control_flow_ops.MaybeCreateControlFlowState( @@ -201,7 +207,7 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): if between_ops[x.op._id]: pending_count[x.op._id] += 1 - return pending_count, loop_state + return reachable_to_ops, pending_count, loop_state def _AsList(x): @@ -294,6 +300,13 @@ def _IsTrainable(tensor): dtypes.complex64, dtypes.complex128) +def _IsBackpropagatable(tensor): + if _IsTrainable(tensor): + return True + dtype = dtypes.as_dtype(tensor.dtype) + return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant) + + def _VerifyGeneratedGradients(grads, op): """Verify that gradients are valid in number and type. @@ -460,6 +473,9 @@ def gradients(ys, backpropagation stops at both `tf.stop_gradient` nodes and nodes in `stop_gradients`, whichever is encountered first. + All integer tensors are considered constant with respect to all `xs`, as if + they were included in `stop_gradients`. + Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. @@ -539,7 +555,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] - pending_count, loop_state = _PendingCount( + reachable_to_ops, pending_count, loop_state = _PendingCount( ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops) # Iterate over the collected ops. @@ -564,7 +580,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, # another output's gradient. # pylint: disable=protected-access ready = (pending_count[op._id] == 0) - if ready and op._id not in to_ops_set: + if ready and op._id not in to_ops_set and op._id in reachable_to_ops: to_ops_set.add(op._id) queue.append(op) # pylint: enable=protected-access