[TF CriticalSection] Bugfix when Execute() inside a while_loop has a dep on a Variabl...
authorEugene Brevdo <ebrevdo@google.com>
Wed, 21 Mar 2018 20:28:11 +0000 (13:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 20:31:07 +0000 (13:31 -0700)
PiperOrigin-RevId: 189957569

tensorflow/contrib/framework/python/ops/critical_section_ops.py
tensorflow/contrib/framework/python/ops/critical_section_test.py

index 1893d7b..bd764ed 100644 (file)
@@ -308,7 +308,19 @@ class CriticalSection(object):
       all_args_dict.pop(input_.op._id, None)
     all_args_dict.pop(lock_op._id, None)
 
-    lock_op._add_control_inputs(all_args_dict.values())
+    all_args = all_args_dict.values()
+
+    if not all_args:
+      # No control dependencies to add; return early.
+      return
+
+    # This group is important: it ensures that any ops in all_args
+    # outside the control context of the lock_op (and this fn, which
+    # runs in the same context) are added to this context before
+    # being added to the control dependencies of lock_op.
+    all_args = control_flow_ops.group(*all_args)
+
+    lock_op._add_control_input(all_args)
     # pylint: enable=protected-access
 
   def _is_self_handle(self, x):
index e24140b..ba66029 100644 (file)
@@ -316,6 +316,20 @@ class CriticalSectionTest(test.TestCase):
         ValueError, "requested exclusive resource access"):
       cs1.execute(lambda: v2 + 1)
 
+  def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
+    cs = critical_section_ops.CriticalSection()
+    v = resource_variable_ops.ResourceVariable(0, name="v")
+    # Make sure that the control dependencies on v do not cause issues
+    # in the lock_op's automatic control dependency adder.
+    #
+    # Note, here v must be a resource variable (or something similar),
+    # otherwise it gets hoisted into the while_loop by the time we add
+    # control dependencies to the lock_op.
+    out = control_flow_ops.while_loop(
+        lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
+    self.evaluate(v.initializer)
+    self.assertEqual(10, self.evaluate(out))
+
   # TODO(ebrevdo): Re-enable once CriticalSection is in core.
   #
   # def testCriticalSectionAndExecuteOpSaverRoundTrip(self):