from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
"""
def __init__(self, values_def=None, import_scope=None):
+ self._nested_contexts = []
self._outer_context = ops.get_default_graph()._get_control_flow_context()
+ if self._outer_context:
+ self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access
self._context_stack = []
if values_def:
self._init_values_from_proto(values_def, import_scope=import_scope)
def back_prop(self):
raise NotImplementedError("Abstract method")
- def _to_proto(self, export_scope=None):
+ @abc.abstractmethod
+ def to_control_flow_context_def(self, context_def, export_scope=None):
+ """Serializes this into `context_def`.
+
+ Args:
+ context_def: a `ControlFlowContextDef` protocol buffer.
+ export_scope: Optional `string`. Name scope to remove.
+ """
+ raise NotImplementedError("Abstract method")
+
+ def _to_values_def(self, export_scope=None):
"""Converts the values to a `ValuesDef` protocol buffer.
Args:
values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
return values_def
- @staticmethod
- def _from_proto(values_def, import_scope=None):
- """Returns a `ControlFlowContext` created from `values_def`."""
- return ControlFlowContext(values_def=values_def, import_scope=import_scope)
-
def AddName(self, name):
self._values.add(name)
context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
export_scope)
context_def.branch = self._branch
- context_def.values_def.MergeFrom(
- super(CondContext, self)._to_proto(export_scope))
+ context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
+ export_scope))
+ # TODO(b/72868227): enable this once the corresponding control_flow.proto
+ # changes have been checked in (they aren't checked in and this is
+ # disabled for now to ensure forwards compatibility).
+ if False: # pylint: disable=using-constant-test
+ for nested in self._nested_contexts:
+ nested_def = context_def.nested_contexts.add()
+ nested.to_control_flow_context_def(nested_def)
return context_def
else:
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `CondContext` object created from `context_def`."""
- return CondContext(context_def=context_def, import_scope=import_scope)
+ ret = CondContext(context_def=context_def,
+ import_scope=import_scope)
+
+ # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
+ # control_flow.proto changes have been checked in (they aren't checked in
+ # and this is here for now to ensure forwards compatibility).
+ if hasattr(context_def, "nested_contexts"):
+ ret.Enter()
+ for nested_def in context_def.nested_contexts:
+ from_control_flow_context_def(nested_def)
+ ret.Exit()
+ return ret
+
+ def to_control_flow_context_def(self, context_def, export_scope=None):
+ context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
- # Add to collections
- ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
- ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
+ # Only add non-nested conds to the collection. Any nested control flow will
+ # be encapsulated in the root context.
+ assert context_t.outer_context == context_f.outer_context
+ # TODO(b/72868227): remove "if True..." once the corresponding
+ # control_flow.proto changes have been checked in (they aren't checked in
+ # and this is disabled for now to ensure forwards compatibility).
+ if True or context_t.outer_context is None:
+ ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
+ ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
])
context_def.values_def.MergeFrom(
- super(WhileContext, self)._to_proto(export_scope=export_scope))
+ super(WhileContext, self)._to_values_def(
+ export_scope=export_scope))
+ # TODO(b/72868227): remove "if True..." once the corresponding
+ # control_flow.proto changes have been checked in (they aren't checked in
+ # and this is disabled for now to ensure forwards compatibility).
+ if False: # pylint: disable=using-constant-test
+ for nested in self._nested_contexts:
+ nested_def = context_def.nested_contexts.add()
+ nested.to_control_flow_context_def(nested_def)
return context_def
else:
return None
+ def to_control_flow_context_def(self, context_def, export_scope=None):
+ context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
+
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `WhileContext` object created from `context_def`.
Returns:
A `WhileContext` Python object.
"""
- return WhileContext(context_def=context_def, import_scope=import_scope)
+ ret = WhileContext(context_def=context_def,
+ import_scope=import_scope)
+ # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
+ # control_flow.proto changes have been checked in (they aren't checked in
+ # and this is disabled for now to ensure forwards compatibility).
+ if hasattr(context_def, "nested_contexts"):
+ ret.Enter()
+ for nested_def in context_def.nested_contexts:
+ from_control_flow_context_def(nested_def, import_scope=import_scope)
+ ret.Exit()
+ return ret
def GetWhileContext(self):
return self
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
- ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
+ # Only add non-nested loops to the collection. Any nested control flow will
+ # be encapsulated in the root context.
+ # TODO(b/72868227): enable condition once the corresponding
+ # control_flow.proto changes have been checked in (they aren't checked in
+ # and this is disabled for now to ensure forwards compatibility).
+ if True or loop_context.outer_context is None:
+ ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
if maximum_iterations is not None:
return result[1]
return x
+def from_control_flow_context_def(context_def, import_scope=None):
+ """Deserializes `context_def` into the appropriate ControlFlowContext.
+
+ Args:
+ context_def: ControlFlowContextDef proto
+ import_scope: Optional `string`. Name scope to add.
+
+ Returns:
+ A ControlFlowContext subclass
+ """
+ if context_def.HasField("cond_ctxt"):
+ return CondContext.from_proto(context_def.cond_ctxt,
+ import_scope=import_scope)
+ if context_def.HasField("while_ctxt"):
+ return WhileContext.from_proto(context_def.while_ctxt,
+ import_scope=import_scope)
+ raise NotImplementedError("Unknown ControlFlowContextDef field: %s"
+ % context_def.WhichOneof("ctxt"))
+
+
ops.register_proto_function(
ops.GraphKeys.COND_CONTEXT,
proto_type=control_flow_pb2.CondContextDef,