Get control_flow_ops.py ready to support de/serializing nested control flow.
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 5 Feb 2018 16:41:54 +0000 (08:41 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 16:45:21 +0000 (08:45 -0800)
With this change, ControlFlowContexts keep track of their nested
contexts (the reverse lookup as ControlFlowContext.outer_context).
This is to enable de/serializing the nested contexts of each "root"
context, and only adding the root contexts to collections. This allows
for simple deserialization of each root context by recursively
deserializing its nested contexts.

The de/serialization logic is disabled and the corresponding
control_flow.proto changes are omitted for now for forwards
compatability (i.e. three-week-old binaries must be ready to accept
the new proto format once its commited). After this is committed for
three weeks, I'll commit a follow-up change enabling the new behavior.

Design note: I chose to serialize the nested contexts, rather than the
outer contexts, because it makes it easy to deserialize the contexts
in topological order and to assign the right outer context. If we
serialized the outer contexts, there'd need to be some mechanism for
either sorting all the serialized contexts first, or deserializing all
of them and then doing another pass to assign the outer contexts.

PiperOrigin-RevId: 184533406

tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/control_flow_ops_test.py

index 33a9263..bcd187d 100644 (file)
@@ -50,6 +50,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import collections
 import functools
 
@@ -1498,7 +1499,10 @@ class ControlFlowContext(object):
   """
 
   def __init__(self, values_def=None, import_scope=None):
+    self._nested_contexts = []
     self._outer_context = ops.get_default_graph()._get_control_flow_context()
+    if self._outer_context:
+      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
     self._context_stack = []
     if values_def:
       self._init_values_from_proto(values_def, import_scope=import_scope)
@@ -1551,7 +1555,17 @@ class ControlFlowContext(object):
   def back_prop(self):
     raise NotImplementedError("Abstract method")
 
-  def _to_proto(self, export_scope=None):
+  @abc.abstractmethod
+  def to_control_flow_context_def(self, context_def, export_scope=None):
+    """Serializes this into `context_def`.
+
+    Args:
+      context_def: a `ControlFlowContextDef` protocol buffer.
+      export_scope: Optional `string`. Name scope to remove.
+    """
+    raise NotImplementedError("Abstract method")
+
+  def _to_values_def(self, export_scope=None):
     """Converts the values to a `ValuesDef` protocol buffer.
 
     Args:
@@ -1568,11 +1582,6 @@ class ControlFlowContext(object):
       values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
     return values_def
 
-  @staticmethod
-  def _from_proto(values_def, import_scope=None):
-    """Returns a `ControlFlowContext` created from `values_def`."""
-    return ControlFlowContext(values_def=values_def, import_scope=import_scope)
-
   def AddName(self, name):
     self._values.add(name)
 
@@ -1751,8 +1760,15 @@ class CondContext(ControlFlowContext):
       context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
                                                     export_scope)
       context_def.branch = self._branch
-      context_def.values_def.MergeFrom(
-          super(CondContext, self)._to_proto(export_scope))
+      context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
+          export_scope))
+      # TODO(b/72868227): enable this once the corresponding control_flow.proto
+      # changes have been checked in (they aren't checked in and this is
+      # disabled for now to ensure forwards compatibility).
+      if False:  # pylint: disable=using-constant-test
+        for nested in self._nested_contexts:
+          nested_def = context_def.nested_contexts.add()
+          nested.to_control_flow_context_def(nested_def)
 
       return context_def
     else:
@@ -1761,7 +1777,21 @@ class CondContext(ControlFlowContext):
   @staticmethod
   def from_proto(context_def, import_scope=None):
     """Returns a `CondContext` object created from `context_def`."""
-    return CondContext(context_def=context_def, import_scope=import_scope)
+    ret = CondContext(context_def=context_def,
+                      import_scope=import_scope)
+
+    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
+    # control_flow.proto changes have been checked in (they aren't checked in
+    # and this is here for now to ensure forwards compatibility).
+    if hasattr(context_def, "nested_contexts"):
+      ret.Enter()
+      for nested_def in context_def.nested_contexts:
+        from_control_flow_context_def(nested_def)
+      ret.Exit()
+    return ret
+
+  def to_control_flow_context_def(self, context_def, export_scope=None):
+    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
 
   def AddValue(self, val):
     """Add `val` to the current context and its outer context recursively."""
