New Mutex operations for a distributed-happy and Function-less CriticalSection.
authorEugene Brevdo <ebrevdo@google.com>
Fri, 23 Feb 2018 00:18:34 +0000 (16:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 00:22:41 +0000 (16:22 -0800)
Original idea by Alex Passos; impl and cancellation handling by Eugene Brevdo with help from Alex.

PiperOrigin-RevId: 186692306

15 files changed:
tensorflow/contrib/framework/BUILD
tensorflow/contrib/framework/python/ops/critical_section_ops.py
tensorflow/contrib/framework/python/ops/critical_section_test.py
tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_CriticalSectionOp.pbtxt [deleted file]
tensorflow/core/api_def/base_api/api_def_ExecuteInCriticalSection.pbtxt [deleted file]
tensorflow/core/api_def/base_api/api_def_MutexLock.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_MutexV2.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/critical_section.cc [deleted file]
tensorflow/core/kernels/mutex_ops.cc [new file with mode: 0644]
tensorflow/core/ops/compat/ops_history.v1.pbtxt
tensorflow/core/ops/resource_variable_ops.cc
tensorflow/python/eager/function.py
tensorflow/python/ops/control_flow_ops.py

index 9e5f54f0973eae899ca65e4098358107053cb7d4..dbdb5cfaaca1a687fefb81cee200295d5cbb7fd5 100644 (file)
@@ -185,11 +185,14 @@ cuda_py_test(
     additional_deps = [
         "//tensorflow/python:client_testlib",
         ":framework_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:gradients",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:tensor_array_ops",
     ],
 )
 
index 182fec924febb74a23b82b1664d137f033f3b1b4..3c5c55ed656432a33f19462130a9e58c2ab14efb 100644 (file)
@@ -27,7 +27,11 @@ from tensorflow.python.eager import context
 from tensorflow.python.eager import function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.util import nest
 
 
@@ -38,7 +42,8 @@ CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
 
 class _ExecutionSignature(
     collections.namedtuple("_ExecutionSignature",
-                           ("op", "exclusive_resource_access"))):
+                           ("op", "handle",
+                            "resources", "exclusive_resource_access"))):
   """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
   pass
 
@@ -112,16 +117,18 @@ class CriticalSection(object):
   ```
   """
 
-  def __init__(self, name=None, critical_section_def=None, import_scope=None):
+  def __init__(self, name=None, shared_name=None,
+               critical_section_def=None, import_scope=None):
     """Creates a critical section."""
     if critical_section_def and name is not None:
-      raise ValueError("critical_section_def and name are mutually exclusive.")
+      raise ValueError("critical_section_def and shared_name are "
+                       "mutually exclusive.")
     if critical_section_def:
       self._init_from_proto(critical_section_def, import_scope=import_scope)
     else:
-      self._init_from_args(name)
+      self._init_from_args(name, shared_name)
 
-  def _init_from_proto(self, critical_section_def, import_scope):
+  def _init_from_proto(self, critical_section_def, import_scope):  # pylint: disable=invalid-name
     raise NotImplementedError("Not yet implemented")
     # TODO(ebrevdo): Re-enable once CriticalSection is in core.
     # assert isinstance(
@@ -133,18 +140,20 @@ class CriticalSection(object):
     #         critical_section_def.critical_section_name,
     #         import_scope=import_scope))
 
-  def _init_from_args(self, name):
+  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
     """Initialize the CriticalSection from constructor arguments."""
     with ops.name_scope(name, "CriticalSection", []) as name:
       with ops.control_dependencies(None):
         # pylint: disable=protected-access
-        handle_name = ops._name_from_scope_name(name)
         container = ops.get_default_graph()._container
         # pylint: enable=protected-access
+        if shared_name is None:
+          shared_name = name
         if container is None:
           container = ""
-        self._handle = gen_resource_variable_ops.critical_section_op(
-            shared_name=handle_name, name=name)
+        self._handle = gen_resource_variable_ops.mutex_v2(
+            shared_name=shared_name, container=container, name=name)
+
     if context.in_graph_mode():
       ops.add_to_collections(CRITICAL_SECTIONS, self)
 
@@ -183,68 +192,96 @@ class CriticalSection(object):
     name = kwargs.pop("name", None)
     exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
 
-    args = nest.map_structure(ops.convert_to_tensor, args)
     with ops.name_scope(name, "critical_section_execute", []):
-      fn_op = function.make_defun_op(fn, *args, **kwargs)
-      flat_dtypes = nest.flatten(fn_op.output_dtypes)
-      flat_shapes = nest.flatten(fn_op.output_shapes)
-      all_inputs = nest.flatten(args) + fn_op.captured_inputs
-      if self._handle in all_inputs:
+      lock = gen_resource_variable_ops.mutex_lock(self._handle)
+
+      with ops.control_dependencies([lock]):
+        c_known_ops = set()
+        c_captured_tensors = set()
+
+        def add_op_internal(op):
+          c_known_ops.add(op)
+          for i in op.inputs:
+            if i.op not in c_known_ops:
+              c_captured_tensors.add(i)
+
+        c = function.HelperContext(add_op_internal)
+        with c:
+          r = fn(*args, **kwargs)
+
+        resource_inputs = set([
+            x for x in
+            list(nest.flatten(args)) + nest.flatten(kwargs.values()) +
+            list(c_captured_tensors)
+            if tensor_util.is_tensor(x) and x.dtype == dtypes.resource])
+
+      if self._handle in resource_inputs:
         raise ValueError("The function fn attempts to access the "
-                         "CriticalSection in which it would be running.  This "
-                         "is illegal and would cause deadlocks.  "
+                         "CriticalSection in which it would be running.  "
+                         "This is illegal and would cause deadlocks.  "
                          "CriticalSection: %s." % self._handle)
 
       if context.in_graph_mode():
         # Collections and op introspection does not work in eager
         # mode.  This is generally ok; since eager mode (as of
         # writing) executes sequentially anyway.
