From f8f921c828fb2c97da7c7b80c01390ccec90ae40 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 5 Feb 2018 13:03:53 -0800 Subject: [PATCH] Fixes issue where external control dependencies in while loops are dropped. Fixes #15891 PiperOrigin-RevId: 184573795 --- .../kernel_tests/control_flow_ops_py_test.py | 30 ++++++++++++++++++++++ tensorflow/python/ops/control_flow_ops.py | 25 +++++++++++------- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 4fafc36..15ff0ec 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -704,6 +704,36 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual(10000, r.eval()) + def testWhileExternalControlDependencies(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(i): + with ops.control_dependencies([increment]): + return i + i + + result = control_flow_ops.while_loop(cond=lambda i: i < 1, + body=body_fn, loop_vars=[1]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + + def testWhileExternalControlDependenciesNoInput(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(unused_i): + with ops.control_dependencies([increment]): + return constant_op.constant(5, name="five") + + result = control_flow_ops.while_loop(cond=lambda i: i < 5, + body=body_fn, loop_vars=[0]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + def testWhileWithRefs_1(self): with self.test_session() as sess: x = variables.Variable(0)._ref() # pylint: disable=protected-access diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index bcd187d..87ff0ab 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1631,10 +1631,13 @@ class ControlFlowContext(object): ctxt = util.GetOutputContext(x) if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: internal_control_inputs.append(x) + external_control_inputs = [] if len(internal_control_inputs) != len(op.control_inputs): + external_control_inputs = list(set(op.control_inputs) + - set(internal_control_inputs)) op._remove_all_control_inputs() op._add_control_inputs(internal_control_inputs) - return internal_control_inputs + return internal_control_inputs, external_control_inputs # pylint: enable=protected-access @@ -2432,14 +2435,12 @@ class WhileContext(ControlFlowContext): def _AddOpInternal(self, op): """Add `op` to the current context. - In the case that op has only external data inputs, we remove all of its - external control inputs so all its inputs are in the same while loop - context. This is valid because op now has an Enter input that has all - the right control dependency. + We move any external control dependencies of the op to the loop pivot, to + ensure they get executed. """ if not op.inputs: # Remove any external control dependency on this op - control_inputs = self._RemoveExternalControlEdges(op) + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) # Add a control edge from the control pivot to this op. if not control_inputs: # pylint: disable=protected-access @@ -2452,14 +2453,20 @@ class WhileContext(ControlFlowContext): x = op.inputs[index] real_x = self.AddValue(x) if real_x != x: - op._update_input(index, real_x) - # Remove any external control dependency on this op. - self._RemoveExternalControlEdges(op) + op._update_input(index, real_x) # pylint: disable=protected-access + # Remove any external control dependency on this op and move then to an + # Enter node. + _, external_inputs = self._RemoveExternalControlEdges(op) # Add a control dependency to prevent loop invariants from # enabling ops that should not be executed. self._MaybeAddControlDependency(op) for x in op.outputs: self._values.add(x.name) + if external_inputs: + # Make the pivot depend on external control inputs + pred = self._pivot_for_pred.op.inputs[0] + assert util.IsLoopEnter(pred.op) + pred.op._add_control_inputs(external_inputs) # pylint: disable=protected-access if self._outer_context or not util.IsLoopExit(op): op.graph.prevent_fetching(op) for x in op.outputs: -- 2.7.4