Scope and decorator to automatically add control dependencies.
authorAlexandre Passos <apassos@google.com>
Mon, 12 Feb 2018 21:00:57 +0000 (13:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 21:04:51 +0000 (13:04 -0800)
Should mimic the desired behavior of eager code.

For now supports only straight-line code and conditionals.

PiperOrigin-RevId: 185421760

tensorflow/python/eager/function.py
tensorflow/python/eager/function_test.py

index 767d719..d352d67 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes as dtypes_module
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
@@ -813,3 +814,208 @@ def make_defun_op(func, *args, **kwds):
   if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
     raise ValueError("Tensor keyword arguments are not supported.")
   return _defun_internal(name, func, args, kwds)
+
+
+class AutomaticControlDependencies(object):
+  """Context manager to automatically add control dependencies.
+
+  Code under this context manager will act as if a sensible set of control
+  dependencies were present. More specifically:
+    1. All stateful ops in the scope will execute
+    2. Stateful ops which modify the same resource will execute in program order
+
+  Note: creating variables in an automatic control dependencies context is not
+  supported (the value of the variables will never change as they will keep
+  getting reinitialized).
+
+  NOT THREAD SAFE
+  """
+
+  def __init__(self):
+    self._returned_tensors = set()
+
+  def mark_as_return(self, tensor):
+    self._returned_tensors.add(tensor)
+
+  def __enter__(self):
+    if context.in_eager_mode():
+      return self
+    # This code assumes no other thread is adding ops to the graph while
+    # we're adding ops to the graph.
+    # TODO(apassos): Fix this by locking the graph or using a temporary
+    # graph (but that would mess up devices and collections at least,
+    # probably other things as well).
+    self._graph = ops.get_default_graph()
+    self._n_operations = len(self._graph.get_operations())
+    return self
+
+  def _process_switch(self, switch_op, ops_which_must_run,
+                      last_op_using_resource_tensor, merge_for_resource):
+    """Processes a switch node for a resource input.
+
+    When tensorflow creates a cond, it creates a control flow context for each
+    branch of the cond. Each external tensor accessed by that branch is routed
+    through a switch op, which gets created in the graph _after_ the op which
+    uses that tensor get created.
+
+    If the resource comes from another switch op we process that one first.
+
+    _process_switch creates a corresponding merge node for the switch node. This
+    merge node is added to the outer control flow context of the switch
+    node. We also ensure that:
+
+      1. The switch node executes after the previous op which used the resource
+         tensor
+
+      2. Any op which uses a resource output of the switch node executes before
+         the merge for the switch node.
+
+      3. The next op which uses the input resource to the switch node (which
+         might be another switch node for the other branch of the conditional)
+         will execute after the merge node is done.
+
+      4. The merge node is marked as must_run so it will run even if no
+         subsequent operation uses the resource.
+
+    Args:
+      switch_op: the switch op to be processed
+      ops_which_must_run: the set of ops which must run
+      last_op_using_resource_tensor: map from resource tensor to last op using
+        it
+      merge_for_resource: map from resource tensor to merge which must follow
+        all usages of it.
+    """
+    inp = switch_op.inputs[0]
+    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
+      self._process_switch(inp.op, ops_which_must_run,
+                           last_op_using_resource_tensor, merge_for_resource)
+    if switch_op.outputs[0] in merge_for_resource:
+      return
+    new_merge = control_flow_ops.merge(switch_op.outputs,
+                                       name="artificial_merge")
+    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
+        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
+    # Ensures the merge always runs
+    ops_which_must_run.add(new_merge[0].op)
+    if inp in last_op_using_resource_tensor:
+      # Ensures the switch exectutes after the previous op using the resource.
+      switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
+    # Ensure the next op outside the cond happens after the merge.
+    last_op_using_resource_tensor[inp] = new_merge[0].op
+    if inp in merge_for_resource:
+      merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
+    for o in switch_op.outputs:
+      # Ensures the merge will execute after all ops inside the cond
+      merge_for_resource[o] = new_merge[0].op
+
+  def __exit__(self, unused_type, unused_value, unused_traceback):
+    if context.in_eager_mode():
+      return
+
+    if self._graph is not ops.get_default_graph():
+      raise RuntimeError(
+          "Graph changed while trying to add control dependencies.")
+
+    # map from resource tensor to the last op which used it
+    last_op_using_resource_tensor = {}
+    # set of conditional and loop exits
+    ops_which_must_run = set()
+    # merge which must depend on ops which use this resource
+    merge_for_resource = {}
+
+    new_operations = self._graph.get_operations()[self._n_operations:]
+
+    # Ensures that uses of resource tensors get serialized properly and all
+    # execute. This is done by keeping a map from resource tensor to the last op
+    # in graph-construction order which used it (last_op_using_resource_tensor).
+    #
+    # Conditionals are written in TensorFlow such that every external tensor
+    # accessed in the conditional goes through a switch op and every return
+    # tensor (it's guaranteed that there will be at least one) goes through a
+    # merge op.
+    #
+    # To handle conditionals, switches are handled in a special way (see
+    # comments for _process_switch). Merge nodes created by TF's conditional
+    # logic (as opposed to by _process_switch) are forced to run and also get a
+    # control dependency added to them to ensure all stateful ops inside their
+    # control flow context run.
+    #
+    # We also ensure that if an op is using a resource output by a switch node
+    # (that is, a resource tensor for which there's a value in
+    # merge_for_resource) this op will run before the merge for that resource.
+    #
+    # We try to add control inputs to nodes respecting their control flow
+    # contexts to avoid dead nodes propagating everywhere and leading to
+    # "retval[0] doesn't have value" errors. If a node gets a control dependency
+    # on a dead node (i.e. a note from an untaken control flow branch) that node
+    # will be marked as dead unless it's a merge node.
+    #
+    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
+    # test that it works. Support while loops. Support init_scope escaping from
+    # this.
+    for op in new_operations:
+      control_inputs = set()
+      # Ensure stateful ops run
+      if self._graph._registered_ops[op.type].is_stateful:  # pylint: disable=protected-access
+        ops_which_must_run.add(op)
+      # Ignore switches (they're handled separately)
+      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
+        continue
+      # Make merges trigger all other computation which must run
+      if op.type == "Merge":
+        for o in ops_which_must_run:
+          op._add_control_input(o)  # pylint: disable=protected-access
+          for inp in o.inputs:
+            if inp in last_op_using_resource_tensor:
+              last_op_using_resource_tensor[inp] = op
+        ops_which_must_run = set([op])
+        continue
+      for inp in op.inputs:
+        if inp.dtype == dtypes_module.resource:
+          # Deal with switches, finally.
+          if inp.op.type == "Switch":
+            self._process_switch(inp.op, ops_which_must_run,
+                                 last_op_using_resource_tensor,
+                                 merge_for_resource)
+          # Ensure uses of resources are serialized
+          if inp in last_op_using_resource_tensor:
+            if (last_op_using_resource_tensor[inp]._control_flow_context  # pylint: disable=protected-access
+                is op._control_flow_context):  # pylint: disable=protected-access
+              control_inputs.add(last_op_using_resource_tensor[inp])
+          # Ensure merges happen after the closing of a cond block
+          if inp in merge_for_resource:
+            merge_for_resource[inp]._add_control_input(op)  # pylint: disable=protected-access
+          last_op_using_resource_tensor[inp] = op
+      control_inputs = [c for c in control_inputs
+                        if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
+      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access
+
+    # Ensure all ops which must run do run
+    for r in self._returned_tensors:
+      r.op._add_control_inputs(  # pylint: disable=protected-access
+          [o for o in ops_which_must_run
+           if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access
+
+
+def automatic_control_dependencies(f):
+  """Wraps f to automatically insert control dependencies.
+
+  The inserted dependencies ensure that:
+    1. All stateful ops in f run when the result of f runs
+    2. Updates to the same resources happen in order.
+
+  Args:
+    f: the function to be wrapped.
+
+  Returns:
+    The wrapped function.
+  """
+
+  def wrapper(*args, **kwds):
+    with AutomaticControlDependencies() as a:
+      result = f(*args, **kwds)
+      for t in nest.flatten(result):
+        a.mark_as_return(t)
+      return result
+
+  return tf_decorator.make_decorator(f, wrapper)
index 3e8e67a..431d938 100644 (file)
@@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
@@ -595,5 +596,172 @@ class FunctionTest(test.TestCase):
         create_variable()
 
 
+class AutomaticControlDependenciesTest(test.TestCase):
+
+  def testBasic(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      with function.AutomaticControlDependencies() as c:
+        v.assign(v + 1)
+        v.assign(2 * v)
+        val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(), 4.0)
+
+  def testCondMustRun(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+
+        def true_fn():
+          v.assign(v + 1)
+          return 0.0
+
+        def false_fn():
+          v.assign(v + 4)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
+
+  def testCondMustRunSeparateRead(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+
+        def true_fn():
+          v.assign(v + 1)
+          return 0.0
+
+        def false_fn():
+          v.assign(v + 4)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        one = constant_op.constant(1.0)
+        c.mark_as_return(one)
+      one.eval(feed_dict={p: False})
+      self.assertAllEqual(v.read_value().eval(), 5.0)
+      one.eval(feed_dict={p: True})
+      self.assertAllEqual(v.read_value().eval(), 6.0)
+
+  def testCondNested(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      q = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+
+        def true_fn():
+          v.assign(v + 1, name='true')
+          return 1.0
+
+        def false_fn():
+
+          def inner_true_fn():
+            v.assign(v * 2, name='false_true')
+            return 2.0
+
+          def inner_false_fn():
+            v.assign(v * 3, name='false_false')
+            return 3.0
+
+          control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        with ops.name_scope('final'):
+          val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
+      self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
+
+  def testCondOneBranch(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+
+        def true_fn():
+          return 0.0
+
+        def false_fn():
+          v.assign(v + 4)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
+
+  def testCondOneBranchUpdateBefore(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+        v.assign(v * 2)
+
+        def true_fn():
+          return 0.0
+
+        def false_fn():
+          v.assign(v + 4)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
+
+  def testCondOneBranchUpdateAfter(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+      p = array_ops.placeholder(dtype=dtypes.bool)
+      with function.AutomaticControlDependencies() as c:
+
+        def true_fn():
+          return 0.0
+
+        def false_fn():
+          v.assign(v + 4)
+          return 1.0
+
+        control_flow_ops.cond(p, true_fn, false_fn)
+        v.assign(v * 2)
+        val = v.read_value()
+        c.mark_as_return(val)
+      self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0)
+      self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0)
+
+  def testDecorator(self):
+    with context.graph_mode(), self.test_session():
+      v = resource_variable_ops.ResourceVariable(1.0)
+      variables.global_variables_initializer().run()
+
+      @function.automatic_control_dependencies
+      def f():
+        v.assign(v + 1)
+        v.assign(2 * v)
+        return v.read_value()
+
+      self.assertAllEqual(f().eval(), 4.0)
+
+
 if __name__ == '__main__':
   test.main()