-        all_input_resources = [
-            x for x in all_inputs if x.dtype == dtypes.resource]
         for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
-          if sg.op.inputs[0].name == self._handle.name:
+          if sg.handle.name == self._handle.name:
             # Other executions in the same critical section are allowed.
             continue
           if not (exclusive_resource_access or sg.exclusive_resource_access):
             # Neither execution requested exclusive access.
             continue
-          sg_input_names = [y.name for y in sg.op.inputs[1:]]
-          for res in all_input_resources:
-            if res.name in sg_input_names:
-              raise ValueError(
-                  "This execution would access resource %s; but either this "
-                  "execution (CriticalSection: %s) or Execution '%s' "
-                  "(CriticalSection: %s) requested exclusive resource access "
-                  "of this resource for their critical section.  Did you mean "
-                  "to call execute with keyword argument "
-                  "exclusive_resource_access=False?"
-                  % (res.name,
-                     self.name,
-                     sg.op.name,
-                     sg.op.inputs[0].op.name))
-
-      flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
-          critical_section=self._handle,
-          arguments=all_inputs,
-          f=fn_op,
-          output_types=flat_dtypes,
-          output_shapes=flat_shapes)
+          resource_intersection = resource_inputs.intersection(sg.resources)
+          if resource_intersection:
+            raise ValueError(
+                "This execution would access resources: %s.  Either this "
+                "lock (CriticalSection: %s) or lock '%s' "
+                "(CriticalSection: %s) requested exclusive resource access "
+                "of this resource.  Did you mean to call execute with keyword "
+                "argument exclusive_resource_access=False?" %
+                (list(resource_intersection), self._handle.name,
+                 sg.op.name, sg.handle.name))
+
+      def identity(x):  # pylint: disable=invalid-name
+        if isinstance(x, tensor_array_ops.TensorArray):
+          return x.identity()
+        elif isinstance(x, ops.Operation):
+          return control_flow_ops.group(x)
+        elif context.in_eager_mode() and x is None:
+          return None
+        else:
+          return array_ops.identity(x)
+
+      r_flat = [identity(x) for x in nest.flatten(r)]
+
+      with ops.control_dependencies(r_flat):
+        # The identity must run on the same machine as self._handle
+        with ops.colocate_with(self._handle):
+          # Do not use array_ops.identity as there are special
+          # optimizations within TensorFlow which seem to elide it
+          # even when optimizations are disabled(!).
+          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
+              lock)
+
+        # Make sure that if any element of r is accessed, all of
+        # them are executed together.
+        r = nest.pack_sequence_as(
+            r, control_flow_ops.tuple(nest.flatten(r)))
+
+      with ops.control_dependencies([ensure_lock_exists]):
+        outputs = nest.map_structure(identity, r)
 
       if context.in_graph_mode():
-        if isinstance(flat_outputs, ops.Operation):
-          flat_outputs = [flat_outputs]
-        op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
-              else flat_outputs[0])
         signature = _ExecutionSignature(
-            op=op,
+            op=lock.op,
+            handle=self._handle,
+            resources=list(resource_inputs),
             exclusive_resource_access=exclusive_resource_access)
         ops.add_to_collections(
             CRITICAL_SECTION_EXECUTIONS, signature)
 
