For tf.gradients(), do not backpropagate through integer tensors.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Apr 2018 19:42:54 +0000 (12:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 19:45:22 +0000 (12:45 -0700)
All integer tensors are now considered constant with respect to all `xs`.
This fixes a bug in gradients through tf.while_loop.

PiperOrigin-RevId: 194438529

14 files changed:
tensorflow/compiler/tests/tensor_array_ops_test.py
tensorflow/contrib/batching/python/ops/batch_ops_test.py
tensorflow/contrib/compiler/jit_test.py
tensorflow/python/data/kernel_tests/iterator_ops_test.py
tensorflow/python/eager/function_test.py
tensorflow/python/framework/meta_graph_test.py
tensorflow/python/kernel_tests/array_ops_test.py
tensorflow/python/kernel_tests/control_flow_ops_py_test.py
tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
tensorflow/python/kernel_tests/gradient_correctness_test.py
tensorflow/python/kernel_tests/nth_element_op_test.py
tensorflow/python/kernel_tests/tensor_array_ops_test.py
tensorflow/python/kernel_tests/topk_op_test.py
tensorflow/python/ops/gradients_impl.py

index 7624d6e..f332aa2 100644 (file)
@@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase):
       self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])
 
   def testTensorArrayGradientWriteRead(self):
-    for dtype in self.numeric_types:
+    for dtype in self.float_types:
+      self._testTensorArrayGradientWriteReadType(dtype)
+    for dtype in self.complex_types:
       self._testTensorArrayGradientWriteReadType(dtype)
 
   def _testTensorArrayGradientWritePackConcatAndRead(self):
index fac7aff..e22f978 100644 (file)
@@ -250,7 +250,7 @@ class BatchOpsTest(test.TestCase):
   def testUnbatchGrad(self):
     """Tests that batch and unbatch are differentiable."""
     with self.test_session() as sess:
-      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+      inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
           batch_timeout_micros=36000000, grad_timeout_micros=1000000,
index 29a593f..b2f678f 100644 (file)
@@ -175,7 +175,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
 
   def testCompilationInGradient(self):
     with self.test_session():
-      x = constant_op.constant([[3]])
+      x = constant_op.constant([[3.]])
       y_nc = math_ops.matmul(x, x, name="not_compiled")
       with jit.experimental_jit_scope():
         y_c = math_ops.matmul(y_nc, y_nc, name="compiled")
@@ -200,11 +200,11 @@ class CompilationEnabledInGradientTest(test.TestCase):
     with self.test_session(graph=ops.Graph()):
       with jit.experimental_jit_scope():
         # XlaScope 0
-        a1 = constant_op.constant([[1]])
+        a1 = constant_op.constant([[1.]])
         a1t = math_ops.matmul(a1, a1)
       with jit.experimental_jit_scope():
         # XlaScope 1
-        a2 = constant_op.constant([[1]])
+        a2 = constant_op.constant([[1.]])
         a2t = math_ops.matmul(a2, a2)
 
       self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
@@ -222,11 +222,11 @@ class CompilationEnabledInGradientTest(test.TestCase):
     with self.test_session(graph=ops.Graph()):
       with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
         # XlaScope 0
-        a1 = constant_op.constant([[1]])
+        a1 = constant_op.constant([[1.]])
         a1t = math_ops.matmul(a1, a1)
       with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
         # XlaScope 1
-        a2 = constant_op.constant([[1]])
+        a2 = constant_op.constant([[1.]])
         a2t = math_ops.matmul(a2, a2)
 
       self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
index 0af282a..820c167 100644 (file)
@@ -51,18 +51,15 @@ from tensorflow.python.util import compat
 
 class IteratorTest(test.TestCase):
 
-  def testAttemptingGradientsRaiseExceptions(self):
-    component = constant_op.constant([1])
-    side = constant_op.constant(0)
+  def testNoGradients(self):
+    component = constant_op.constant([1.])
+    side = constant_op.constant(0.)
     add = lambda x: x + side
     dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
     value = dataset.make_one_shot_iterator().get_next()
-    with self.assertRaisesRegexp(LookupError, "No gradient defined"):
-      gradients_impl.gradients(value, component)
-    with self.assertRaisesRegexp(LookupError, "No gradient defined"):
-      gradients_impl.gradients(value, side)
-    with self.assertRaisesRegexp(LookupError, "No gradient defined"):
-      gradients_impl.gradients(value, [component, side])
+    self.assertIsNone(gradients_impl.gradients(value, component)[0])
+    self.assertIsNone(gradients_impl.gradients(value, side)[0])
+    self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
 
   def testCapturingStateInOneShotRaisesException(self):
     var = variables.Variable(37.0, name="myvar")
