Enable de/serialization of nested control flow.
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 26 Feb 2018 22:38:31 +0000 (14:38 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
This is a follow-up to the previous commit
(https://github.com/tensorflow/tensorflow/commit/23851760b7b099214bdd4f1b88156d7ac2bdd2a2).
It adds the new proto schemas, enables the behavior for reading and
writing the new protos, and adds a test for de/serializing nested
while loops.

There's still a bug preventing deserializing conds, which will be addressed
in another change.

PiperOrigin-RevId: 187082713

tensorflow/core/protobuf/control_flow.proto
tensorflow/python/ops/control_flow_ops.py
tensorflow/python/training/saver_test.py

index 2c9476a..3c05b4f 100644 (file)
@@ -17,6 +17,15 @@ message ValuesDef {
   map<string, string> external_values = 2;
 }
 
+// Container for any kind of control flow context. Any other control flow
+// contexts that are added below should also be added here.
+message ControlFlowContextDef {
+  oneof ctxt {
+    CondContextDef cond_ctxt = 1;
+    WhileContextDef while_ctxt = 2;
+  }
+}
+
 // Protocol buffer representing a CondContext object.
 message CondContextDef {
   // Name of the context.
@@ -33,6 +42,9 @@ message CondContextDef {
 
   // Values and external values in control flow context.
   ValuesDef values_def = 5;
+
+  // Contexts contained inside this context (e.g. nested conds).
+  repeated ControlFlowContextDef nested_contexts = 6;
 }
 
 // Protocol buffer representing a WhileContext object.
@@ -70,5 +82,8 @@ message WhileContextDef {
   // Optional name of the maximum_iterations tensor.
   string maximum_iterations_name = 11;
 
-  // Next available id: 12.
+  // Contexts contained inside this context (e.g. nested whiles).
+  repeated ControlFlowContextDef nested_contexts = 12;
+
+  // Next available id: 13.
 }
index 152578c..b16901e 100644 (file)
@@ -1765,13 +1765,9 @@ class CondContext(ControlFlowContext):
       context_def.branch = self._branch
       context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
           export_scope))
-      # TODO(b/72868227): enable this once the corresponding control_flow.proto
-      # changes have been checked in (they aren't checked in and this is
-      # disabled for now to ensure forwards compatibility).
-      if False:  # pylint: disable=using-constant-test
-        for nested in self._nested_contexts:
-          nested_def = context_def.nested_contexts.add()
-          nested.to_control_flow_context_def(nested_def)
+      for nested in self._nested_contexts:
+        nested_def = context_def.nested_contexts.add()
+        nested.to_control_flow_context_def(nested_def)
 
       return context_def
     else:
@@ -1783,14 +1779,10 @@ class CondContext(ControlFlowContext):
     ret = CondContext(context_def=context_def,
                       import_scope=import_scope)
 
-    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
-    # control_flow.proto changes have been checked in (they aren't checked in
-    # and this is here for now to ensure forwards compatibility).
-    if hasattr(context_def, "nested_contexts"):
-      ret.Enter()
-      for nested_def in context_def.nested_contexts:
-        from_control_flow_context_def(nested_def)
-      ret.Exit()
+    ret.Enter()
+    for nested_def in context_def.nested_contexts:
+      from_control_flow_context_def(nested_def)
+    ret.Exit()
     return ret
 
   def to_control_flow_context_def(self, context_def, export_scope=None):
@@ -2108,10 +2100,7 @@ def cond(pred,
     # Only add non-nested conds to the collection. Any nested control flow will
     # be encapsulated in the root context.
     assert context_t.outer_context == context_f.outer_context
-    # TODO(b/72868227): remove "if True..." once the corresponding
-    # control_flow.proto changes have been checked in (they aren't checked in
-    # and this is disabled for now to ensure forwards compatibility).
-    if True or context_t.outer_context is None:
+    if context_t.outer_context is None:
       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
 
@@ -2334,13 +2323,9 @@ class WhileContext(ControlFlowContext):
       context_def.values_def.MergeFrom(
           super(WhileContext, self)._to_values_def(
               export_scope=export_scope))
-      # TODO(b/72868227): remove "if True..." once the corresponding
-      # control_flow.proto changes have been checked in (they aren't checked in
-      # and this is disabled for now to ensure forwards compatibility).
-      if False:  # pylint: disable=using-constant-test
-        for nested in self._nested_contexts:
-          nested_def = context_def.nested_contexts.add()
-          nested.to_control_flow_context_def(nested_def)
+      for nested in self._nested_contexts:
+        nested_def = context_def.nested_contexts.add()
+        nested.to_control_flow_context_def(nested_def)
 
       return context_def
     else:
@@ -2362,14 +2347,10 @@ class WhileContext(ControlFlowContext):
     """
     ret = WhileContext(context_def=context_def,
                        import_scope=import_scope)
-    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
-    # control_flow.proto changes have been checked in (they aren't checked in
-    # and this is disabled for now to ensure forwards compatibility).
-    if hasattr(context_def, "nested_contexts"):
-      ret.Enter()
-      for nested_def in context_def.nested_contexts:
-        from_control_flow_context_def(nested_def, import_scope=import_scope)
-      ret.Exit()
+    ret.Enter()
+    for nested_def in context_def.nested_contexts:
+      from_control_flow_context_def(nested_def, import_scope=import_scope)
+    ret.Exit()
     return ret
 
   def GetWhileContext(self):
@@ -3214,10 +3195,7 @@ def while_loop(cond,
         swap_memory=swap_memory)
     # Only add non-nested loops to the collection. Any nested control flow will
     # be encapsulated in the root context.
-    # TODO(b/72868227): enable condition once the corresponding
-    # control_flow.proto changes have been checked in (they aren't checked in
-    # and this is disabled for now to ensure forwards compatibility).
-    if True or loop_context.outer_context is None:
+    if loop_context.outer_context is None:
       ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
     if maximum_iterations is not None:
index f00f98d..b366ed3 100644 (file)
@@ -53,6 +53,7 @@ from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import partitioned_variables
@@ -2040,6 +2041,61 @@ class MetaGraphTest(test.TestCase):
     self._testGraphExtensionRestore(test_dir)
     self._testRestoreFromTrainGraphWithControlContext(test_dir)
 
+  def testNestedWhileLoops(self):
+    test_dir = self._get_test_dir("nested_whiles")
+    filename = os.path.join(test_dir, "metafile")
+    saver_ckpt = os.path.join(test_dir, "saver.ckpt")
+
+    # Create two simple nested while loops.
+    with ops_lib.Graph().as_default():
+      def body(i, x):
+        _, r = control_flow_ops.while_loop(lambda j, y: j < 3,
+                                           lambda j, y: (j + 1, y + x),
+                                           [0, 0])
+        return i + 1, x + r
+
+      var = variables.Variable(0)
+      var_name = var.name
+
+      _, output = control_flow_ops.while_loop(lambda i, x: i < 5, body,
+                                              [0, var])
+      output_name = output.name
+
+      init_op = variables.global_variables_initializer()
+
+      # Generate a MetaGraphDef containing the nested loops.
+      with session.Session() as sess:
+        sess.run(init_op)
+        sess.run(output)
+        saver = saver_module.Saver()
+        saver.save(sess, saver_ckpt)
+        saver.export_meta_graph(filename)
+
+      # Build and run the gradients of the nested while loop. We use this below
+      # to verify that the gradients are correct with an imported MetaGraphDef.
+      grad = gradients_impl.gradients([output], [var])
+      with session.Session() as sess:
+        sess.run(init_op)
+        expected_grad_value = sess.run(grad)
+
+    # Restore the MetaGraphDef into a new Graph.
+    with ops_lib.Graph().as_default():
+      with session.Session() as sess:
+        saver = saver_module.import_meta_graph(filename)
+        saver.restore(sess, saver_ckpt)
+
+      # Make sure we can still build gradients and get the same result.
+      var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
+      output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
+      grad = gradients_impl.gradients([output], [var])
+
+      init_op = variables.global_variables_initializer()
+
+      with session.Session() as sess:
+        sess.run(init_op)
+        actual_grad_value = sess.run(grad)
+        self.assertEqual(expected_grad_value, actual_grad_value)
+
   def testStrippedOpListDef(self):
     with self.test_session():
       # Creates a graph.