Automated g4 rollback of changelist 184573795
authorAlexandre Passos <apassos@google.com>
Mon, 5 Feb 2018 22:47:23 +0000 (14:47 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 22:51:48 +0000 (14:51 -0800)
PiperOrigin-RevId: 184590080

tensorflow/python/kernel_tests/control_flow_ops_py_test.py
tensorflow/python/ops/control_flow_ops.py

index 15ff0ec..4fafc36 100644 (file)
@@ -704,36 +704,6 @@ class ControlFlowTest(test.TestCase):
       r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
       self.assertEqual(10000, r.eval())
 
-  def testWhileExternalControlDependencies(self):
-    with self.test_session():
-      v = variables.Variable(0.0)
-      v.initializer.run()
-      increment = v.assign_add(1.0)
-
-      def body_fn(i):
-        with ops.control_dependencies([increment]):
-          return i + i
-
-      result = control_flow_ops.while_loop(cond=lambda i: i < 1,
-                                           body=body_fn, loop_vars=[1])
-      result.eval()
-      self.assertAllEqual(v.eval(), 1.0)
-
-  def testWhileExternalControlDependenciesNoInput(self):
-    with self.test_session():
-      v = variables.Variable(0.0)
-      v.initializer.run()
-      increment = v.assign_add(1.0)
-
-      def body_fn(unused_i):
-        with ops.control_dependencies([increment]):
-          return constant_op.constant(5, name="five")
-
-      result = control_flow_ops.while_loop(cond=lambda i: i < 5,
-                                           body=body_fn, loop_vars=[0])
-      result.eval()
-      self.assertAllEqual(v.eval(), 1.0)
-
   def testWhileWithRefs_1(self):
     with self.test_session() as sess:
       x = variables.Variable(0)._ref()  # pylint: disable=protected-access
index 87ff0ab..bcd187d 100644 (file)
@@ -1631,13 +1631,10 @@ class ControlFlowContext(object):
         ctxt = util.GetOutputContext(x)
         if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
           internal_control_inputs.append(x)
-    external_control_inputs = []
     if len(internal_control_inputs) != len(op.control_inputs):
-      external_control_inputs = list(set(op.control_inputs)
-                                     - set(internal_control_inputs))
       op._remove_all_control_inputs()
       op._add_control_inputs(internal_control_inputs)
-    return internal_control_inputs, external_control_inputs
+    return internal_control_inputs
 
   # pylint: enable=protected-access
 
@@ -2435,12 +2432,14 @@ class WhileContext(ControlFlowContext):
   def _AddOpInternal(self, op):
     """Add `op` to the current context.
 
-    We move any external control dependencies of the op to the loop pivot, to
-    ensure they get executed.
+    In the case that op has only external data inputs, we remove all of its
+    external control inputs so all its inputs are in the same while loop
+    context. This is valid because op now has an Enter input that has all
+    the right control dependency.
     """
     if not op.inputs:
       # Remove any external control dependency on this op
-      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+      control_inputs = self._RemoveExternalControlEdges(op)
       # Add a control edge from the control pivot to this op.
       if not control_inputs:
         # pylint: disable=protected-access
@@ -2453,20 +2452,14 @@ class WhileContext(ControlFlowContext):
         x = op.inputs[index]
         real_x = self.AddValue(x)
         if real_x != x:
-          op._update_input(index, real_x)  # pylint: disable=protected-access
-      # Remove any external control dependency on this op and move then to an
-      # Enter node.
-      _, external_inputs = self._RemoveExternalControlEdges(op)
+          op._update_input(index, real_x)
+      # Remove any external control dependency on this op.
+      self._RemoveExternalControlEdges(op)
       # Add a control dependency to prevent loop invariants from
       # enabling ops that should not be executed.
       self._MaybeAddControlDependency(op)
       for x in op.outputs:
         self._values.add(x.name)
-    if external_inputs:
-      # Make the pivot depend on external control inputs
-      pred = self._pivot_for_pred.op.inputs[0]
-      assert util.IsLoopEnter(pred.op)
-      pred.op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
     if self._outer_context or not util.IsLoopExit(op):
       op.graph.prevent_fetching(op)
       for x in op.outputs: