From 78d10e5800a058c6d1865c5282aaa4094f7bc36d Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 26 Feb 2018 19:58:18 -0800 Subject: [PATCH] Fix bug in deserializing CondContexts. PiperOrigin-RevId: 187121244 --- tensorflow/python/ops/control_flow_ops.py | 11 +++++-- tensorflow/python/training/saver_test.py | 49 +++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index b16901e..0815527 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1716,8 +1716,15 @@ class CondContext(ControlFlowContext): self._pivot = g.as_graph_element( ops.prepend_name_scope(context_def.pivot_name, import_scope)) self._branch = context_def.branch - super(CondContext, self).__init__( - values_def=context_def.values_def, import_scope=import_scope) + super(CondContext, self).__init__(values_def=context_def.values_def, + import_scope=import_scope) + # The predicate and pivot ops appear in self._values, but don't have self + # set as their control context. The __init__ call above will set self for + # all values, so manually override the predicate and pivot contexts here. + # pylint: disable=protected-access + self._pred.op._set_control_flow_context(self.outer_context) + self._pivot.op._set_control_flow_context(self.outer_context) + # pylint: enable=protected-access @property def pred(self): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b366ed3..b758cea 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -2041,29 +2041,24 @@ class MetaGraphTest(test.TestCase): self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) - def testNestedWhileLoops(self): - test_dir = self._get_test_dir("nested_whiles") + def _testWhileLoopAndGradientSerDes(self, outer_body_fn): + # Build a while loop with `outer_body_fn`, export it, and verify that it can + # be imported and the gradient can be built and run correctly. + + test_dir = self._get_test_dir("nested_control_flow") filename = os.path.join(test_dir, "metafile") saver_ckpt = os.path.join(test_dir, "saver.ckpt") - # Create two simple nested while loops. + # Create while loop using `outer_body_fn`. with ops_lib.Graph().as_default(): - def body(i, x): - _, r = control_flow_ops.while_loop(lambda j, y: j < 3, - lambda j, y: (j + 1, y + x), - [0, 0]) - return i + 1, x + r - var = variables.Variable(0) var_name = var.name - - _, output = control_flow_ops.while_loop(lambda i, x: i < 5, body, + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, outer_body_fn, [0, var]) output_name = output.name - init_op = variables.global_variables_initializer() - # Generate a MetaGraphDef containing the nested loops. + # Generate a MetaGraphDef containing the while loop. with session.Session() as sess: sess.run(init_op) sess.run(output) @@ -2071,8 +2066,8 @@ class MetaGraphTest(test.TestCase): saver.save(sess, saver_ckpt) saver.export_meta_graph(filename) - # Build and run the gradients of the nested while loop. We use this below - # to verify that the gradients are correct with an imported MetaGraphDef. + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with an imported MetaGraphDef. grad = gradients_impl.gradients([output], [var]) with session.Session() as sess: sess.run(init_op) @@ -2096,6 +2091,30 @@ class MetaGraphTest(test.TestCase): actual_grad_value = sess.run(grad) self.assertEqual(expected_grad_value, actual_grad_value) + def testNestedWhileLoopsSerDes(self): + # Test two simple nested while loops. + def body(i, x): + _, r = control_flow_ops.while_loop(lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0]) + return i + 1, x + r + self._testWhileLoopAndGradientSerDes(body) + + def testNestedControlFlowSerDes(self): + # Test while loop in a cond in a while loop. + # pylint: disable=g-long-lambda + def body(i, x): + cond_result = control_flow_ops.cond( + i > 0, + lambda: control_flow_ops.while_loop( + lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0])[1], + lambda: x) + return i + 1, cond_result + # pylint: enable=g-long-lambda + self._testWhileLoopAndGradientSerDes(body) + def testStrippedOpListDef(self): with self.test_session(): # Creates a graph. -- 2.7.4