index 1828c98..185f6d9 100644 (file)
@@ -309,7 +309,7 @@ class FunctionTest(test.TestCase):
     def g(x):
       return backprop.gradients_function(f, [0])(x)[0]
 
-    self.assertAllEqual(2, g(constant_op.constant(2)))
+    self.assertAllEqual(2, g(constant_op.constant(2.)))
 
   def testGraphModeEagerGradError(self):
     with context.graph_mode():
index e5b1576..0532ed4 100644 (file)
@@ -476,11 +476,12 @@ class ScopedMetaGraphTest(test.TestCase):
     # Create a simple while loop.
     with ops.Graph().as_default():
       with ops.name_scope("export"):
-        var = variables.Variable(0)
+        var = variables.Variable(0.)
         var_name = var.name
-        _, output = control_flow_ops.while_loop(lambda i, x: i < 5,
-                                                lambda i, x: (i + 1, x + i),
-                                                [0, var])
+        _, output = control_flow_ops.while_loop(
+            lambda i, x: i < 5,
+            lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
+            [0, var])
         output_name = output.name
 
       # Generate a MetaGraphDef containing the while loop with an export scope.
index 5a20eeb..7acca0a 100644 (file)
@@ -730,7 +730,7 @@ class GradSliceChecker(object):
     analytic_grad2 = 2 * slice_val
 
     dy = variables.Variable(
-        array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.int32))
+        array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.float32))
     assign = dy.assign(slice_var)
     slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy)
     slice_val_grad2, = gradients_impl.gradients(
@@ -755,7 +755,8 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase):
   def testGradient(self):
     with self.test_session(use_gpu=True) as sess:
       var = variables.Variable(
-          array_ops.reshape(math_ops.range(1, 97, 1), shape=(6, 4, 4)))
+          array_ops.reshape(
+              math_ops.range(1, 97, 1, dtype=dtypes.float32), shape=(6, 4, 4)))
       init = variables.global_variables_initializer()
       sess.run(init)
 
@@ -774,7 +775,7 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase):
 
   def testGradientZero(self):
     with self.test_session(use_gpu=True) as sess:
-      var = variables.Variable(8)
+      var = variables.Variable(8.)
       init = variables.global_variables_initializer()
       sess.run(init)
       grad = GradSliceChecker(self, sess, var, np.array(8))
@@ -782,11 +783,11 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase):
 
   def testInt64Indices(self):
     with self.test_session(use_gpu=True) as sess:
-      a = math_ops.range(3)
+      a = math_ops.range(3, dtype=dtypes.float32)
       index = constant_op.constant(1, dtype=dtypes.int64)
-      b = 2 * a[index]
+      b = 2. * a[index]
       grad, = gradients_impl.gradients(b, a)
-      self.assertAllEqual(sess.run(grad), [0, 2, 0])
+      self.assertAllEqual(sess.run(grad), [0., 2., 0.])
 
 
 class StridedSliceGradTypeTest(test_util.TensorFlowTestCase):
index 209411c..77e6f5f 100644 (file)
@@ -2222,14 +2222,14 @@ class ControlFlowTest(test.TestCase):
 
   def testWhileWithRefsWithGradients_1(self):
     with self.test_session() as sess:
-      x = variables.Variable(0)._ref()  # pylint: disable=protected-access
+      x = variables.Variable(0.)._ref()  # pylint: disable=protected-access
       i = constant_op.constant(0)
       c = lambda i, x: math_ops.less(i, 10)
 
-      self.assertEqual(x.dtype, dtypes.int32_ref)
+      self.assertEqual(x.dtype, dtypes.float32_ref)
 
       def body(i, x):
-        self.assertEqual(x.dtype, dtypes.int32_ref)
+        self.assertEqual(x.dtype, dtypes.float32_ref)
         return [i + 1, gen_array_ops.ref_identity(x)]
 
       r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
@@ -2240,7 +2240,7 @@ class ControlFlowTest(test.TestCase):
       variables.global_variables_initializer().run()
 
       self.assertEqual(r[0].dtype, dtypes.int32)