-      return (flat_outputs[0]
-              if (len(flat_outputs) == 1
-                  and isinstance(flat_outputs[0], ops.Operation))
-              else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
+      return outputs
 
   # TODO(ebrevdo): Re-enable once CriticalSection is in core.
 
@@ -276,6 +313,7 @@ class CriticalSection(object):
 
 # def _execution_to_proto_fn(execution_signature, export_scope=None):
 #   """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
+#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
 
 #   Args:
 #     execution_signature: Instance of `_ExecutionSignature`.
@@ -298,6 +336,7 @@ class CriticalSection(object):
 
 # def _execution_from_proto_fn(op_def, import_scope=None):
 #   """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
+#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
 #   assert isinstance(
 #       op_def, critical_section_pb2.CriticalSectionExecutionDef)
 
index a416724d3ba1719471d70667e140f9cd2daf86c7..c916592ce1979fe3a79cf28ad4bdac44284cce97 100644 (file)
@@ -19,12 +19,10 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.framework.python.ops import critical_section_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
 # TODO(ebrevdo): Re-enable once CriticalSection is in core.
@@ -35,7 +33,7 @@ class CriticalSectionTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testCreateCriticalSection(self):
-    cs = critical_section_ops.CriticalSection(name="cs")
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
     v = resource_variable_ops.ResourceVariable(0.0, name="v")
 
     def fn(a, b):
@@ -45,16 +43,72 @@ class CriticalSectionTest(test.TestCase):
         with ops.control_dependencies([nv]):
           return array_ops.identity(c)
 
-    num_concurrent = 1000
+    num_concurrent = 100
     r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
     self.evaluate(v.initializer)
     r_value = self.evaluate(r)
     self.assertAllClose([2.0 * i for i in range(num_concurrent)],
                         sorted(r_value))
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testCriticalSectionWithControlFlow(self):
+    for outer_cond in [False, True]:
+      for inner_cond in [False, True]:
+        cs = critical_section_ops.CriticalSection(shared_name="cs")
+        v = resource_variable_ops.ResourceVariable(0.0, name="v")
+        num_concurrent = 100
+
+        # pylint: disable=cell-var-from-loop
+        def fn(a, b):
+          c = v.read_value()
+          def true_fn():
+            with ops.control_dependencies([c]):
+              nv = v.assign_add(a * b)
+              with ops.control_dependencies([nv]):
+                return array_ops.identity(c)
+          return control_flow_ops.cond(
+              array_ops.identity(inner_cond), true_fn, lambda: c)
+
+        def execute():
+          return cs.execute(fn, 1.0, 2.0)
+
+        r = [
+            control_flow_ops.cond(array_ops.identity(outer_cond),
+                                  execute,
+                                  v.read_value)
+            for _ in range(num_concurrent)
+        ]
+        # pylint: enable=cell-var-from-loop
+
+        self.evaluate(v.initializer)
+        r_value = self.evaluate(r)
+        if inner_cond and outer_cond:
+          self.assertAllClose([2.0 * i for i in range(num_concurrent)],
+                              sorted(r_value))
+        else:
+          self.assertAllClose([0] * num_concurrent, r_value)
+
+  def testCriticalSectionInParallelDoesntDeadlockOnError(self):
+    # No eager mode execution of this test because eager does not
+    # run fn() in parallel, which is where the deadlock could
+    # potentially occur (in graph mode).
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
+    v = resource_variable_ops.ResourceVariable(0.0, name="v")
+
+    def fn(i):
+      error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
+      with ops.control_dependencies([error]):
+        return v.read_value()
+    num_concurrent = 2
+    r = [cs.execute(fn, i) for i in range(num_concurrent)]
+    self.evaluate(v.initializer)
+    for _ in range(100):
+      with self.assertRaisesOpError("Error"):
+        self.evaluate(r)
+
   @test_util.run_in_graph_and_eager_modes()
   def testCreateCriticalSectionFnReturnsOp(self):
-    cs = critical_section_ops.CriticalSection(name="cs")
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
     v = resource_variable_ops.ResourceVariable(0.0, name="v")
 
     def fn_return_op(a, b):
@@ -62,7 +116,7 @@ class CriticalSectionTest(test.TestCase):
       with ops.control_dependencies([c]):
         nv = v.assign_add(a * b)
         with ops.control_dependencies([nv]):
-          return ()
+          return control_flow_ops.no_op()
 
     num_concurrent = 100
     r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
@@ -71,47 +125,25 @@ class CriticalSectionTest(test.TestCase):
     final_v = self.evaluate(v)
     self.assertAllClose(2.0 * num_concurrent, final_v)
 
-  def testCreateCriticalSectionRaw(self):
-    cs = critical_section_ops.CriticalSection(name="cs")
-    v = resource_variable_ops.ResourceVariable(0.0, name="v")
-
-    @function.Defun(dtypes.float32, dtypes.float32)
-    def fn(a, b):
-      c = v.read_value()
-      with ops.control_dependencies([c]):
-        nv = v.assign_add(a * b)
-        with ops.control_dependencies([nv]):
-          return array_ops.identity(c)
-
-    def execute(fn, *args):
-      output_args = fn.definition.signature.output_arg
-      return resource_variable_ops.execute_in_critical_section(
-          critical_section=cs._handle,
-          arguments=list(args) + fn.captured_inputs,
-          f=fn,
-          output_types=[out.type for out in output_args],
-          output_shapes=[tensor_shape.TensorShape(None) for _ in output_args])
-
-    num_concurrent = 1000
-    r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)]
-    self.evaluate(v.initializer)
-    r_value = self.evaluate(r)
-    self.assertAllClose([2.0 * i for i in range(num_concurrent)],
-                        sorted(r_value))
-
   def testCollection(self):
