From 54ba7aace17cc25d4d063e174ad3a15db5447085 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 12 Feb 2018 13:00:57 -0800 Subject: [PATCH] Scope and decorator to automatically add control dependencies. Should mimic the desired behavior of eager code. For now supports only straight-line code and conditionals. PiperOrigin-RevId: 185421760 --- tensorflow/python/eager/function.py | 206 +++++++++++++++++++++++++++++++ tensorflow/python/eager/function_test.py | 168 +++++++++++++++++++++++++ 2 files changed, 374 insertions(+) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 767d719..d352d67 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -813,3 +814,208 @@ def make_defun_op(func, *args, **kwds): if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") return _defun_internal(name, func, args, kwds) + + +class AutomaticControlDependencies(object): + """Context manager to automatically add control dependencies. + + Code under this context manager will act as if a sensible set of control + dependencies were present. More specifically: + 1. All stateful ops in the scope will execute + 2. Stateful ops which modify the same resource will execute in program order + + Note: creating variables in an automatic control dependencies context is not + supported (the value of the variables will never change as they will keep + getting reinitialized). + + NOT THREAD SAFE + """ + + def __init__(self): + self._returned_tensors = set() + + def mark_as_return(self, tensor): + self._returned_tensors.add(tensor) + + def __enter__(self): + if context.in_eager_mode(): + return self + # This code assumes no other thread is adding ops to the graph while + # we're adding ops to the graph. + # TODO(apassos): Fix this by locking the graph or using a temporary + # graph (but that would mess up devices and collections at least, + # probably other things as well). + self._graph = ops.get_default_graph() + self._n_operations = len(self._graph.get_operations()) + return self + + def _process_switch(self, switch_op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource): + """Processes a switch node for a resource input. + + When tensorflow creates a cond, it creates a control flow context for each + branch of the cond. Each external tensor accessed by that branch is routed + through a switch op, which gets created in the graph _after_ the op which + uses that tensor get created. + + If the resource comes from another switch op we process that one first. + + _process_switch creates a corresponding merge node for the switch node. This + merge node is added to the outer control flow context of the switch + node. We also ensure that: + + 1. The switch node executes after the previous op which used the resource + tensor + + 2. Any op which uses a resource output of the switch node executes before + the merge for the switch node. + + 3. The next op which uses the input resource to the switch node (which + might be another switch node for the other branch of the conditional) + will execute after the merge node is done. + + 4. The merge node is marked as must_run so it will run even if no + subsequent operation uses the resource. + + Args: + switch_op: the switch op to be processed + ops_which_must_run: the set of ops which must run + last_op_using_resource_tensor: map from resource tensor to last op using + it + merge_for_resource: map from resource tensor to merge which must follow + all usages of it. + """ + inp = switch_op.inputs[0] + if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource) + if switch_op.outputs[0] in merge_for_resource: + return + new_merge = control_flow_ops.merge(switch_op.outputs, + name="artificial_merge") + new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access + switch_op._control_flow_context.outer_context) # pylint: disable=protected-access + # Ensures the merge always runs + ops_which_must_run.add(new_merge[0].op) + if inp in last_op_using_resource_tensor: + # Ensures the switch exectutes after the previous op using the resource. + switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access + # Ensure the next op outside the cond happens after the merge. + last_op_using_resource_tensor[inp] = new_merge[0].op + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access + for o in switch_op.outputs: + # Ensures the merge will execute after all ops inside the cond + merge_for_resource[o] = new_merge[0].op + + def __exit__(self, unused_type, unused_value, unused_traceback): + if context.in_eager_mode(): + return + + if self._graph is not ops.get_default_graph(): + raise RuntimeError( + "Graph changed while trying to add control dependencies.") + + # map from resource tensor to the last op which used it + last_op_using_resource_tensor = {} + # set of conditional and loop exits + ops_which_must_run = set() + # merge which must depend on ops which use this resource + merge_for_resource = {} + + new_operations = self._graph.get_operations()[self._n_operations:] + + # Ensures that uses of resource tensors get serialized properly and all + # execute. This is done by keeping a map from resource tensor to the last op + # in graph-construction order which used it (last_op_using_resource_tensor). + # + # Conditionals are written in TensorFlow such that every external tensor + # accessed in the conditional goes through a switch op and every return + # tensor (it's guaranteed that there will be at least one) goes through a + # merge op. + # + # To handle conditionals, switches are handled in a special way (see + # comments for _process_switch). Merge nodes created by TF's conditional + # logic (as opposed to by _process_switch) are forced to run and also get a + # control dependency added to them to ensure all stateful ops inside their + # control flow context run. + # + # We also ensure that if an op is using a resource output by a switch node + # (that is, a resource tensor for which there's a value in + # merge_for_resource) this op will run before the merge for that resource. + # + # We try to add control inputs to nodes respecting their control flow + # contexts to avoid dead nodes propagating everywhere and leading to + # "retval[0] doesn't have value" errors. If a node gets a control dependency + # on a dead node (i.e. a note from an untaken control flow branch) that node + # will be marked as dead unless it's a merge node. + # + # TODO(apassos): serialize non-resource-taking stateful ops as well, and + # test that it works. Support while loops. Support init_scope escaping from + # this. + for op in new_operations: + control_inputs = set() + # Ensure stateful ops run + if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access + ops_which_must_run.add(op) + # Ignore switches (they're handled separately) + if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: + continue + # Make merges trigger all other computation which must run + if op.type == "Merge": + for o in ops_which_must_run: + op._add_control_input(o) # pylint: disable=protected-access + for inp in o.inputs: + if inp in last_op_using_resource_tensor: + last_op_using_resource_tensor[inp] = op + ops_which_must_run = set([op]) + continue + for inp in op.inputs: + if inp.dtype == dtypes_module.resource: + # Deal with switches, finally. + if inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, + merge_for_resource) + # Ensure uses of resources are serialized + if inp in last_op_using_resource_tensor: + if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access + is op._control_flow_context): # pylint: disable=protected-access + control_inputs.add(last_op_using_resource_tensor[inp]) + # Ensure merges happen after the closing of a cond block + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access + last_op_using_resource_tensor[inp] = op + control_inputs = [c for c in control_inputs + if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access + op._add_control_inputs(control_inputs) # pylint: disable=protected-access + + # Ensure all ops which must run do run + for r in self._returned_tensors: + r.op._add_control_inputs( # pylint: disable=protected-access + [o for o in ops_which_must_run + if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access + + +def automatic_control_dependencies(f): + """Wraps f to automatically insert control dependencies. + + The inserted dependencies ensure that: + 1. All stateful ops in f run when the result of f runs + 2. Updates to the same resources happen in order. + + Args: + f: the function to be wrapped. + + Returns: + The wrapped function. + """ + + def wrapper(*args, **kwds): + with AutomaticControlDependencies() as a: + result = f(*args, **kwds) + for t in nest.flatten(result): + a.mark_as_return(t) + return result + + return tf_decorator.make_decorator(f, wrapper) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 3e8e67a..431d938 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -595,5 +596,172 @@ class FunctionTest(test.TestCase): create_variable() +class AutomaticControlDependenciesTest(test.TestCase): + + def testBasic(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + with function.AutomaticControlDependencies() as c: + v.assign(v + 1) + v.assign(2 * v) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(), 4.0) + + def testCondMustRun(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) + + def testCondMustRunSeparateRead(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + one = constant_op.constant(1.0) + c.mark_as_return(one) + one.eval(feed_dict={p: False}) + self.assertAllEqual(v.read_value().eval(), 5.0) + one.eval(feed_dict={p: True}) + self.assertAllEqual(v.read_value().eval(), 6.0) + + def testCondNested(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + q = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1, name='true') + return 1.0 + + def false_fn(): + + def inner_true_fn(): + v.assign(v * 2, name='false_true') + return 2.0 + + def inner_false_fn(): + v.assign(v * 3, name='false_false') + return 3.0 + + control_flow_ops.cond(q, inner_true_fn, inner_false_fn) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + with ops.name_scope('final'): + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) + self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) + + def testCondOneBranch(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) + + def testCondOneBranchUpdateBefore(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + v.assign(v * 2) + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) + + def testCondOneBranchUpdateAfter(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + v.assign(v * 2) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) + + def testDecorator(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + + @function.automatic_control_dependencies + def f(): + v.assign(v + 1) + v.assign(2 * v) + return v.read_value() + + self.assertAllEqual(f().eval(), 4.0) + + if __name__ == '__main__': test.main() -- 2.7.4