-      self.assertEqual(r[1].dtype, dtypes.int32_ref)
+      self.assertEqual(r[1].dtype, dtypes.float32_ref)
 
       value_i, value_x, value_x_grad = sess.run(r + grad)
 
@@ -2443,6 +2443,63 @@ class ControlFlowTest(test.TestCase):
       r = gradients_impl.gradients(r, y)[0]
       self.assertEqual(388.0, r.eval())
 
+  def testWhileGradientWithNontrainablePath1(self):
+    q = variables.Variable([7., 8.])
+
+    def cond(_, y):
+      del y
+      return False
+
+    def body(x, _):
+      return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
+
+    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
+    dy_dq, = gradients_impl.gradients(y, q)
+    self.assertIsNotNone(dy_dq)
+    with self.test_session() as sess:
+      sess.run(q.initializer)
+      self.assertAllClose([0., 0.], sess.run(dy_dq))
+
+  def testWhileGradientWithNontrainablePath2(self):
+    q = variables.Variable([7., 8.])
+
+    def cond(_, y):
+      return math_ops.equal(y, 0.)
+
+    def body(x, _):
+      zero = constant_op.constant(0, dtype=dtypes.int64)
+      return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
+
+    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
+    dy_dq, = gradients_impl.gradients(y, q)
+    self.assertIsNotNone(dy_dq)
+    with self.test_session() as sess:
+      sess.run(q.initializer)
+      self.assertAllClose([1., 1.], sess.run(dy_dq))
+
+  def testIssue16504(self):
+    c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
+    w = variables.Variable(
+        initial_value=np.ones(100), dtype=dtypes.float32) / 100
+    k = variables.Variable(0, dtype=dtypes.int32)
+    chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
+
+    def cond(k, _, chg_w):
+      return math_ops.logical_and(k < 10, chg_w > 1e-3)
+
+    def body(k, w, chg_w):
+      grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
+      w_n = w * math_ops.exp(-0.1 * grad)
+      w_n /= math_ops.reduce_sum(w_n)
+      chg_w = (
+          math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
+              math_ops.abs(w)))
+      return k + 1, w_n, chg_w
+
+    _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
+    grad, = gradients_impl.gradients(w, c)
+    self.assertIsNotNone(grad)
+
   def testStopGradMultiFlows(self):
     with self.test_session():
 
index a4b30e4..159cba5 100644 (file)
@@ -113,22 +113,23 @@ class DynamicStitchTestBase(object):
           constant_op.constant([[5, 2], [0, 3]])
       ]
       data = [
-          constant_op.constant([61, 62]),
-          constant_op.constant([[41, 42], [11, 12]]),
-          constant_op.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])
+          constant_op.constant([61., 62.]),
+          constant_op.constant([[41., 42.], [11., 12.]]),
+          constant_op.constant([[[51., 52.], [21., 22.]],
+                                [[1., 2.], [31., 32.]]])
       ]
       stitched_t = self.stitch_op(indices, data)
       stitched_val = stitched_t.eval()
-      correct = 10 * np.arange(7)[:, None] + [1, 2]
+      correct = 10. * np.arange(7)[:, None] + [1., 2.]
       self.assertAllEqual(correct, stitched_val)
       self.assertEqual([7, 2], stitched_t.get_shape().as_list())
       # Test gradients
-      stitched_grad = 7 * stitched_val
+      stitched_grad = 7. * stitched_val
       grads = gradients_impl.gradients(stitched_t, indices + data,
                                        stitched_grad)
       self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
       for datum, grad in zip(data, sess.run(grads[3:])):