-    cs = critical_section_ops.CriticalSection(name="cs")
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
     self.assertIn(
         cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
-    execute_op = cs.execute(lambda x: x + 1, 1.0).op
+    execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
+    execute_op = [
+        x for x in execute.graph.get_operations()
+        if "my_execute" in x.name and "MutexLock" in x.type
+    ][0]
     self.assertIn(
         execute_op,
         [signature.op for signature in
          ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
 
-  @test_util.run_in_graph_and_eager_modes()
   def testRecursiveCriticalSectionAccessIsIllegal(self):
-    cs = critical_section_ops.CriticalSection(name="cs")
+    # This does not work properly in eager mode.  Eager users will
+    # just hit a deadlock if they do this.  But at least it'll be easier
+    # to debug.
+    cs = critical_section_ops.CriticalSection(shared_name="cs")
     def fn(x):
       return cs.execute(lambda x: x+1, x)
     with self.assertRaisesRegexp(
@@ -167,7 +199,7 @@ class CriticalSectionTest(test.TestCase):
   #     self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
 
   # def testToProto(self):
-  #   cs = critical_section_ops.CriticalSection(name="cs")
+  #   cs = critical_section_ops.CriticalSection(shared_name="cs")
   #   proto = cs.to_proto()
   #   self.assertEqual(proto.critical_section_name, cs._handle.name)
   #   cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
diff --git a/tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt
new file mode 100644 (file)
index 0000000..b9db827
--- /dev/null
@@ -0,0 +1,19 @@
+op {
+  graph_op_name: "ConsumeMutexLock"
+  in_arg {
+    name: "mutex_lock"
+    description: <<END
+A tensor returned by `MutexLock`.
+END
+  }
+  summary: "This op consumes a lock created by `MutexLock`."
+  description: <<END
+This op exists to consume a tensor created by `MutexLock` (other than
+direct control dependencies).  It should be the only that consumes the tensor,
+and will raise an error if it is not.  Its only purpose is to keep the
+mutex lock tensor alive until it is consumed by this op.
+
+**NOTE**: This operation must run on the same device as its input.  This may
+be enforced via the `colocate_with` mechanism.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CriticalSectionOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_CriticalSectionOp.pbtxt
deleted file mode 100644 (file)
index 5027fa8..0000000
+++ /dev/null
@@ -1,16 +0,0 @@
-op {
-  graph_op_name: "CriticalSectionOp"
-  attr {
-    name: "container"
-    description: <<END
-the container this critical section is placed in.
-END
-  }
-  attr {
-    name: "shared_name"
-    description: <<END
-the name by which this critical section is referred to.
-END
-  }
-  summary: "Creates a handle to a CriticalSection resource."
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExecuteInCriticalSection.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExecuteInCriticalSection.pbtxt
deleted file mode 100644 (file)
index cd5fc84..0000000
+++ /dev/null
@@ -1,49 +0,0 @@
-op {
-  graph_op_name: "ExecuteInCriticalSection"
-  in_arg {
-    name: "critical_section"
-    description: <<END
-The handle of the `critical_section`.
-END
-  }
-  in_arg {
-    name: "arguments"
-    description: <<END
-Arguments for `f`, including any captured inputs appended at the end.
-END
-  }
-  out_arg {
-    name: "outputs"
-    description: <<END
-The outputs of `f`.
-END
-  }
-  attr {
-    name: "f"
-    description: <<END
-The `Function` to execute.
-END
-  }
-  summary: "Executes function `f` within critical section `critical_section`."
-  description: <<END
-While `f` is running in `critical_section`, no other functions which wish to
-use this critical section may run.
-
-Often the use case is that two executions of the same graph, in parallel,
-wish to run `f`; and we wish to ensure that only one of them executes
-at a time.  This is especially important if `f` modifies one or more
-variables at a time.
-
-It is also useful if two separate functions must share a resource, but we
-wish to ensure the usage is exclusive.
-
-The signature of `f` is expected to be:
-
-```
-  outputs <- F(arguments)
-```
-Typically, but this is not required, `arguments` contain resources.  The
-primary purpose of this op is to limit access to these resources to one
-execution of `F` at a time.
-END
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_MutexLock.pbtxt b/tensorflow/core/api_def/base_api/api_def_MutexLock.pbtxt
new file mode 100644 (file)
index 0000000..cd3eb43
--- /dev/null
@@ -0,0 +1,58 @@
+op {
+  graph_op_name: "MutexLock"
+  in_arg {
+    name: "mutex"
+    description: <<END
+The mutex resource to lock.
+END
+  }
+  out_arg {
+    name: "mutex_lock"
+    description: <<END
+A tensor that keeps a shared pointer to a lock on the mutex;
+when the Tensor is destroyed, the use count on the shared pointer is decreased
+by 1.  When it reaches 0, the lock is released.
+END
+  }
+  summary: "Locks a mutex resource.  The output is the lock.  So long as the lock tensor"
+  description: <<END
+is alive, any other request to use `MutexLock` with this mutex will wait.
+
+This is particularly useful for creating a critical section when used in
+conjunction with `MutexLockIdentity`:
+
+```python
+
+mutex = mutex_v2(
+  shared_name=handle_name, container=container, name=name)
+
+def execute_in_critical_section(fn, *args, **kwargs):
+  lock = gen_resource_variable_ops.mutex_lock(mutex)
+
+  with ops.control_dependencies([lock]):
+    r = fn(*args, **kwargs)
+
+  with ops.control_dependencies(nest.flatten(r)):
+    with ops.colocate_with(mutex):
+      ensure_lock_exists = mutex_lock_identity(lock)
+
+    # Make sure that if any element of r is accessed, all of
+    # them are executed together.
+    r = nest.map_structure(tf.identity, r)
+
+  with ops.control_dependencies([ensure_lock_exists]):
+    return nest.map_structure(tf.identity, r)
+```
+
+While `fn` is running in the critical section, no other functions which wish to
+use this critical section may run.
+
+Often the use case is that two executions of the same graph, in parallel,
+wish to run `fn`; and we wish to ensure that only one of them executes
+at a time.  This is especially important if `fn` modifies one or more
+variables at a time.
+
+It is also useful if two separate functions must share a resource, but we
+wish to ensure the usage is exclusive.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MutexV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_MutexV2.pbtxt
new file mode 100644 (file)
index 0000000..22295ec
--- /dev/null
@@ -0,0 +1,24 @@
+op {
+  graph_op_name: "MutexV2"
+  out_arg {
+    name: "resource"
+    description: <<END
+The mutex resource.
+END
+  }
+  attr {
+    name: "container"
+    description: <<END
+If non-empty, this variable is placed in the given container.
+Otherwise, a default container is used.
+END
+  }
+  attr {
+    name: "shared_name"
+    description: <<END
+If non-empty, this variable is named in the given bucket
+with this shared_name. Otherwise, the node name is used instead.
+END
+  }
+  summary: "Creates a Mutex resource that can be locked by `MutexLock`."
+}
index dc93c76eaee6c3408453a74bac98f5e365364247..dd0de7829f1b60297e91549ae65bd747d16a6749 100644 (file)
@@ -1891,9 +1891,9 @@ tf_kernel_library(
     srcs = ["resource_variable_ops.cc"],
     deps = [
         ":bounds_check",
-        ":critical_section",
         ":dense_update_functor",
         ":gather_functor",
+        ":mutex_ops",
         ":scatter_functor",
         ":state",
         ":training_op_helpers",
@@ -4094,9 +4094,9 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
-    name = "critical_section",
-    prefix = "critical_section",
-    deps = STATE_DEPS + [":captured_function"],
+    name = "mutex_ops",
+    prefix = "mutex_ops",
+    deps = STATE_DEPS + [":ops_util"],
 )
 
 tf_cc_test(
@@ -5048,7 +5048,7 @@ filegroup(
             # Excluded due to experimental status:
             "debug_ops.*",
             "scatter_nd_op*",
-            "critical_section.*",
+            "mutex_ops.*",
             "batch_kernels.*",
         ],
     ),
diff --git a/tensorflow/core/kernels/critical_section.cc b/tensorflow/core/kernels/critical_section.cc
deleted file mode 100644 (file)
index 30a9abf..0000000
+++ /dev/null
@@ -1,246 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include <deque>
-#include <utility>
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/captured_function.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-class CriticalSection : public ResourceBase {
- public:
-  explicit CriticalSection() : is_locked_(false) {}
-  ~CriticalSection() override {
-    // Wait for all closures to finish running.
-    mutex_lock lock(mu_);
-    while (!closures_.empty()) {
-      queue_empty_cv_.wait(lock);
-    }
-  }
-
- private:
-  friend class ExecuteInCriticalSectionOp;
-
-  void Acquire(std::function<void()> closure) {
-    std::function<void()> next;
-    {
-      mutex_lock ml(mu_);
-      if (is_locked_) {
-        closures_.push_back(std::move(closure));
-      } else {
-        // This branch is the common case.  Avoid the queue.
-        is_locked_ = true;
-        next = std::move(closure);
-      }
-    }
-    if (next) {
-      next();
-    }
-  }
-
-  void Release() {
-    std::function<void()> next;
-    {
-      mutex_lock ml(mu_);
-      CHECK(is_locked_);
-      if (!closures_.empty()) {
-        // if queue is not empty, start the next entry off the queue.
-        std::swap(next, closures_.front());
-        closures_.pop_front();
-      } else {
-        is_locked_ = false;
-        queue_empty_cv_.notify_all();
-      }
-    }
-    if (next) {
-      next();
-    }
-  }
-
-  string DebugString() override {
-    tf_shared_lock ml(mu_);
-    return strings::StrCat("CriticalSection(locked: ", is_locked_,
-                           " queue_size: ", closures_.size(), ")");
-  }
-
- private:
-  mutex mu_;
-  std::deque<std::function<void()>> closures_ GUARDED_BY(mu_);
-  bool is_locked_ GUARDED_BY(mu_);
-  condition_variable queue_empty_cv_ GUARDED_BY(mu_);
-};
-
-class ExecuteInCriticalSectionOp : public AsyncOpKernel {
- public:
-  explicit ExecuteInCriticalSectionOp(OpKernelConstruction* c)
-      : AsyncOpKernel(c) {
-    OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
-  }
-
- public:
-  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
-    CriticalSection* critical_section = nullptr;
-    OP_REQUIRES_OK_ASYNC(c,
-                         LookupOrCreateResource<CriticalSection>(
-                             c, HandleFromInput(c, 0), &critical_section,
-                             [this, c](CriticalSection** ptr) {
-                               *ptr = new CriticalSection;
-                               return Status::OK();
-                             }),
-                         done);
-    // No need to Unref critical_section; the Closure below will take
-    // care of the Unref associated with this execution.
-
-    auto* execution = new Closure{std::move(done), c, critical_section, &func_};
-    execution->Start();
-  }
-
- private:
-  class Closure {
-   public:
-    AsyncOpKernel::DoneCallback done_;
-    OpKernelContext* ctx_;
-    CriticalSection* cs_;
-    FunctionLibraryRuntime::Handle handle_;
-    FunctionLibraryRuntime::Options opts_;
-    std::vector<Tensor> arguments_t_;
-    std::vector<Tensor> output_t_;
-    NameAttrList* func_;
-
-    explicit Closure(AsyncOpKernel::DoneCallback done, OpKernelContext* ctx,
-                     CriticalSection* critical_section, NameAttrList* func)
-        : done_(std::move(done)),
-          ctx_(ctx),
-          cs_(critical_section),
-          handle_(-1),
-          func_(func) {}
-
-    ~Closure();
-
-    void Start() {
-      // Perform ExecuteFunction isnide a separate thread to avoid
-      // having lightweight Functions be inlined in this thread.
-      // That inlining would in turn inline DoneAndDelete inside the
-      // same thread.  Since DoneAndDelete can call the next
-      // ExecuteFunction in the CriticalSection, this can cause a
-      // stack overflow.
-      cs_->Acquire(
-          [this]() { (*ctx_->runner())([this]() { ExecuteFunction(); }); });
-    }
-
-   private:
-    void ExecuteFunction();
-    void DoneAndDelete(const Status& status);
-  };
-
-  NameAttrList func_;
-};
-
-void ExecuteInCriticalSectionOp::Closure::ExecuteFunction() {
-  // Arguments to a Function are in the order:
-  //   concat(<formal arguments>, <captured arguments>)
-  OpInputList arguments;
-  Status s = ctx_->input_list("arguments", &arguments);
-  if (!s.ok()) {
-    DoneAndDelete(s);
-    return;
-  }
-
-  arguments_t_.reserve(arguments.size());
-  for (const Tensor& t : arguments) {
-    arguments_t_.push_back(t);
-  }
-
-  auto* function_library = ctx_->function_library();
-  s = function_library->Instantiate(func_->name(), AttrSlice(&func_->attr()),
-                                    &handle_);
-  if (!s.ok()) {
-    DoneAndDelete(s);
-    return;
-  }
-
-  opts_.step_id = CapturedFunction::generate_step_id();
-  auto* step_container =
-      new ScopedStepContainer(opts_.step_id, [this](const string& name) {
-        ctx_->resource_manager()->Cleanup(name).IgnoreError();
-      });
-  opts_.cancellation_manager = ctx_->cancellation_manager();
-  opts_.step_container = step_container;
-  opts_.runner = ctx_->runner();
-
-  function_library->Run(opts_, handle_, arguments_t_, &output_t_,
-                        [this](const Status& s) { DoneAndDelete(s); });
-}
-
-void ExecuteInCriticalSectionOp::Closure::DoneAndDelete(const Status& status) {
-  cs_->Release();
-
-  if (!status.ok()) {
-    ctx_->SetStatus(status);
-  } else {
-    OpOutputList output;
-    const Status s = ctx_->output_list("outputs", &output);
-    if (!s.ok()) {
-      ctx_->SetStatus(s);
-    } else if (output_t_.size() != output.size()) {
-      ctx_->SetStatus(errors::Internal(
-          "Could not set all outputs.  Expected output size is ", output.size(),
-          " but function set ", output_t_.size(), " output values."));
-    } else {
-      for (int i = 0; i < output_t_.size(); ++i) {
-        output.set(i, output_t_[i]);
-      }
-    }
-  }
-
-  delete opts_.step_container;
-  opts_.step_container = nullptr;
-  done_();
-  cs_->Unref();
-  delete this;
-}
-
-ExecuteInCriticalSectionOp::Closure::~Closure() {
-  CHECK(!opts_.step_container)
-      << "Initialized closure destroyed without calling Done";
-}
-
-REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection").Device(DEVICE_CPU),
-                        ExecuteInCriticalSectionOp);
-
-REGISTER_KERNEL_BUILDER(Name("CriticalSectionOp").Device(DEVICE_CPU),
-                        ResourceHandleOp<CriticalSection>);
-
-// TODO(ebrevdo): Re-enable once the cross-device function execution works.
-#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection")
-                            .Device(DEVICE_GPU)
-                            .HostMemory("critical_section"),
-                        ExecuteInCriticalSectionOp);
-REGISTER_KERNEL_BUILDER(
-    Name("CriticalSectionOp").Device(DEVICE_GPU).HostMemory("resource"),
-    ResourceHandleOp<CriticalSection>);
-#endif  // GOOGLE_CUDA
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc
new file mode 100644 (file)
index 0000000..b8b1fc7
--- /dev/null
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <deque>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+class Mutex : public ResourceBase {
+ public:
+  explicit Mutex(OpKernelContext* c, const string& name)
+      : locked_(false),
+        thread_pool_(new thread::ThreadPool(
+            c->env(), ThreadOptions(),
+            strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
+            1 /* num_threads */, false /* low_latency_hint */)),
+        name_(name) {
+    VLOG(2) << "Creating mutex with name " << name << ": " << this;
+  }
+
+  string DebugString() override { return strings::StrCat("Mutex ", name_); }
+
+  class LockReleaser {
+   public:
+    explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
+
+    LockReleaser(const LockReleaser&) = delete;
+    LockReleaser& operator=(const LockReleaser&) = delete;
+
+    virtual ~LockReleaser() {
+      VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
+      if (mutex_) {
+        mutex_lock lock(mutex_->mu_);
+        mutex_->locked_ = false;
+        mutex_->cv_.notify_all();
+        VLOG(3) << "Destroying LockReleaser " << this
+                << ": sent notifications.";
+      }
+    }
+
+   private:
+    Mutex* mutex_;
+  };
+
+  struct SharedLockReleaser {
+    std::shared_ptr<LockReleaser> shared_lock;
+
+    explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
+        : shared_lock(std::forward<decltype(lock)>(lock)) {
+      VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
+              << " count is: " << shared_lock.use_count();
+    }
+
+    SharedLockReleaser(SharedLockReleaser&& rhs)
+        : shared_lock(std::move(rhs.shared_lock)) {
+      VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
+              << " count is: " << shared_lock.use_count();
+    }
+
+    SharedLockReleaser(const SharedLockReleaser& rhs)
+        : shared_lock(rhs.shared_lock) {
+      VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
+              << " count is: " << shared_lock.use_count();
+    }
+
+    ~SharedLockReleaser() {
+      VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
+              << " count is: " << shared_lock.use_count();
+    }
+
+    void Encode(VariantTensorData*) const {
+      // Not supported.
+    }
+
+    bool Decode(const VariantTensorData&) {
+      return false;  // Not supported.
+    }
+  };
+
+  void AcquireAsync(
+      OpKernelContext* c,
+      std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
+    CancellationManager* cm = c->cancellation_manager();
+    CancellationToken token{};
+    bool* cancelled = nullptr;
+    if (cm) {
+      cancelled = new bool(false);  // GUARDED_BY(mu_);
+      token = cm->get_cancellation_token();
+      const bool already_cancelled =
+          !cm->RegisterCallback(token, [this, cancelled]() {
+            mutex_lock lock(mu_);
+            *cancelled = true;
+            cv_.notify_all();
+          });
+      if (already_cancelled) {
+        delete cancelled;
+        fn(errors::Cancelled("Lock acquisition cancelled."),
+           SharedLockReleaser{nullptr});
+        return;
+      }
+    }
+    thread_pool_->Schedule(std::bind(
+        [this, c, cm, cancelled,
+         token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
+                    fn_) {
+          bool local_locked;
+          {
+            mutex_lock lock(mu_);
+            while (locked_ && !(cancelled && *cancelled)) {
+              cv_.wait(lock);
+            }
+            local_locked = locked_ = !(cancelled && *cancelled);
+          }
+          if (cm) {
+            cm->DeregisterCallback(token);
+            delete cancelled;
+          }
+          if (local_locked) {  // Not cancelled.
+            fn_(Status::OK(),
+                SharedLockReleaser{std::make_shared<LockReleaser>(this)});
+          } else {
+            fn_(errors::Cancelled("Lock acqusition cancelled."),
+                SharedLockReleaser{nullptr});
+          }
+        },
+        std::move(fn)));
+  }
+
+ private:
+  mutex mu_;
+  condition_variable cv_ GUARDED_BY(mu_);
+  bool locked_ GUARDED_BY(mu_);
+  std::unique_ptr<thread::ThreadPool> thread_pool_;
+  string name_;
+};
+
+}  // namespace
+
+class MutexLockOp : public AsyncOpKernel {
+ public:
+  explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+ public:
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    Mutex* mutex = nullptr;
+    OP_REQUIRES_OK_ASYNC(
+        c,
+        LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
+                                      [this, c](Mutex** ptr) {
+                                        *ptr = new Mutex(
+                                            c, HandleFromInput(c, 0).name());
+                                        return Status::OK();
+                                      }),
+        done);
+
+    Tensor* variant;
+    OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
+                         done);
+
+    mutex->AcquireAsync(
+        c, std::bind(
+               [this, c, variant, mutex](DoneCallback done_,
+                                         // End of bound arguments.
+                                         const Status& s,
+                                         Mutex::SharedLockReleaser&& lock) {
+                 core::ScopedUnref unref(mutex);
+                 VLOG(2) << "Finished locking mutex " << mutex
+                         << " with lock: " << lock.shared_lock.get()
+                         << " status: " << s.ToString();
+                 if (s.ok()) {
+                   variant->scalar<Variant>()() = std::move(lock);
+                 } else {
+                   c->SetStatus(s);
+                 }
+                 done_();
+               },
+               std::move(done), std::placeholders::_1, std::placeholders::_2));
+  }
+};
+
+class ConsumeMutexLockOp : public OpKernel {
+ public:
+  explicit ConsumeMutexLockOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* c) override {
+    VLOG(2) << "Executing ConsumeMutexLockOp";
+    const Tensor& lock_t = c->input(0);
+    OP_REQUIRES(
+        c, lock_t.dims() == 0,
+        errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
+                                lock_t.shape().DebugString()));
+    OP_REQUIRES(
+        c, lock_t.dtype() == DT_VARIANT,
+        errors::InvalidArgument("Expected input to be a variant, saw type: ",
+                                DataTypeString(lock_t.dtype())));
+    const auto* lock =
+        lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
+    OP_REQUIRES(c, lock,
+                errors::InvalidArgument(
+                    "Expected input to contain a SharedLockReleaser "
+                    "object, but saw variant: '",
+                    lock_t.scalar<Variant>()().DebugString(), "'"));
+    const int use_count = lock->shared_lock.use_count();
+    OP_REQUIRES(
+        c, use_count == 1,
+        errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
+                                use_count));
+  }
+
+  bool IsExpensive() override { return false; }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
+
+REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_CPU),
+                        ResourceHandleOp<Mutex>);
+
+REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
+                        ConsumeMutexLockOp);
+
+}  // namespace tensorflow
index 7da2365f62ae57c288666137eb973936d10767d1..3fb17d92d2bec73e0712b06692ceb3ce4e39bd3d 100644 (file)
@@ -12814,28 +12814,6 @@ op {
     }
   }
 }
