[tf.contrib CriticalSection] Avoid deadlocks using additional control dependencies...
authorEugene Brevdo <ebrevdo@google.com>
Wed, 21 Mar 2018 15:25:34 +0000 (08:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 15:27:54 +0000 (08:27 -0700)
PiperOrigin-RevId: 189910726

tensorflow/contrib/framework/python/ops/critical_section_ops.py
tensorflow/contrib/framework/python/ops/critical_section_test.py

index cc19372..1893d7b 100644 (file)
@@ -24,10 +24,8 @@ import collections
 # from tensorflow.core.protobuf import critical_section_pb2
 
 from tensorflow.python.eager import context
-from tensorflow.python.eager import function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_resource_variable_ops
@@ -48,6 +46,26 @@ class _ExecutionSignature(
   pass
 
 
+def _identity(x):
+  """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
+  if isinstance(x, tensor_array_ops.TensorArray):
+    return x.identity()
+  elif isinstance(x, ops.Operation):
+    return control_flow_ops.group(x)
+  elif context.executing_eagerly() and x is None:
+    return None
+  else:
+    return array_ops.identity(x)
+
+
+def _get_colocation(op):
+  """Get colocation symbol from op, if any."""
+  try:
+    return op.get_attr("_class")
+  except ValueError:
+    return None
+
+
 class CriticalSection(object):
   """Critical section.
 
@@ -180,8 +198,8 @@ class CriticalSection(object):
       The tensors returned from `fn(*args, **kwargs)`.
 
     Raises:
-      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
-        way.
+      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
+        or lazy way that may cause a deadlock.
       ValueError: If `exclusive_resource_access` is not provided (is `True`) and
         another `CriticalSection` has an execution requesting the same
         resources as in `*args`, `**kwargs`, and any additionaly captured
@@ -193,69 +211,52 @@ class CriticalSection(object):
     exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
 
     with ops.name_scope(name, "critical_section_execute", []):
-      lock = gen_resource_variable_ops.mutex_lock(self._handle)
-
-      with ops.control_dependencies([lock]):
-        c_known_ops = set()
-        c_captured_tensors = set()
 
-        def add_op_internal(op):
-          c_known_ops.add(op)
-          for i in op.inputs:
-            if i.op not in c_known_ops:
-              c_captured_tensors.add(i)
+      # Ensure that mutex locking only happens *after* all args and
+      # kwargs have been executed.  This avoids certain types of deadlocks.
+      lock = gen_resource_variable_ops.mutex_lock(self._handle)
 
-        c = function.HelperContext(add_op_internal)
-        with c:
+      if not context.executing_eagerly():
+        # NOTE(ebrevdo): This is to ensure we don't pick up spurious
+        # Operations created by other threads.
+        with ops.get_default_graph()._lock:  # pylint: disable=protected-access
+          existing_ops = ops.get_default_graph().get_operations()
+          with ops.control_dependencies([lock]):
+            r = fn(*args, **kwargs)
+          # TODO(ebrevdo): If creating critical sections in a python loop, this
+          # makes graph creation time quadratic.  Revisit if this
+          # becomes a problem.
+          created_ops = (set(ops.get_default_graph().get_operations())
+                         .difference(existing_ops))
+      else:
+        with ops.control_dependencies([lock]):
           r = fn(*args, **kwargs)
 
-        resource_inputs = set([
-            x for x in
-            list(nest.flatten(args)) + nest.flatten(kwargs.values()) +
-            list(c_captured_tensors)
-            if tensor_util.is_tensor(x) and x.dtype == dtypes.resource])
-
-      if self._handle in resource_inputs:
-        raise ValueError("The function fn attempts to access the "
-                         "CriticalSection in which it would be running.  "
-                         "This is illegal and would cause deadlocks.  "
-                         "CriticalSection: %s." % self._handle)
-
       if not context.executing_eagerly():
-        # Collections and op introspection does not work in eager
-        # mode.  This is generally ok; since eager mode (as of
-        # writing) executes sequentially anyway.
-        for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
-          sg_handle_name = ops.convert_to_tensor(sg.handle).name
-          self_handle_name = ops.convert_to_tensor(self._handle).name
-          if sg_handle_name == self_handle_name:
-            # Other executions in the same critical section are allowed.
-            continue
-          if not (exclusive_resource_access or sg.exclusive_resource_access):
-            # Neither execution requested exclusive access.
-            continue
-          resource_intersection = resource_inputs.intersection(sg.resources)
-          if resource_intersection:
-            raise ValueError(
-                "This execution would access resources: %s.  Either this "
-                "lock (CriticalSection: %s) or lock '%s' "
-                "(CriticalSection: %s) requested exclusive resource access "
-                "of this resource.  Did you mean to call execute with keyword "
-                "argument exclusive_resource_access=False?" %
-                (list(resource_intersection), self._handle.name,
-                 sg.op.name, sg.handle.name))
-
-      def identity(x):  # pylint: disable=invalid-name
-        if isinstance(x, tensor_array_ops.TensorArray):
-          return x.identity()
-        elif isinstance(x, ops.Operation):
-          return control_flow_ops.group(x)
-        elif context.executing_eagerly() and x is None:
-          return None
-        else:
-          return array_ops.identity(x)
-
-      r_flat = [identity(x) for x in nest.flatten(r)]
+        self._add_control_dependencies_to_lock(created_ops, lock.op)
+
+        # captured_resources is a list of resources that are directly
+        # accessed only by ops created during fn(), not by any
+        # ancestors of those ops in the graph.
+        captured_resources = set([
+            input_ for op in created_ops
+            for input_ in op.inputs
+            if input_.dtype == dtypes.resource
+        ])
+
+        # NOTE(ebrevdo): The only time self._is_self_handle() is True
+        # in this call is if one of the recently created ops, within
+        # the execute(), themselves attempt to access the
+        # CriticalSection.  This will cause a deadlock.
+        if any(self._is_self_handle(x) for x in captured_resources):
+          raise ValueError("The function fn attempts to directly access the "
+                           "CriticalSection in which it would be running.  "
+                           "This is illegal and would cause deadlocks.")
+
+        self._check_multiple_access_to_resources(
+            captured_resources, exclusive_resource_access)
+
+      r_flat = [_identity(x) for x in nest.flatten(r)]
 
       with ops.control_dependencies(r_flat):
         # The identity must run on the same machine as self._handle
@@ -268,23 +269,93 @@ class CriticalSection(object):
 
         # Make sure that if any element of r is accessed, all of
         # them are executed together.
-        r = nest.pack_sequence_as(
-            r, control_flow_ops.tuple(nest.flatten(r)))
+        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
 
       with ops.control_dependencies([ensure_lock_exists]):
-        outputs = nest.map_structure(identity, r)
+        outputs = nest.map_structure(_identity, r)
 
       if not context.executing_eagerly():
         signature = _ExecutionSignature(
             op=lock.op,
             handle=self._handle,
-            resources=list(resource_inputs),
+            resources=list(captured_resources),
             exclusive_resource_access=exclusive_resource_access)
         ops.add_to_collections(
             CRITICAL_SECTION_EXECUTIONS, signature)
 
       return outputs
 
+  def _add_control_dependencies_to_lock(self, created_ops, lock_op):
+    """To avoid deadlocks, all args must be executed before lock_op."""
+    # Get all arguments (explicit and captured) of all ops created by fn().
+    all_args = set([input_.op for op in created_ops for input_ in op.inputs])
+    all_args.update(
+        input_op for op in created_ops for input_op in op.control_inputs)
+    # Unfortunately, we can't use sets throughout because TF seems to
+    # create new Operation objects for the same op sometimes; and we
+    # can't rely on id(op).
+
+    # pylint: disable=protected-access
+    all_args_dict = dict((op._id, op) for op in all_args)
+
+    # Remove ops created within fn, or that lock_op already has a
+    # control dependency on.  Also remove a possible self-loop.
+    for op in created_ops:
+      all_args_dict.pop(op._id, None)
+    for op in lock_op.control_inputs:
+      all_args_dict.pop(op._id, None)
+    for input_ in lock_op.inputs:
+      all_args_dict.pop(input_.op._id, None)
+    all_args_dict.pop(lock_op._id, None)
+
+    lock_op._add_control_inputs(all_args_dict.values())
+    # pylint: enable=protected-access
+
+  def _is_self_handle(self, x):
+    """Check if the tensor `x` is the same Mutex as `self._handle`."""
+    return (x.op.type == "MutexV2"
+            # blank shared_name means the op will create a unique one.
+            and x.op.get_attr("shared_name")
+            and (x.op.get_attr("shared_name") ==
+                 self._handle.op.get_attr("shared_name"))
+            and (x.op.device == self._handle.op.device
+                 or _get_colocation(x.op) == _get_colocation(self._handle.op)))
+
+  def _check_multiple_access_to_resources(
+      self, captured_resources, exclusive_resource_access):
+    """Raise if captured_resources are accessed by another CriticalSection.
+
+    Args:
+      captured_resources: Set of tensors of type resource.
+      exclusive_resource_access: Whether this execution requires exclusive
+        resource access.
+
+    Raises:
+      ValueError: If any tensors in `captured_resources` are also accessed
+        by another `CriticalSection`, and at least one of them requires
+        exclusive resource access.
+    """
+    # Collections and op introspection does not work in eager
+    # mode.  This is generally ok; since eager mode (as of
+    # writing) executes sequentially anyway.
+    for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
+      if self._is_self_handle(sg.handle):
+        # Other executions in the same critical section are allowed.
+        continue
+      if not (exclusive_resource_access or sg.exclusive_resource_access):
+        # Neither execution requested exclusive access.
+        continue
+      resource_intersection = captured_resources.intersection(sg.resources)
+      if resource_intersection:
+        raise ValueError(
+            "This execution would access resources: %s.  Either this "
+            "lock (CriticalSection: %s) or lock '%s' "
+            "(CriticalSection: %s) requested exclusive resource access "
+            "of this resource.  Did you mean to call execute with keyword "
+            "argument exclusive_resource_access=False?" %
+            (list(resource_intersection), self._handle.name,
+             sg.op.name, sg.handle.name))
+
   # TODO(ebrevdo): Re-enable once CriticalSection is in core.
 
   # def to_proto(self, export_scope=None):
index c916592..e24140b 100644 (file)
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
 # TODO(ebrevdo): Re-enable once CriticalSection is in core.
 # from tensorflow.python.training import saver as saver_lib
 
@@ -37,7 +38,7 @@ class CriticalSectionTest(test.TestCase):
     v = resource_variable_ops.ResourceVariable(0.0, name="v")
 
     def fn(a, b):
-      c = v.read_value()
+      c = v.value()
       with ops.control_dependencies([c]):
         nv = v.assign_add(a * b)
         with ops.control_dependencies([nv]):
@@ -143,12 +144,148 @@ class CriticalSectionTest(test.TestCase):
     # This does not work properly in eager mode.  Eager users will
     # just hit a deadlock if they do this.  But at least it'll be easier
     # to debug.
+    cs = critical_section_ops.CriticalSection()
+    def fn(x):
+      return cs.execute(lambda y: y + 1, x)
+    with self.assertRaisesRegexp(
+        ValueError,
+        r"attempts to directly access the CriticalSection in which it "
+        r"would be running"):
+      cs.execute(fn, 1.0)
+
+  def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
+    # This one is subtle; and we're being overly cautious here.  The
+    # deadlock we are ensuring we catch is:
+    #
+    # to_capture = CS[lambda x: x + 1](1.0)
+    # deadlocked = CS[lambda x: x + to_capture](1.0)
+    #
+    # This would have caused a deadlock because executing `deadlocked` will
+    # lock the mutex on CS; but then due to dependencies, will attempt
+    # to compute `to_capture`.  This computation requires locking CS,
+    # but that is not possible now because CS is already locked by
+    # `deadlocked`.
+    #
+    # We check that CriticalSection.execute properly inserts new
+    # control dependencies to its lock to ensure all captured
+    # operations are finished before anything runs within the critical section.
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
+    fn = array_ops.identity
+    to_capture = cs.execute(fn, 1.0)
+    fn_captures = lambda x: x + to_capture
+    to_capture_too = array_ops.identity(to_capture)
+
+    ex_0 = cs.execute(fn_captures, 1.0)
+
+    with ops.control_dependencies([to_capture]):
+      # This is OK because to_capture will execute before this next call
+      ex_1 = cs.execute(fn_captures, 1.0)
+
+    dependency = array_ops.identity(to_capture)
+
+    fn_captures_dependency = lambda x: x + dependency
+
+    ex_2 = cs.execute(fn_captures_dependency, 1.0)
+
+    with ops.control_dependencies([to_capture_too]):
+      ex_3 = cs.execute(fn_captures_dependency, 1.0)
+
+    # Ensure there's no actual deadlock on to_execute.
+    self.assertEquals(2.0, self.evaluate(ex_0))
+    self.assertEquals(2.0, self.evaluate(ex_1))
+    self.assertEquals(2.0, self.evaluate(ex_2))
+    self.assertEquals(2.0, self.evaluate(ex_3))
+
+  def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
+
+    def body_implicit_capture(i, j):
+      # This would have caused a deadlock if not for logic in execute
+      # that inserts additional control dependencies onto the lock op:
+      #   * Loop body argument j is captured by fn()
+      #   * i is running in parallel to move forward the execution
+      #   * j is not being checked by the predicate function
+      #   * output of cs.execute() is returned as next j.
+      fn = lambda: j + 1
+      return (i + 1, cs.execute(fn))
+
+    (i_n, j_n) = control_flow_ops.while_loop(
+        lambda i, _: i < 1000,
+        body_implicit_capture,
+        [0, 0],
+        parallel_iterations=25)
+    logging.warn(
+        "\n==============\nRunning "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_implicit_capture'\n"
+        "==============\n")
+    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+    logging.warn(
+        "\n==============\nSuccessfully finished running "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_implicit_capture'\n"
+        "==============\n")
+
+    def body_implicit_capture_protected(i, j):
+      # This version is ok because we manually add a control
+      # dependency on j, which is an argument to the while_loop body
+      # and captured by fn.
+      fn = lambda: j + 1
+      with ops.control_dependencies([j]):
+        return (i + 1, cs.execute(fn))
+
+    (i_n, j_n) = control_flow_ops.while_loop(
+        lambda i, _: i < 1000,
+        body_implicit_capture_protected,
+        [0, 0],
+        parallel_iterations=25)
+    logging.warn(
+        "\n==============\nRunning "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_implicit_capture_protected'\n"
+        "==============\n")
+    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+    logging.warn(
+        "\n==============\nSuccessfully finished running "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_implicit_capture_protected'\n"
+        "==============\n")
+
+    def body_args_capture(i, j):
+      # This version is ok because j is an argument to fn and we can
+      # ensure there's a control dependency on j.
+      fn = lambda x: x + 1
+      return (i + 1, cs.execute(fn, j))
+
+    (i_n, j_n) = control_flow_ops.while_loop(
+        lambda i, _: i < 1000,
+        body_args_capture,
+        [0, 0],
+        parallel_iterations=25)
+    logging.warn(
+        "\n==============\nRunning "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_args_capture'\n"
+        "==============\n")
+    self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+    logging.warn(
+        "\n==============\nSuccessfully finished running "
+        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+        "body_args_capture'\n"
+        "==============\n")
+
+  def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
+    # This does not work properly in eager mode.  Eager users will
+    # just hit a deadlock if they do this.  But at least it'll be easier
+    # to debug.
     cs = critical_section_ops.CriticalSection(shared_name="cs")
+    cs_same = critical_section_ops.CriticalSection(shared_name="cs")
     def fn(x):
-      return cs.execute(lambda x: x+1, x)
+      return cs_same.execute(lambda x: x+1, x)
     with self.assertRaisesRegexp(
         ValueError,
-        r"attempts to access the CriticalSection in which it would be running"):
+        r"attempts to directly access the CriticalSection in which it "
+        r"would be running"):
       cs.execute(fn, 1.0)
 
   def testMultipleCSExecutionsRequestSameResource(self):