Add IsCondSwitch.
authorJacques Pienaar <jpienaar@google.com>
Wed, 9 May 2018 20:03:45 +0000 (13:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 20:21:02 +0000 (13:21 -0700)
* Switch nodes are not part of the cond contexts of the tf.cond that they are the switches for, so check the contexts of the outputs of the switch to determine if a cond switch.
* Include the pivot of a cond in its cond context (there is one pivot per CondContext)
* If a cond is nested in a while loop, then the switch nodes of the cond is in the control flow context of the while loop, so only return that it is a loop switch if it isn't a cond switch.

PiperOrigin-RevId: 196015879

tensorflow/python/kernel_tests/control_flow_util_test.py
tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/control_flow_util.py

index 39e96f7..5138ad5 100644 (file)
@@ -19,9 +19,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import gen_control_flow_ops
 from tensorflow.python.platform import test
 
@@ -66,6 +70,80 @@ class ControlFlowUtilTest(test.TestCase):
 
     self.assertFalse(control_flow_util.IsLoopExit(test_ops.int_output().op))
 
+  def build_test_graph(self):
+    g = ops.Graph()
+    with g.as_default():
+
+      def while_loop(x):
+
+        def b(x):
+          with ops.name_scope("NestedCond"):
+            return control_flow_ops.cond(
+                math_ops.less(x, 100), lambda: math_ops.add(x, 1),
+                lambda: math_ops.add(x, 2))
+
+        c = lambda x: math_ops.less(x, 10000)
+        with ops.name_scope("OuterWhile"):
+          return control_flow_ops.while_loop(c, b, [x])
+
+      x = array_ops.placeholder(dtypes.int32)
+      with ops.name_scope("OuterCond"):
+        control_flow_ops.cond(
+            math_ops.less(x, 1000), lambda: while_loop(x),
+            lambda: math_ops.add(x, 2))
+    return g
+
+  def testIsCondSwitch(self):
+    g = self.build_test_graph()
+
+    cond_switch = [
+        "OuterCond/cond/Switch",
+        "OuterCond/cond/OuterWhile/while/Switch",
+        "OuterCond/cond/OuterWhile/while/NestedCond/cond/Switch",
+        "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add/Switch",
+        "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add_1/Switch",
+        "OuterCond/cond/Add/Switch",
+    ]
+    for n in g.get_operations():
+      if control_flow_util.IsSwitch(n):
+        self.assertTrue(
+            control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+                n))
+      if n.name in cond_switch:
+        self.assertTrue(control_flow_util.IsSwitch(n))
+        self.assertTrue(
+            control_flow_util.IsCondSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+        self.assertFalse(
+            control_flow_util.IsLoopSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+      else:
+        self.assertFalse(
+            control_flow_util.IsCondSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+
+  def testIsLoopSwitch(self):
+    g = self.build_test_graph()
+
+    loop_switch = ["OuterCond/cond/OuterWhile/while/Switch_1"]
+    for n in g.get_operations():
+      if control_flow_util.IsSwitch(n):
+        self.assertTrue(
+            control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+                n))
+      if n.name in loop_switch:
+        self.assertTrue(control_flow_util.IsSwitch(n))
+        self.assertFalse(
+            control_flow_util.IsCondSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+        self.assertTrue(
+            control_flow_util.IsLoopSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+      else:
+        self.assertFalse(
+            control_flow_util.IsLoopSwitch(n),
+            msg="Mismatch for {}".format(n.name))
+
 
 if __name__ == "__main__":
   test.main()
index 5f60dab..5ebdb19 100644 (file)
@@ -1685,12 +1685,12 @@ class CondContext(ControlFlowContext):
       self._pivot = pivot  # The predicate tensor in this branch
       self._branch = branch  # 0 or 1 representing this branch
 
-      # Values considered to have been already seen in this context. They are
-      # not included in this context.
+      # Values considered to have been already seen in this context. pred is not
+      # included in this context.
       self._values.add(pred.name)
       self._external_values[pred.name] = pred
       self._values.add(pivot.name)
-      self._external_values[pivot.name] = pivot
+      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
 
   def _init_from_proto(self, context_def, import_scope=None):
     """Creates a new `CondContext` from protocol buffer.
index eee3110..41f16ac 100644 (file)
@@ -63,11 +63,32 @@ def IsLoopExit(op):
   return op.type == "Exit" or op.type == "RefExit"
 
 
+def IsCondSwitch(op):
+  """Return true if `op` is the Switch for a conditional."""
+  if not IsSwitch(op):
+    return False
+  if not op.outputs:
+    return False
+  # Switch nodes are not part of the cond control flow context that they
+  # represent, so consider the consumers of its outputs to determine if it is
+  # cond switch or not. A switch is a cond switch iff all its consumers are in
+  # cond contexts.
+  is_cond_switch = True
+  for o in op.outputs:
+    for c in o.consumers():
+      ctxt = c._get_control_flow_context()  # pylint: disable=protected-access
+      if IsLoopEnter(c):
+        ctxt = ctxt.outer_context
+      is_cond_switch = is_cond_switch and (ctxt is not None and
+                                           ctxt.IsCondContext())
+  return is_cond_switch
+
+
 def IsLoopSwitch(op):
   """Return true if `op` is the Switch for a while loop."""
   if IsSwitch(op):
     ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
-    return ctxt and ctxt.IsWhileContext()
+    return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
   return False