-op {
-  name: "CriticalSectionOp"
-  output_arg {
-    name: "resource"
-    type: DT_RESOURCE
-  }
-  attr {
-    name: "container"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  attr {
-    name: "shared_name"
-    type: "string"
-    default_value {
-      s: ""
-    }
-  }
-  is_stateful: true
-}
 op {
   name: "CropAndResize"
   input_arg {
@@ -17433,78 +17411,6 @@ op {
     }
   }
 }
-op {
-  name: "ExecuteInCriticalSection"
-  input_arg {
-    name: "critical_section"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "outputs"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-    minimum: 1
-  }
-  is_stateful: true
-}
-op {
-  name: "ExecuteInCriticalSection"
-  input_arg {
-    name: "critical_section"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "arguments"
-    type_list_attr: "Targuments"
-  }
-  output_arg {
-    name: "outputs"
-    type_list_attr: "output_types"
-  }
-  attr {
-    name: "f"
-    type: "func"
-  }
-  attr {
-    name: "Targuments"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_types"
-    type: "list(type)"
-    has_minimum: true
-  }
-  attr {
-    name: "output_shapes"
-    type: "list(shape)"
-    has_minimum: true
-  }
-  is_stateful: true
-}
 op {
   name: "Exit"
   input_arg {
index 8dae7e1ff5f872c33dd56509c0349180cec78593..0d8cf78cc2a196cde4a77f53ce912c437648786a 100644 (file)
@@ -211,7 +211,7 @@ REGISTER_OP("ResourceScatterUpdate")
       return Status::OK();
     });
 
-REGISTER_OP("CriticalSectionOp")
+REGISTER_OP("MutexV2")
     .Attr("container: string = ''")
     .Attr("shared_name: string = ''")
     .Output("resource: resource")
@@ -221,24 +221,18 @@ REGISTER_OP("CriticalSectionOp")
       return Status::OK();
     });
 