-        self.assertAllEqual(7 * datum.eval(), grad)
+        self.assertAllEqual(7. * datum.eval(), grad)
 
   def testErrorIndicesMultiDimensional(self):
     indices = [
index 10fe4f5..e93c623 100644 (file)
@@ -40,6 +40,71 @@ class GradientCorrectnessTest(test.TestCase):
       # [dexp(x)/dx + d(log(exp(x)))/dx] @ x=1 == exp(1) + 1
       self.assertAllClose(grad_vals[0], exp1_plus_one)
 
+  def testIdentityGradient(self):
+    x = constant_op.constant(3.)
+    dx_dx, = gradients_impl.gradients(x, x)
+    with self.test_session() as sess:
+      self.assertAllClose(1., sess.run(dx_dx))
+
+  def testIntegerIdentityGradient(self):
+    x = constant_op.constant(3)
+    dx_dx, = gradients_impl.gradients(x, x)
+    with self.test_session() as sess:
+      self.assertAllClose(1, sess.run(dx_dx))
+
+  def testGradientWithIntegerPath(self):
+    x = constant_op.constant([3.9, 4.1])
+    k = math_ops.to_float(math_ops.to_int32(x))
+    y = x * k
+    dy_dx, = gradients_impl.gradients(y, x)
+    with self.test_session() as sess:
+      self.assertAllClose([3., 4.], sess.run(dy_dx))
+
+  def testNoIntegerGradient1(self):
+    x = constant_op.constant([3.9, 4.1])
+    k = math_ops.to_float(math_ops.to_int32(x))
+    y = k * k
+    dy_dx, = gradients_impl.gradients(y, x)
+    self.assertIsNone(dy_dx)
+
+  def testNoIntegerGradient2(self):
+    k = constant_op.constant([3, 4])
+    x = math_ops.to_float(k)
+    y = x * x
+    dy_dk, = gradients_impl.gradients(y, k)
+    self.assertIsNone(dy_dk)
+
+  def testNoIntegerGradient3(self):
+    k = constant_op.constant([3, 4])
+    m = k * k
+    dm_dk, = gradients_impl.gradients(m, k)
+    self.assertIsNone(dm_dk)
+
+  def testNoIntegerGradient4(self):
+    k = constant_op.constant([3, 4])
+    m = k * k * k
+    dm_dk, = gradients_impl.gradients(m, k)
+    self.assertIsNone(dm_dk)
+
+  def testNoIntegerGradient5(self):
+    k = constant_op.constant([3, 4])
+    m = k * k
+    n = m * m
+    dn_dk, = gradients_impl.gradients(n, k)
+    self.assertIsNone(dn_dk)
+
+  def testNoIntegerGradient6(self):
+    k = constant_op.constant(3)
+    x = math_ops.to_float(k)
+    grad_1, = gradients_impl.gradients(k * k, k)
+    grad_2, = gradients_impl.gradients(x * x, k)
+    grad_3, = gradients_impl.gradients(math_ops.square(k), k)
+    grad_4, = gradients_impl.gradients(math_ops.square(x), k)
+    self.assertIsNone(grad_1)
+    self.assertIsNone(grad_2)
+    self.assertIsNone(grad_3)
+    self.assertIsNone(grad_4)
+
 
 if __name__ == '__main__':
   test.main()
index 58cd46d..1b8f021 100644 (file)
@@ -154,14 +154,14 @@ class NthElementTest(test.TestCase):
 
   def testGradients(self):
     with self.test_session(use_gpu=False) as sess:
-      inputs = array_ops.placeholder(dtypes.int32, shape=[3, 5])
+      inputs = array_ops.placeholder(dtypes.float32, shape=[3, 5])
       values = nn_ops.nth_element(inputs, 3)
       grad = sess.run(
           gradients_impl.gradients(
               values, inputs, grad_ys=[[-1., 2., 5.]]),
-          feed_dict={inputs: [[2, -1, 1000, 3, 1000],
-                              [1, 5, 2, 4, 3],
-                              [2, 2, 2, 2, 2],
+          feed_dict={inputs: [[2., -1., 1000., 3., 1000.],
+                              [1., 5., 2., 4., 3.],
+                              [2., 2., 2., 2., 2.],
                              ]})
     self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5],
                                   [0, 0, 0, 2, 0],
index a834675..918bbd3 100644 (file)
@@ -615,8 +615,7 @@ class TensorArrayTest(test.TestCase):
       self.assertAllEqual(c(-2.0), grad_vals[1])
 
   def testTensorArrayGradientWriteRead(self):
-    for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64,
-                  np.complex128):
+    for dtype in (np.float32, np.float64, np.complex64, np.complex128):
       self._testTensorArrayGradientWriteReadType(dtype)
 
   def _testTensorArrayGradientWritePackConcatAndRead(self):
index 6ab931f..fa7c6a0 100644 (file)
@@ -197,13 +197,15 @@ class TopKTest(test.TestCase):
 
   def testTopKGradients(self):
     with self.test_session(use_gpu=True) as sess:
-      inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5])
+      inputs = array_ops.placeholder(dtypes.float32, shape=[2, 5])
       values, _ = nn_ops.top_k(inputs, 3)
       grad = sess.run(
           gradients_impl.gradients(
-              values, inputs, grad_ys=[[[1, 2, 3], [4, 5, 6]]]),
-          feed_dict={inputs: [[2, -1, 1000, 3, 4], [1, 5, 2, 4, 3]]})[0]
-    self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]])
+              values, inputs, grad_ys=[[[1., 2., 3.], [4., 5., 6.]]]),
+          feed_dict={inputs: [[2., -1., 1000., 3., 4.],
+                              [1., 5., 2., 4., 3.]]})[0]
+    self.assertEqual(
+        grad.tolist(), [[0., 0., 1., 3., 2.], [0., 4., 0., 5., 6.]])
 
 
 class TopKBenchmark(test.Benchmark):
