from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import compat
from tensorflow.python.util import nest
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
return _defun_internal(name, func, args, kwds)
+
+
+class AutomaticControlDependencies(object):
+ """Context manager to automatically add control dependencies.
+
+ Code under this context manager will act as if a sensible set of control
+ dependencies were present. More specifically:
+ 1. All stateful ops in the scope will execute
+ 2. Stateful ops which modify the same resource will execute in program order
+
+ Note: creating variables in an automatic control dependencies context is not
+ supported (the value of the variables will never change as they will keep
+ getting reinitialized).
+
+ NOT THREAD SAFE
+ """
+
+ def __init__(self):
+ self._returned_tensors = set()
+
+ def mark_as_return(self, tensor):
+ self._returned_tensors.add(tensor)
+
+ def __enter__(self):
+ if context.in_eager_mode():
+ return self
+ # This code assumes no other thread is adding ops to the graph while
+ # we're adding ops to the graph.
+ # TODO(apassos): Fix this by locking the graph or using a temporary
+ # graph (but that would mess up devices and collections at least,
+ # probably other things as well).
+ self._graph = ops.get_default_graph()
+ self._n_operations = len(self._graph.get_operations())
+ return self
+
+ def _process_switch(self, switch_op, ops_which_must_run,
+ last_op_using_resource_tensor, merge_for_resource):
+ """Processes a switch node for a resource input.
+
+ When tensorflow creates a cond, it creates a control flow context for each
+ branch of the cond. Each external tensor accessed by that branch is routed
+ through a switch op, which gets created in the graph _after_ the op which
+ uses that tensor get created.
+
+ If the resource comes from another switch op we process that one first.
+
+ _process_switch creates a corresponding merge node for the switch node. This
+ merge node is added to the outer control flow context of the switch
+ node. We also ensure that:
+
+ 1. The switch node executes after the previous op which used the resource
+ tensor
+
+ 2. Any op which uses a resource output of the switch node executes before
+ the merge for the switch node.
+
+ 3. The next op which uses the input resource to the switch node (which
+ might be another switch node for the other branch of the conditional)
+ will execute after the merge node is done.
+
+ 4. The merge node is marked as must_run so it will run even if no
+ subsequent operation uses the resource.
+
+ Args:
+ switch_op: the switch op to be processed
+ ops_which_must_run: the set of ops which must run
+ last_op_using_resource_tensor: map from resource tensor to last op using
+ it
+ merge_for_resource: map from resource tensor to merge which must follow
+ all usages of it.
+ """
+ inp = switch_op.inputs[0]
+ if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
+ self._process_switch(inp.op, ops_which_must_run,
+ last_op_using_resource_tensor, merge_for_resource)
+ if switch_op.outputs[0] in merge_for_resource:
+ return
+ new_merge = control_flow_ops.merge(switch_op.outputs,
+ name="artificial_merge")
+ new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access
+ switch_op._control_flow_context.outer_context) # pylint: disable=protected-access
+ # Ensures the merge always runs
+ ops_which_must_run.add(new_merge[0].op)
+ if inp in last_op_using_resource_tensor:
+ # Ensures the switch exectutes after the previous op using the resource.
+ switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
+ # Ensure the next op outside the cond happens after the merge.
+ last_op_using_resource_tensor[inp] = new_merge[0].op
+ if inp in merge_for_resource:
+ merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access
+ for o in switch_op.outputs:
+ # Ensures the merge will execute after all ops inside the cond
+ merge_for_resource[o] = new_merge[0].op
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ if context.in_eager_mode():
+ return
+
+ if self._graph is not ops.get_default_graph():
+ raise RuntimeError(
+ "Graph changed while trying to add control dependencies.")
+
+ # map from resource tensor to the last op which used it
+ last_op_using_resource_tensor = {}
+ # set of conditional and loop exits
+ ops_which_must_run = set()
+ # merge which must depend on ops which use this resource
+ merge_for_resource = {}
+
+ new_operations = self._graph.get_operations()[self._n_operations:]
+
+ # Ensures that uses of resource tensors get serialized properly and all
+ # execute. This is done by keeping a map from resource tensor to the last op
+ # in graph-construction order which used it (last_op_using_resource_tensor).
+ #
+ # Conditionals are written in TensorFlow such that every external tensor
+ # accessed in the conditional goes through a switch op and every return
+ # tensor (it's guaranteed that there will be at least one) goes through a
+ # merge op.
+ #
+ # To handle conditionals, switches are handled in a special way (see
+ # comments for _process_switch). Merge nodes created by TF's conditional
+ # logic (as opposed to by _process_switch) are forced to run and also get a
+ # control dependency added to them to ensure all stateful ops inside their
+ # control flow context run.
+ #
+ # We also ensure that if an op is using a resource output by a switch node
+ # (that is, a resource tensor for which there's a value in
+ # merge_for_resource) this op will run before the merge for that resource.
+ #
+ # We try to add control inputs to nodes respecting their control flow
+ # contexts to avoid dead nodes propagating everywhere and leading to
+ # "retval[0] doesn't have value" errors. If a node gets a control dependency
+ # on a dead node (i.e. a note from an untaken control flow branch) that node
+ # will be marked as dead unless it's a merge node.
+ #
+ # TODO(apassos): serialize non-resource-taking stateful ops as well, and
+ # test that it works. Support while loops. Support init_scope escaping from
+ # this.
+ for op in new_operations:
+ control_inputs = set()
+ # Ensure stateful ops run
+ if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access
+ ops_which_must_run.add(op)
+ # Ignore switches (they're handled separately)
+ if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
+ continue
+ # Make merges trigger all other computation which must run
+ if op.type == "Merge":
+ for o in ops_which_must_run:
+ op._add_control_input(o) # pylint: disable=protected-access
+ for inp in o.inputs:
+ if inp in last_op_using_resource_tensor:
+ last_op_using_resource_tensor[inp] = op
+ ops_which_must_run = set([op])
+ continue
+ for inp in op.inputs:
+ if inp.dtype == dtypes_module.resource:
+ # Deal with switches, finally.
+ if inp.op.type == "Switch":
+ self._process_switch(inp.op, ops_which_must_run,
+ last_op_using_resource_tensor,
+ merge_for_resource)
+ # Ensure uses of resources are serialized
+ if inp in last_op_using_resource_tensor:
+ if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
+ is op._control_flow_context): # pylint: disable=protected-access
+ control_inputs.add(last_op_using_resource_tensor[inp])
+ # Ensure merges happen after the closing of a cond block
+ if inp in merge_for_resource:
+ merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
+ last_op_using_resource_tensor[inp] = op
+ control_inputs = [c for c in control_inputs
+ if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
+ op._add_control_inputs(control_inputs) # pylint: disable=protected-access
+
+ # Ensure all ops which must run do run
+ for r in self._returned_tensors:
+ r.op._add_control_inputs( # pylint: disable=protected-access
+ [o for o in ops_which_must_run
+ if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access
+
+
+def automatic_control_dependencies(f):
+ """Wraps f to automatically insert control dependencies.
+
+ The inserted dependencies ensure that:
+ 1. All stateful ops in f run when the result of f runs
+ 2. Updates to the same resources happen in order.
+
+ Args:
+ f: the function to be wrapped.
+
+ Returns:
+ The wrapped function.
+ """
+
+ def wrapper(*args, **kwds):
+ with AutomaticControlDependencies() as a:
+ result = f(*args, **kwds)
+ for t in nest.flatten(result):
+ a.mark_as_return(t)
+ return result
+
+ return tf_decorator.make_decorator(f, wrapper)
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
create_variable()
+class AutomaticControlDependenciesTest(test.TestCase):
+
+ def testBasic(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ with function.AutomaticControlDependencies() as c:
+ v.assign(v + 1)
+ v.assign(2 * v)
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(), 4.0)
+
+ def testCondMustRun(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+
+ def true_fn():
+ v.assign(v + 1)
+ return 0.0
+
+ def false_fn():
+ v.assign(v + 4)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
+
+ def testCondMustRunSeparateRead(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+
+ def true_fn():
+ v.assign(v + 1)
+ return 0.0
+
+ def false_fn():
+ v.assign(v + 4)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ one = constant_op.constant(1.0)
+ c.mark_as_return(one)
+ one.eval(feed_dict={p: False})
+ self.assertAllEqual(v.read_value().eval(), 5.0)
+ one.eval(feed_dict={p: True})
+ self.assertAllEqual(v.read_value().eval(), 6.0)
+
+ def testCondNested(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ q = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+
+ def true_fn():
+ v.assign(v + 1, name='true')
+ return 1.0
+
+ def false_fn():
+
+ def inner_true_fn():
+ v.assign(v * 2, name='false_true')
+ return 2.0
+
+ def inner_false_fn():
+ v.assign(v * 3, name='false_false')
+ return 3.0
+
+ control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ with ops.name_scope('final'):
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
+ self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
+
+ def testCondOneBranch(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+
+ def true_fn():
+ return 0.0
+
+ def false_fn():
+ v.assign(v + 4)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
+
+ def testCondOneBranchUpdateBefore(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+ v.assign(v * 2)
+
+ def true_fn():
+ return 0.0
+
+ def false_fn():
+ v.assign(v + 4)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
+
+ def testCondOneBranchUpdateAfter(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+ p = array_ops.placeholder(dtype=dtypes.bool)
+ with function.AutomaticControlDependencies() as c:
+
+ def true_fn():
+ return 0.0
+
+ def false_fn():
+ v.assign(v + 4)
+ return 1.0
+
+ control_flow_ops.cond(p, true_fn, false_fn)
+ v.assign(v * 2)
+ val = v.read_value()
+ c.mark_as_return(val)
+ self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0)
+ self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0)
+
+ def testDecorator(self):
+ with context.graph_mode(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ variables.global_variables_initializer().run()
+
+ @function.automatic_control_dependencies
+ def f():
+ v.assign(v + 1)
+ v.assign(2 * v)
+ return v.read_value()
+
+ self.assertAllEqual(f().eval(), 4.0)
+
+
if __name__ == '__main__':
test.main()