-REGISTER_OP("ExecuteInCriticalSection")
-    .Input("critical_section: resource")
-    .Input("arguments: Targuments")
-    .Output("outputs: output_types")
-    .Attr("f: func")
-    .Attr("Targuments: list(type) >= 0")
-    .Attr("output_types: list(type) >= 0")
-    .Attr("output_shapes: list(shape) >= 0")
+REGISTER_OP("MutexLock")
+    .Input("mutex: resource")
+    .Output("mutex_lock: variant")
+    .SetIsStateful()
     .SetShapeFn([](InferenceContext* c) {
-      std::vector<PartialTensorShape> output_shapes;
-      TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
-      for (int i = 0; i < output_shapes.size(); ++i) {
-        ShapeHandle s;
-        TF_RETURN_IF_ERROR(
-            c->MakeShapeFromPartialTensorShape(output_shapes[i], &s));
-        c->set_output(i, s);
-      }
+      c->set_output(0, c->Scalar());
       return Status::OK();
     });
 
+REGISTER_OP("ConsumeMutexLock")
+    .Input("mutex_lock: variant")
+    .SetIsStateful()
+    .SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+
 }  // namespace tensorflow
index 28f5289ffc0ace6f9b6cad7cdd1160a184f882c7..b3317bd3235f432220d9d5d135f1af18a6f43310 100644 (file)
@@ -196,33 +196,66 @@ ops.register_tensor_conversion_function(
     ops.EagerTensor, _convert_to_graph_tensor, priority=-1)
 
 