index 13420b7..581ba7d 100644 (file)
@@ -121,7 +121,8 @@ def _MarkReachedOps(from_ops, reached_ops):
     if not reached_ops[op._id]:
       reached_ops[op._id] = True
       for output in op.outputs:
-        queue.extend(output.consumers())
+        if _IsBackpropagatable(output):
+          queue.extend(output.consumers())
 
 
 def _GatherInputs(to_ops, reached_ops):
@@ -163,16 +164,19 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
     colocate_gradients_with_ops: Python bool.  See docstring of gradients().
 
   Returns:
-    A tuple containing: (1) a list of integers indexed by operation id,
-    indicating the number of backprop inputs to this operation, and (2)
-    a ControlFlowState object which is not None if the ops between from_ops
-    and to_ops contain control flow loops.
+    A tuple containing: (1) the subset of to_ops ids reachable from from_ops
+    by a path of zero or more backpropagatable tensors, (2) a list of integers
+    indexed by operation id, indicating the number of backprop inputs to this
+    operation, and (3) a ControlFlowState object which is not None if the ops
+    between from_ops and to_ops contain control flow loops.
   """
   # Mark reachable ops from from_ops.
   reached_ops = [False] * (graph._last_id + 1)
-  for op in to_ops:
-    reached_ops[op._id] = True
   _MarkReachedOps(from_ops, reached_ops)
+  # reached_ops[X] iff X is reachable from from_ops by a path of zero or more
+  # backpropagatable tensors.
+
+  reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id])  # pylint: disable=protected-access
 
   # Mark between ops.
   between_ops = [False] * (graph._last_id + 1)
@@ -189,6 +193,8 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
       reached_ops[op._id] = False
       for inp in op.inputs:
         queue.append(inp.op)
+  # between_ops[X] iff X is on a path of zero or more backpropagatable tensors
+  # between from_ops and to_ops
 
   # 'loop_state' is None if there are no while loops.
   loop_state = control_flow_ops.MaybeCreateControlFlowState(
@@ -201,7 +207,7 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
       if between_ops[x.op._id]:
         pending_count[x.op._id] += 1
 
-  return pending_count, loop_state
+  return reachable_to_ops, pending_count, loop_state
 
 
 def _AsList(x):
@@ -294,6 +300,13 @@ def _IsTrainable(tensor):
                               dtypes.complex64, dtypes.complex128)
 
 
+def _IsBackpropagatable(tensor):
+  if _IsTrainable(tensor):
+    return True
+  dtype = dtypes.as_dtype(tensor.dtype)
+  return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+
+
 def _VerifyGeneratedGradients(grads, op):
   """Verify that gradients are valid in number and type.
 
@@ -460,6 +473,9 @@ def gradients(ys,
   backpropagation stops at both `tf.stop_gradient` nodes and nodes in
   `stop_gradients`, whichever is encountered first.
 
+  All integer tensors are considered constant with respect to all `xs`, as if
+  they were included in `stop_gradients`.
+
   Args:
     ys: A `Tensor` or list of tensors to be differentiated.
     xs: A `Tensor` or list of tensors to be used for differentiation.
@@ -539,7 +555,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
     to_ops = [t.op for t in ys]
     from_ops = [t.op for t in xs]
     stop_gradient_ops = [t.op for t in stop_gradients]
-    pending_count, loop_state = _PendingCount(
+    reachable_to_ops, pending_count, loop_state = _PendingCount(
         ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
 
     # Iterate over the collected ops.
@@ -564,7 +580,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
       # another output's gradient.
       # pylint: disable=protected-access
       ready = (pending_count[op._id] == 0)
-      if ready and op._id not in to_ops_set:
+      if ready and op._id not in to_ops_set and op._id in reachable_to_ops:
         to_ops_set.add(op._id)
         queue.append(op)
       # pylint: enable=protected-access