@@ -2067,9 +2097,15 @@ def cond(pred,
     merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
     merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
 
-    # Add to collections
-    ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
-    ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
+    # Only add non-nested conds to the collection. Any nested control flow will
+    # be encapsulated in the root context.
+    assert context_t.outer_context == context_f.outer_context
+    # TODO(b/72868227): remove "if True..." once the corresponding
+    # control_flow.proto changes have been checked in (they aren't checked in
+    # and this is disabled for now to ensure forwards compatibility).
+    if True or context_t.outer_context is None:
+      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
+      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
 
     merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
 
@@ -2277,12 +2313,23 @@ class WhileContext(ControlFlowContext):
           ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
       ])
       context_def.values_def.MergeFrom(
-          super(WhileContext, self)._to_proto(export_scope=export_scope))
+          super(WhileContext, self)._to_values_def(
+              export_scope=export_scope))
+      # TODO(b/72868227): remove "if True..." once the corresponding
+      # control_flow.proto changes have been checked in (they aren't checked in
+      # and this is disabled for now to ensure forwards compatibility).
+      if False:  # pylint: disable=using-constant-test
+        for nested in self._nested_contexts:
+          nested_def = context_def.nested_contexts.add()
+          nested.to_control_flow_context_def(nested_def)
 
       return context_def
     else:
       return None
 
+  def to_control_flow_context_def(self, context_def, export_scope=None):
+    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
+
   @staticmethod
   def from_proto(context_def, import_scope=None):
     """Returns a `WhileContext` object created from `context_def`.
@@ -2294,7 +2341,17 @@ class WhileContext(ControlFlowContext):
     Returns:
       A `WhileContext` Python object.
     """
-    return WhileContext(context_def=context_def, import_scope=import_scope)
+    ret = WhileContext(context_def=context_def,
+                       import_scope=import_scope)
+    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
+    # control_flow.proto changes have been checked in (they aren't checked in
+    # and this is disabled for now to ensure forwards compatibility).
+    if hasattr(context_def, "nested_contexts"):
+      ret.Enter()
+      for nested_def in context_def.nested_contexts:
+        from_control_flow_context_def(nested_def, import_scope=import_scope)
+      ret.Exit()
+    return ret
 
   def GetWhileContext(self):
     return self
@@ -3092,7 +3149,13 @@ def while_loop(cond,
         parallel_iterations=parallel_iterations,
         back_prop=back_prop,
         swap_memory=swap_memory)
-    ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
+    # Only add non-nested loops to the collection. Any nested control flow will
+    # be encapsulated in the root context.
+    # TODO(b/72868227): enable condition once the corresponding
+    # control_flow.proto changes have been checked in (they aren't checked in
+    # and this is disabled for now to ensure forwards compatibility).
+    if True or loop_context.outer_context is None:
+      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
     if maximum_iterations is not None:
       return result[1]
@@ -3540,6 +3603,26 @@ class XLAControlFlowContext(ControlFlowContext):
     return x
 
 
+def from_control_flow_context_def(context_def, import_scope=None):
+  """Deserializes `context_def` into the appropriate ControlFlowContext.
+
+  Args:
+    context_def: ControlFlowContextDef proto
+    import_scope: Optional `string`. Name scope to add.
+
+  Returns:
+    A ControlFlowContext subclass
+  """
+  if context_def.HasField("cond_ctxt"):
+    return CondContext.from_proto(context_def.cond_ctxt,
+                                  import_scope=import_scope)
+  if context_def.HasField("while_ctxt"):
+    return WhileContext.from_proto(context_def.while_ctxt,
+                                   import_scope=import_scope)
+  raise NotImplementedError("Unknown ControlFlowContextDef field: %s"
+                            % context_def.WhichOneof("ctxt"))
+
+
 ops.register_proto_function(
     ops.GraphKeys.COND_CONTEXT,
     proto_type=control_flow_pb2.CondContextDef,
index cc5a42b..f942f47 100644 (file)
@@ -483,8 +483,8 @@ class ContextTest(test_util.TensorFlowTestCase):
       c._values = ["a", "b"]
       c._external_values = {"a": b1}
 
-      c_with_scope = control_flow_ops.ControlFlowContext._from_proto(
-          c._to_proto(), import_scope="test_scope")
+      c_with_scope = control_flow_ops.ControlFlowContext(
+          values_def=c._to_values_def(), import_scope="test_scope")
 
       # _values and _external_values should be have scope prepended.
       self.assertEquals(
@@ -494,8 +494,8 @@ class ContextTest(test_util.TensorFlowTestCase):
 
       # Calling _to_proto() with export_scope should remove "test_scope".
       self.assertProtoEquals(
-          c._to_proto(),
-          c_with_scope._to_proto(export_scope="test_scope"))
+          c._to_values_def(),
+          c_with_scope._to_values_def(export_scope="test_scope"))
 
 
 def _GetNestedShape(nested):