-class _CapturingContext(object):
-  """Tracks references to Tensors outside this context while it is active."""
+# pylint: disable=invalid-name
+class HelperContext(object):
+  """ControlFlowContext with a customizable AddOp method."""
 
-  def __init__(self):
-    # known_ops are ops which are created while this context is active
-    self.known_ops = set()
+  def __init__(self, add_op_internal):
+    self._add_op_internal = add_op_internal
+    self._values = set()  # control flow code sometimes updates this.
+
+  def _AddOpInternal(self, op):
+    self._add_op_internal(op)
+
+  @property
+  def outer_context(self):
+    return self._outer_context
+
+  def GetWhileContext(self):
+    if self._outer_context:
+      return self._outer_context.GetWhileContext()
+
+  def IsWhileContext(self):
+    return False
+
+  def IsCondContext(self):
+    return False
 
-    # captured_tensors are all tensors referenced to by ops in this context but
-    # not produced in it
-    self.captured_tensors = set()
+  def IsXLAContext(self):
+    return False
 
   def AddOp(self, op):  # pylint: disable=invalid-name
-    if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
-      raise ValueError("tfe.defun cannot capture variables created without "
-                       "using tf.get_variable. Op: %s" % op)
-    self.known_ops.add(op)
-    for i in op.inputs:
-      if i.op not in self.known_ops:
-        self.captured_tensors.add(i)
+    self._AddOpInternal(op)
+    if self._outer_context:
+      self._outer_context.AddOp(op)
+
+  def AddName(self, _):
+    pass
+
+  def AddInnerOp(self, op):
+    self._AddOpInternal(op)
+    if self._outer_context:
+      self._outer_context.AddInnerOp(op)
+
+  def AddValue(self, val):
+    if self._outer_context:
+      return self._outer_context.AddValue(val)
+    else:
+      return val
 
   def __enter__(self):
