from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
self.assertFalse(control_flow_util.IsLoopExit(test_ops.int_output().op))
+ def build_test_graph(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ def while_loop(x):
+
+ def b(x):
+ with ops.name_scope("NestedCond"):
+ return control_flow_ops.cond(
+ math_ops.less(x, 100), lambda: math_ops.add(x, 1),
+ lambda: math_ops.add(x, 2))
+
+ c = lambda x: math_ops.less(x, 10000)
+ with ops.name_scope("OuterWhile"):
+ return control_flow_ops.while_loop(c, b, [x])
+
+ x = array_ops.placeholder(dtypes.int32)
+ with ops.name_scope("OuterCond"):
+ control_flow_ops.cond(
+ math_ops.less(x, 1000), lambda: while_loop(x),
+ lambda: math_ops.add(x, 2))
+ return g
+
+ def testIsCondSwitch(self):
+ g = self.build_test_graph()
+
+ cond_switch = [
+ "OuterCond/cond/Switch",
+ "OuterCond/cond/OuterWhile/while/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add_1/Switch",
+ "OuterCond/cond/Add/Switch",
+ ]
+ for n in g.get_operations():
+ if control_flow_util.IsSwitch(n):
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+ n))
+ if n.name in cond_switch:
+ self.assertTrue(control_flow_util.IsSwitch(n))
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ self.assertFalse(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ else:
+ self.assertFalse(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+
+ def testIsLoopSwitch(self):
+ g = self.build_test_graph()
+
+ loop_switch = ["OuterCond/cond/OuterWhile/while/Switch_1"]
+ for n in g.get_operations():
+ if control_flow_util.IsSwitch(n):
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+ n))
+ if n.name in loop_switch:
+ self.assertTrue(control_flow_util.IsSwitch(n))
+ self.assertFalse(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ self.assertTrue(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ else:
+ self.assertFalse(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+
if __name__ == "__main__":
test.main()
self._pivot = pivot # The predicate tensor in this branch
self._branch = branch # 0 or 1 representing this branch
- # Values considered to have been already seen in this context. They are
- # not included in this context.
+ # Values considered to have been already seen in this context. pred is not
+ # included in this context.
self._values.add(pred.name)
self._external_values[pred.name] = pred
self._values.add(pivot.name)
- self._external_values[pivot.name] = pivot
+ pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `CondContext` from protocol buffer.
return op.type == "Exit" or op.type == "RefExit"
+def IsCondSwitch(op):
+ """Return true if `op` is the Switch for a conditional."""
+ if not IsSwitch(op):
+ return False
+ if not op.outputs:
+ return False
+ # Switch nodes are not part of the cond control flow context that they
+ # represent, so consider the consumers of its outputs to determine if it is
+ # cond switch or not. A switch is a cond switch iff all its consumers are in
+ # cond contexts.
+ is_cond_switch = True
+ for o in op.outputs:
+ for c in o.consumers():
+ ctxt = c._get_control_flow_context() # pylint: disable=protected-access
+ if IsLoopEnter(c):
+ ctxt = ctxt.outer_context
+ is_cond_switch = is_cond_switch and (ctxt is not None and
+ ctxt.IsCondContext())
+ return is_cond_switch
+
+
def IsLoopSwitch(op):
"""Return true if `op` is the Switch for a while loop."""
if IsSwitch(op):
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
- return ctxt and ctxt.IsWhileContext()
+ return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
return False