+    # pylint: disable=protected-access
     self._g = ops.get_default_graph()
-    self._old = self._g._get_control_flow_context()  # pylint: disable=protected-access
-    self._g._set_control_flow_context(self)  # pylint: disable=protected-access
+    self._outer_context = self._g._get_control_flow_context()
+    self._g._set_control_flow_context(self)
+    self._nested_contexts = (
+        self._outer_context._nested_contexts
+        if self._outer_context is not None else None)
+    # pylint: enable=protected-access
 
-  def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
-    self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access
+  def __exit__(self, *_):
+    self._g._set_control_flow_context(self._outer_context)  # pylint: disable=protected-access
+# pylint: enable=invalid-name
 
 
 def _forward_name(n):
@@ -368,7 +401,20 @@ class GraphModeFunction(object):
   def _construct_backprop_function(self):
     """Constructs the backprop function object for this function."""
     with self._graph.as_default(), context.graph_mode():
-      c = _CapturingContext()
+      c_known_ops = set()
+      c_captured_tensors = set()
+
+      def add_op_internal(op):
+        if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
+          raise ValueError("tfe.defun cannot capture variables created without "
+                           "using tf.get_variable. Op: %s" % op)
+        c_known_ops.add(op)
+        for i in op.inputs:
+          if i.op not in c_known_ops:
+            c_captured_tensors.add(i)
+
+      c = HelperContext(add_op_internal)
+
       with c:
         filtered_outputs = [x for x in self._returns if x is not None]
         self._out_grad_placeholders = [
@@ -382,7 +428,7 @@ class GraphModeFunction(object):
         grad for grad in _flatten(in_gradients) if grad is not None)
     output_shapes = tuple(grad.shape for grad in backward_outputs)
 
-    captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
+    captures = list(sorted(c_captured_tensors, key=lambda x: x.name))
     forward_name = _forward_name(self._func_name)
     self._forward_fdef = _EagerDefinedFunction(
         forward_name, self._graph, self._ops, self._input_placeholders,
@@ -395,7 +441,7 @@ class GraphModeFunction(object):
     # means rerunning the function-defining code will always define the same
     # function, which is useful if we serialize this etc.
     function_def_ops = tuple(x
-                             for x in sorted(c.known_ops, key=lambda x: x.name)
+                             for x in sorted(c_known_ops, key=lambda x: x.name)
                              if x not in all_ignored_ops)
     bname = _backward_name(self._func_name)
     self._backward_function = GraphModeFunction(
index b4bfc0fe47402e588fc831b5936ddd65aa93717a..c78a5aa8c2227f016596f9727f7ee6e205843a08 100644 (file)
@@ -3477,7 +3477,12 @@ def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined
   if context.in_eager_mode():
     return tensors
   with ops.name_scope(name, "tuple", tensors) as name:
-    gating_ops = [t.op for t in tensors if t is not None]
+    tensors = [t if (isinstance(t, ops.Operation)
+                     or tensor_util.is_tensor(t)
+                     or t is None)
+               else ops.convert_to_tensor(t) for t in tensors]
+    gating_ops = [t if isinstance(t, ops.Operation) else t.op for t in tensors
+                  if t is not None]
     if control_inputs:
       for c in control_inputs:
         if isinstance(c, ops.Tensor):
@@ -3493,8 +3498,11 @@ def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined
     gate = group(*gating_ops)
     tpl = []
     for t in tensors:
-      if t is not None:
+      if tensor_util.is_tensor(t):
         tpl.append(with_dependencies([gate], t))
+      elif isinstance(t, ops.Operation):
+        with ops.control_dependencies([gate]):
+          tpl.append(group(t))
       else:
         tpl.append(None)
     return tpl