additional_deps = [
"//tensorflow/python:client_testlib",
":framework_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:tensor_array_ops",
],
)
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
class _ExecutionSignature(
collections.namedtuple("_ExecutionSignature",
- ("op", "exclusive_resource_access"))):
+ ("op", "handle",
+ "resources", "exclusive_resource_access"))):
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
pass
```
"""
- def __init__(self, name=None, critical_section_def=None, import_scope=None):
+ def __init__(self, name=None, shared_name=None,
+ critical_section_def=None, import_scope=None):
"""Creates a critical section."""
if critical_section_def and name is not None:
- raise ValueError("critical_section_def and name are mutually exclusive.")
+ raise ValueError("critical_section_def and shared_name are "
+ "mutually exclusive.")
if critical_section_def:
self._init_from_proto(critical_section_def, import_scope=import_scope)
else:
- self._init_from_args(name)
+ self._init_from_args(name, shared_name)
- def _init_from_proto(self, critical_section_def, import_scope):
+ def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
raise NotImplementedError("Not yet implemented")
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# assert isinstance(
# critical_section_def.critical_section_name,
# import_scope=import_scope))
- def _init_from_args(self, name):
+ def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.control_dependencies(None):
# pylint: disable=protected-access
- handle_name = ops._name_from_scope_name(name)
container = ops.get_default_graph()._container
# pylint: enable=protected-access
+ if shared_name is None:
+ shared_name = name
if container is None:
container = ""
- self._handle = gen_resource_variable_ops.critical_section_op(
- shared_name=handle_name, name=name)
+ self._handle = gen_resource_variable_ops.mutex_v2(
+ shared_name=shared_name, container=container, name=name)
+
if context.in_graph_mode():
ops.add_to_collections(CRITICAL_SECTIONS, self)
name = kwargs.pop("name", None)
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
- args = nest.map_structure(ops.convert_to_tensor, args)
with ops.name_scope(name, "critical_section_execute", []):
- fn_op = function.make_defun_op(fn, *args, **kwargs)
- flat_dtypes = nest.flatten(fn_op.output_dtypes)
- flat_shapes = nest.flatten(fn_op.output_shapes)
- all_inputs = nest.flatten(args) + fn_op.captured_inputs
- if self._handle in all_inputs:
+ lock = gen_resource_variable_ops.mutex_lock(self._handle)
+
+ with ops.control_dependencies([lock]):
+ c_known_ops = set()
+ c_captured_tensors = set()
+
+ def add_op_internal(op):
+ c_known_ops.add(op)
+ for i in op.inputs:
+ if i.op not in c_known_ops:
+ c_captured_tensors.add(i)
+
+ c = function.HelperContext(add_op_internal)
+ with c:
+ r = fn(*args, **kwargs)
+
+ resource_inputs = set([
+ x for x in
+ list(nest.flatten(args)) + nest.flatten(kwargs.values()) +
+ list(c_captured_tensors)
+ if tensor_util.is_tensor(x) and x.dtype == dtypes.resource])
+
+ if self._handle in resource_inputs:
raise ValueError("The function fn attempts to access the "
- "CriticalSection in which it would be running. This "
- "is illegal and would cause deadlocks. "
+ "CriticalSection in which it would be running. "
+ "This is illegal and would cause deadlocks. "
"CriticalSection: %s." % self._handle)
if context.in_graph_mode():
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
- all_input_resources = [
- x for x in all_inputs if x.dtype == dtypes.resource]
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
- if sg.op.inputs[0].name == self._handle.name:
+ if sg.handle.name == self._handle.name:
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
- sg_input_names = [y.name for y in sg.op.inputs[1:]]
- for res in all_input_resources:
- if res.name in sg_input_names:
- raise ValueError(
- "This execution would access resource %s; but either this "
- "execution (CriticalSection: %s) or Execution '%s' "
- "(CriticalSection: %s) requested exclusive resource access "
- "of this resource for their critical section. Did you mean "
- "to call execute with keyword argument "
- "exclusive_resource_access=False?"
- % (res.name,
- self.name,
- sg.op.name,
- sg.op.inputs[0].op.name))
-
- flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
- critical_section=self._handle,
- arguments=all_inputs,
- f=fn_op,
- output_types=flat_dtypes,
- output_shapes=flat_shapes)
+ resource_intersection = resource_inputs.intersection(sg.resources)
+ if resource_intersection:
+ raise ValueError(
+ "This execution would access resources: %s. Either this "
+ "lock (CriticalSection: %s) or lock '%s' "
+ "(CriticalSection: %s) requested exclusive resource access "
+ "of this resource. Did you mean to call execute with keyword "
+ "argument exclusive_resource_access=False?" %
+ (list(resource_intersection), self._handle.name,
+ sg.op.name, sg.handle.name))
+
+ def identity(x): # pylint: disable=invalid-name
+ if isinstance(x, tensor_array_ops.TensorArray):
+ return x.identity()
+ elif isinstance(x, ops.Operation):
+ return control_flow_ops.group(x)
+ elif context.in_eager_mode() and x is None:
+ return None
+ else:
+ return array_ops.identity(x)
+
+ r_flat = [identity(x) for x in nest.flatten(r)]
+
+ with ops.control_dependencies(r_flat):
+ # The identity must run on the same machine as self._handle
+ with ops.colocate_with(self._handle):
+ # Do not use array_ops.identity as there are special
+ # optimizations within TensorFlow which seem to elide it
+ # even when optimizations are disabled(!).
+ ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
+ lock)
+
+ # Make sure that if any element of r is accessed, all of
+ # them are executed together.
+ r = nest.pack_sequence_as(
+ r, control_flow_ops.tuple(nest.flatten(r)))
+
+ with ops.control_dependencies([ensure_lock_exists]):
+ outputs = nest.map_structure(identity, r)
if context.in_graph_mode():
- if isinstance(flat_outputs, ops.Operation):
- flat_outputs = [flat_outputs]
- op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
- else flat_outputs[0])
signature = _ExecutionSignature(
- op=op,
+ op=lock.op,
+ handle=self._handle,
+ resources=list(resource_inputs),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
- return (flat_outputs[0]
- if (len(flat_outputs) == 1
- and isinstance(flat_outputs[0], ops.Operation))
- else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
+ return outputs
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def _execution_to_proto_fn(execution_signature, export_scope=None):
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
+# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# Args:
# execution_signature: Instance of `_ExecutionSignature`.
# def _execution_from_proto_fn(op_def, import_scope=None):
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
+# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# assert isinstance(
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import critical_section_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
@test_util.run_in_graph_and_eager_modes()
def testCreateCriticalSection(self):
- cs = critical_section_ops.CriticalSection(name="cs")
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn(a, b):
with ops.control_dependencies([nv]):
return array_ops.identity(c)
- num_concurrent = 1000
+ num_concurrent = 100
r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
self.evaluate(v.initializer)
r_value = self.evaluate(r)
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
sorted(r_value))
+ @test_util.run_in_graph_and_eager_modes()
+ def testCriticalSectionWithControlFlow(self):
+ for outer_cond in [False, True]:
+ for inner_cond in [False, True]:
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
+ v = resource_variable_ops.ResourceVariable(0.0, name="v")
+ num_concurrent = 100
+
+ # pylint: disable=cell-var-from-loop
+ def fn(a, b):
+ c = v.read_value()
+ def true_fn():
+ with ops.control_dependencies([c]):
+ nv = v.assign_add(a * b)
+ with ops.control_dependencies([nv]):
+ return array_ops.identity(c)
+ return control_flow_ops.cond(
+ array_ops.identity(inner_cond), true_fn, lambda: c)
+
+ def execute():
+ return cs.execute(fn, 1.0, 2.0)
+
+ r = [
+ control_flow_ops.cond(array_ops.identity(outer_cond),
+ execute,
+ v.read_value)
+ for _ in range(num_concurrent)
+ ]
+ # pylint: enable=cell-var-from-loop
+
+ self.evaluate(v.initializer)
+ r_value = self.evaluate(r)
+ if inner_cond and outer_cond:
+ self.assertAllClose([2.0 * i for i in range(num_concurrent)],
+ sorted(r_value))
+ else:
+ self.assertAllClose([0] * num_concurrent, r_value)
+
+ def testCriticalSectionInParallelDoesntDeadlockOnError(self):
+ # No eager mode execution of this test because eager does not
+ # run fn() in parallel, which is where the deadlock could
+ # potentially occur (in graph mode).
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
+ v = resource_variable_ops.ResourceVariable(0.0, name="v")
+
+ def fn(i):
+ error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
+ with ops.control_dependencies([error]):
+ return v.read_value()
+ num_concurrent = 2
+ r = [cs.execute(fn, i) for i in range(num_concurrent)]
+ self.evaluate(v.initializer)
+ for _ in range(100):
+ with self.assertRaisesOpError("Error"):
+ self.evaluate(r)
+
@test_util.run_in_graph_and_eager_modes()
def testCreateCriticalSectionFnReturnsOp(self):
- cs = critical_section_ops.CriticalSection(name="cs")
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn_return_op(a, b):
with ops.control_dependencies([c]):
nv = v.assign_add(a * b)
with ops.control_dependencies([nv]):
- return ()
+ return control_flow_ops.no_op()
num_concurrent = 100
r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
final_v = self.evaluate(v)
self.assertAllClose(2.0 * num_concurrent, final_v)
- def testCreateCriticalSectionRaw(self):
- cs = critical_section_ops.CriticalSection(name="cs")
- v = resource_variable_ops.ResourceVariable(0.0, name="v")
-
- @function.Defun(dtypes.float32, dtypes.float32)
- def fn(a, b):
- c = v.read_value()
- with ops.control_dependencies([c]):
- nv = v.assign_add(a * b)
- with ops.control_dependencies([nv]):
- return array_ops.identity(c)
-
- def execute(fn, *args):
- output_args = fn.definition.signature.output_arg
- return resource_variable_ops.execute_in_critical_section(
- critical_section=cs._handle,
- arguments=list(args) + fn.captured_inputs,
- f=fn,
- output_types=[out.type for out in output_args],
- output_shapes=[tensor_shape.TensorShape(None) for _ in output_args])
-
- num_concurrent = 1000
- r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)]
- self.evaluate(v.initializer)
- r_value = self.evaluate(r)
- self.assertAllClose([2.0 * i for i in range(num_concurrent)],
- sorted(r_value))
-
def testCollection(self):
- cs = critical_section_ops.CriticalSection(name="cs")
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
self.assertIn(
cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
- execute_op = cs.execute(lambda x: x + 1, 1.0).op
+ execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
+ execute_op = [
+ x for x in execute.graph.get_operations()
+ if "my_execute" in x.name and "MutexLock" in x.type
+ ][0]
self.assertIn(
execute_op,
[signature.op for signature in
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
- @test_util.run_in_graph_and_eager_modes()
def testRecursiveCriticalSectionAccessIsIllegal(self):
- cs = critical_section_ops.CriticalSection(name="cs")
+ # This does not work properly in eager mode. Eager users will
+ # just hit a deadlock if they do this. But at least it'll be easier
+ # to debug.
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
def fn(x):
return cs.execute(lambda x: x+1, x)
with self.assertRaisesRegexp(
# self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
# def testToProto(self):
- # cs = critical_section_ops.CriticalSection(name="cs")
+ # cs = critical_section_ops.CriticalSection(shared_name="cs")
# proto = cs.to_proto()
# self.assertEqual(proto.critical_section_name, cs._handle.name)
# cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
--- /dev/null
+op {
+ graph_op_name: "ConsumeMutexLock"
+ in_arg {
+ name: "mutex_lock"
+ description: <<END
+A tensor returned by `MutexLock`.
+END
+ }
+ summary: "This op consumes a lock created by `MutexLock`."
+ description: <<END
+This op exists to consume a tensor created by `MutexLock` (other than
+direct control dependencies). It should be the only that consumes the tensor,
+and will raise an error if it is not. Its only purpose is to keep the
+mutex lock tensor alive until it is consumed by this op.
+
+**NOTE**: This operation must run on the same device as its input. This may
+be enforced via the `colocate_with` mechanism.
+END
+}
+++ /dev/null
-op {
- graph_op_name: "CriticalSectionOp"
- attr {
- name: "container"
- description: <<END
-the container this critical section is placed in.
-END
- }
- attr {
- name: "shared_name"
- description: <<END
-the name by which this critical section is referred to.
-END
- }
- summary: "Creates a handle to a CriticalSection resource."
-}
+++ /dev/null
-op {
- graph_op_name: "ExecuteInCriticalSection"
- in_arg {
- name: "critical_section"
- description: <<END
-The handle of the `critical_section`.
-END
- }
- in_arg {
- name: "arguments"
- description: <<END
-Arguments for `f`, including any captured inputs appended at the end.
-END
- }
- out_arg {
- name: "outputs"
- description: <<END
-The outputs of `f`.
-END
- }
- attr {
- name: "f"
- description: <<END
-The `Function` to execute.
-END
- }
- summary: "Executes function `f` within critical section `critical_section`."
- description: <<END
-While `f` is running in `critical_section`, no other functions which wish to
-use this critical section may run.
-
-Often the use case is that two executions of the same graph, in parallel,
-wish to run `f`; and we wish to ensure that only one of them executes
-at a time. This is especially important if `f` modifies one or more
-variables at a time.
-
-It is also useful if two separate functions must share a resource, but we
-wish to ensure the usage is exclusive.
-
-The signature of `f` is expected to be:
-
-```
- outputs <- F(arguments)
-```
-Typically, but this is not required, `arguments` contain resources. The
-primary purpose of this op is to limit access to these resources to one
-execution of `F` at a time.
-END
-}
--- /dev/null
+op {
+ graph_op_name: "MutexLock"
+ in_arg {
+ name: "mutex"
+ description: <<END
+The mutex resource to lock.
+END
+ }
+ out_arg {
+ name: "mutex_lock"
+ description: <<END
+A tensor that keeps a shared pointer to a lock on the mutex;
+when the Tensor is destroyed, the use count on the shared pointer is decreased
+by 1. When it reaches 0, the lock is released.
+END
+ }
+ summary: "Locks a mutex resource. The output is the lock. So long as the lock tensor"
+ description: <<END
+is alive, any other request to use `MutexLock` with this mutex will wait.
+
+This is particularly useful for creating a critical section when used in
+conjunction with `MutexLockIdentity`:
+
+```python
+
+mutex = mutex_v2(
+ shared_name=handle_name, container=container, name=name)
+
+def execute_in_critical_section(fn, *args, **kwargs):
+ lock = gen_resource_variable_ops.mutex_lock(mutex)
+
+ with ops.control_dependencies([lock]):
+ r = fn(*args, **kwargs)
+
+ with ops.control_dependencies(nest.flatten(r)):
+ with ops.colocate_with(mutex):
+ ensure_lock_exists = mutex_lock_identity(lock)
+
+ # Make sure that if any element of r is accessed, all of
+ # them are executed together.
+ r = nest.map_structure(tf.identity, r)
+
+ with ops.control_dependencies([ensure_lock_exists]):
+ return nest.map_structure(tf.identity, r)
+```
+
+While `fn` is running in the critical section, no other functions which wish to
+use this critical section may run.
+
+Often the use case is that two executions of the same graph, in parallel,
+wish to run `fn`; and we wish to ensure that only one of them executes
+at a time. This is especially important if `fn` modifies one or more
+variables at a time.
+
+It is also useful if two separate functions must share a resource, but we
+wish to ensure the usage is exclusive.
+END
+}
--- /dev/null
+op {
+ graph_op_name: "MutexV2"
+ out_arg {
+ name: "resource"
+ description: <<END
+The mutex resource.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this variable is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this variable is named in the given bucket
+with this shared_name. Otherwise, the node name is used instead.
+END
+ }
+ summary: "Creates a Mutex resource that can be locked by `MutexLock`."
+}
srcs = ["resource_variable_ops.cc"],
deps = [
":bounds_check",
- ":critical_section",
":dense_update_functor",
":gather_functor",
+ ":mutex_ops",
":scatter_functor",
":state",
":training_op_helpers",
)
tf_kernel_library(
- name = "critical_section",
- prefix = "critical_section",
- deps = STATE_DEPS + [":captured_function"],
+ name = "mutex_ops",
+ prefix = "mutex_ops",
+ deps = STATE_DEPS + [":ops_util"],
)
tf_cc_test(
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",
- "critical_section.*",
+ "mutex_ops.*",
"batch_kernels.*",
],
),
+++ /dev/null
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include <deque>
-#include <utility>
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/captured_function.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-class CriticalSection : public ResourceBase {
- public:
- explicit CriticalSection() : is_locked_(false) {}
- ~CriticalSection() override {
- // Wait for all closures to finish running.
- mutex_lock lock(mu_);
- while (!closures_.empty()) {
- queue_empty_cv_.wait(lock);
- }
- }
-
- private:
- friend class ExecuteInCriticalSectionOp;
-
- void Acquire(std::function<void()> closure) {
- std::function<void()> next;
- {
- mutex_lock ml(mu_);
- if (is_locked_) {
- closures_.push_back(std::move(closure));
- } else {
- // This branch is the common case. Avoid the queue.
- is_locked_ = true;
- next = std::move(closure);
- }
- }
- if (next) {
- next();
- }
- }
-
- void Release() {
- std::function<void()> next;
- {
- mutex_lock ml(mu_);
- CHECK(is_locked_);
- if (!closures_.empty()) {
- // if queue is not empty, start the next entry off the queue.
- std::swap(next, closures_.front());
- closures_.pop_front();
- } else {
- is_locked_ = false;
- queue_empty_cv_.notify_all();
- }
- }
- if (next) {
- next();
- }
- }
-
- string DebugString() override {
- tf_shared_lock ml(mu_);
- return strings::StrCat("CriticalSection(locked: ", is_locked_,
- " queue_size: ", closures_.size(), ")");
- }
-
- private:
- mutex mu_;
- std::deque<std::function<void()>> closures_ GUARDED_BY(mu_);
- bool is_locked_ GUARDED_BY(mu_);
- condition_variable queue_empty_cv_ GUARDED_BY(mu_);
-};
-
-class ExecuteInCriticalSectionOp : public AsyncOpKernel {
- public:
- explicit ExecuteInCriticalSectionOp(OpKernelConstruction* c)
- : AsyncOpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
- }
-
- public:
- void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
- CriticalSection* critical_section = nullptr;
- OP_REQUIRES_OK_ASYNC(c,
- LookupOrCreateResource<CriticalSection>(
- c, HandleFromInput(c, 0), &critical_section,
- [this, c](CriticalSection** ptr) {
- *ptr = new CriticalSection;
- return Status::OK();
- }),
- done);
- // No need to Unref critical_section; the Closure below will take
- // care of the Unref associated with this execution.
-
- auto* execution = new Closure{std::move(done), c, critical_section, &func_};
- execution->Start();
- }
-
- private:
- class Closure {
- public:
- AsyncOpKernel::DoneCallback done_;
- OpKernelContext* ctx_;
- CriticalSection* cs_;
- FunctionLibraryRuntime::Handle handle_;
- FunctionLibraryRuntime::Options opts_;
- std::vector<Tensor> arguments_t_;
- std::vector<Tensor> output_t_;
- NameAttrList* func_;
-
- explicit Closure(AsyncOpKernel::DoneCallback done, OpKernelContext* ctx,
- CriticalSection* critical_section, NameAttrList* func)
- : done_(std::move(done)),
- ctx_(ctx),
- cs_(critical_section),
- handle_(-1),
- func_(func) {}
-
- ~Closure();
-
- void Start() {
- // Perform ExecuteFunction isnide a separate thread to avoid
- // having lightweight Functions be inlined in this thread.
- // That inlining would in turn inline DoneAndDelete inside the
- // same thread. Since DoneAndDelete can call the next
- // ExecuteFunction in the CriticalSection, this can cause a
- // stack overflow.
- cs_->Acquire(
- [this]() { (*ctx_->runner())([this]() { ExecuteFunction(); }); });
- }
-
- private:
- void ExecuteFunction();
- void DoneAndDelete(const Status& status);
- };
-
- NameAttrList func_;
-};
-
-void ExecuteInCriticalSectionOp::Closure::ExecuteFunction() {
- // Arguments to a Function are in the order:
- // concat(<formal arguments>, <captured arguments>)
- OpInputList arguments;
- Status s = ctx_->input_list("arguments", &arguments);
- if (!s.ok()) {
- DoneAndDelete(s);
- return;
- }
-
- arguments_t_.reserve(arguments.size());
- for (const Tensor& t : arguments) {
- arguments_t_.push_back(t);
- }
-
- auto* function_library = ctx_->function_library();
- s = function_library->Instantiate(func_->name(), AttrSlice(&func_->attr()),
- &handle_);
- if (!s.ok()) {
- DoneAndDelete(s);
- return;
- }
-
- opts_.step_id = CapturedFunction::generate_step_id();
- auto* step_container =
- new ScopedStepContainer(opts_.step_id, [this](const string& name) {
- ctx_->resource_manager()->Cleanup(name).IgnoreError();
- });
- opts_.cancellation_manager = ctx_->cancellation_manager();
- opts_.step_container = step_container;
- opts_.runner = ctx_->runner();
-
- function_library->Run(opts_, handle_, arguments_t_, &output_t_,
- [this](const Status& s) { DoneAndDelete(s); });
-}
-
-void ExecuteInCriticalSectionOp::Closure::DoneAndDelete(const Status& status) {
- cs_->Release();
-
- if (!status.ok()) {
- ctx_->SetStatus(status);
- } else {
- OpOutputList output;
- const Status s = ctx_->output_list("outputs", &output);
- if (!s.ok()) {
- ctx_->SetStatus(s);
- } else if (output_t_.size() != output.size()) {
- ctx_->SetStatus(errors::Internal(
- "Could not set all outputs. Expected output size is ", output.size(),
- " but function set ", output_t_.size(), " output values."));
- } else {
- for (int i = 0; i < output_t_.size(); ++i) {
- output.set(i, output_t_[i]);
- }
- }
- }
-
- delete opts_.step_container;
- opts_.step_container = nullptr;
- done_();
- cs_->Unref();
- delete this;
-}
-
-ExecuteInCriticalSectionOp::Closure::~Closure() {
- CHECK(!opts_.step_container)
- << "Initialized closure destroyed without calling Done";
-}
-
-REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection").Device(DEVICE_CPU),
- ExecuteInCriticalSectionOp);
-
-REGISTER_KERNEL_BUILDER(Name("CriticalSectionOp").Device(DEVICE_CPU),
- ResourceHandleOp<CriticalSection>);
-
-// TODO(ebrevdo): Re-enable once the cross-device function execution works.
-#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection")
- .Device(DEVICE_GPU)
- .HostMemory("critical_section"),
- ExecuteInCriticalSectionOp);
-REGISTER_KERNEL_BUILDER(
- Name("CriticalSectionOp").Device(DEVICE_GPU).HostMemory("resource"),
- ResourceHandleOp<CriticalSection>);
-#endif // GOOGLE_CUDA
-
-} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <deque>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+class Mutex : public ResourceBase {
+ public:
+ explicit Mutex(OpKernelContext* c, const string& name)
+ : locked_(false),
+ thread_pool_(new thread::ThreadPool(
+ c->env(), ThreadOptions(),
+ strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
+ 1 /* num_threads */, false /* low_latency_hint */)),
+ name_(name) {
+ VLOG(2) << "Creating mutex with name " << name << ": " << this;
+ }
+
+ string DebugString() override { return strings::StrCat("Mutex ", name_); }
+
+ class LockReleaser {
+ public:
+ explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
+
+ LockReleaser(const LockReleaser&) = delete;
+ LockReleaser& operator=(const LockReleaser&) = delete;
+
+ virtual ~LockReleaser() {
+ VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
+ if (mutex_) {
+ mutex_lock lock(mutex_->mu_);
+ mutex_->locked_ = false;
+ mutex_->cv_.notify_all();
+ VLOG(3) << "Destroying LockReleaser " << this
+ << ": sent notifications.";
+ }
+ }
+
+ private:
+ Mutex* mutex_;
+ };
+
+ struct SharedLockReleaser {
+ std::shared_ptr<LockReleaser> shared_lock;
+
+ explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
+ : shared_lock(std::forward<decltype(lock)>(lock)) {
+ VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
+ << " count is: " << shared_lock.use_count();
+ }
+
+ SharedLockReleaser(SharedLockReleaser&& rhs)
+ : shared_lock(std::move(rhs.shared_lock)) {
+ VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
+ << " count is: " << shared_lock.use_count();
+ }
+
+ SharedLockReleaser(const SharedLockReleaser& rhs)
+ : shared_lock(rhs.shared_lock) {
+ VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
+ << " count is: " << shared_lock.use_count();
+ }
+
+ ~SharedLockReleaser() {
+ VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
+ << " count is: " << shared_lock.use_count();
+ }
+
+ void Encode(VariantTensorData*) const {
+ // Not supported.
+ }
+
+ bool Decode(const VariantTensorData&) {
+ return false; // Not supported.
+ }
+ };
+
+ void AcquireAsync(
+ OpKernelContext* c,
+ std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
+ CancellationManager* cm = c->cancellation_manager();
+ CancellationToken token{};
+ bool* cancelled = nullptr;
+ if (cm) {
+ cancelled = new bool(false); // GUARDED_BY(mu_);
+ token = cm->get_cancellation_token();
+ const bool already_cancelled =
+ !cm->RegisterCallback(token, [this, cancelled]() {
+ mutex_lock lock(mu_);
+ *cancelled = true;
+ cv_.notify_all();
+ });
+ if (already_cancelled) {
+ delete cancelled;
+ fn(errors::Cancelled("Lock acquisition cancelled."),
+ SharedLockReleaser{nullptr});
+ return;
+ }
+ }
+ thread_pool_->Schedule(std::bind(
+ [this, c, cm, cancelled,
+ token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
+ fn_) {
+ bool local_locked;
+ {
+ mutex_lock lock(mu_);
+ while (locked_ && !(cancelled && *cancelled)) {
+ cv_.wait(lock);
+ }
+ local_locked = locked_ = !(cancelled && *cancelled);
+ }
+ if (cm) {
+ cm->DeregisterCallback(token);
+ delete cancelled;
+ }
+ if (local_locked) { // Not cancelled.
+ fn_(Status::OK(),
+ SharedLockReleaser{std::make_shared<LockReleaser>(this)});
+ } else {
+ fn_(errors::Cancelled("Lock acqusition cancelled."),
+ SharedLockReleaser{nullptr});
+ }
+ },
+ std::move(fn)));
+ }
+
+ private:
+ mutex mu_;
+ condition_variable cv_ GUARDED_BY(mu_);
+ bool locked_ GUARDED_BY(mu_);
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ string name_;
+};
+
+} // namespace
+
+class MutexLockOp : public AsyncOpKernel {
+ public:
+ explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+ public:
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ Mutex* mutex = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c,
+ LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
+ [this, c](Mutex** ptr) {
+ *ptr = new Mutex(
+ c, HandleFromInput(c, 0).name());
+ return Status::OK();
+ }),
+ done);
+
+ Tensor* variant;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
+ done);
+
+ mutex->AcquireAsync(
+ c, std::bind(
+ [this, c, variant, mutex](DoneCallback done_,
+ // End of bound arguments.
+ const Status& s,
+ Mutex::SharedLockReleaser&& lock) {
+ core::ScopedUnref unref(mutex);
+ VLOG(2) << "Finished locking mutex " << mutex
+ << " with lock: " << lock.shared_lock.get()
+ << " status: " << s.ToString();
+ if (s.ok()) {
+ variant->scalar<Variant>()() = std::move(lock);
+ } else {
+ c->SetStatus(s);
+ }
+ done_();
+ },
+ std::move(done), std::placeholders::_1, std::placeholders::_2));
+ }
+};
+
+class ConsumeMutexLockOp : public OpKernel {
+ public:
+ explicit ConsumeMutexLockOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ VLOG(2) << "Executing ConsumeMutexLockOp";
+ const Tensor& lock_t = c->input(0);
+ OP_REQUIRES(
+ c, lock_t.dims() == 0,
+ errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
+ lock_t.shape().DebugString()));
+ OP_REQUIRES(
+ c, lock_t.dtype() == DT_VARIANT,
+ errors::InvalidArgument("Expected input to be a variant, saw type: ",
+ DataTypeString(lock_t.dtype())));
+ const auto* lock =
+ lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
+ OP_REQUIRES(c, lock,
+ errors::InvalidArgument(
+ "Expected input to contain a SharedLockReleaser "
+ "object, but saw variant: '",
+ lock_t.scalar<Variant>()().DebugString(), "'"));
+ const int use_count = lock->shared_lock.use_count();
+ OP_REQUIRES(
+ c, use_count == 1,
+ errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
+ use_count));
+ }
+
+ bool IsExpensive() override { return false; }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
+
+REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_CPU),
+ ResourceHandleOp<Mutex>);
+
+REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
+ ConsumeMutexLockOp);
+
+} // namespace tensorflow
}
}
}
-op {
- name: "CriticalSectionOp"
- output_arg {
- name: "resource"
- type: DT_RESOURCE
- }
- attr {
- name: "container"
- type: "string"
- default_value {
- s: ""
- }
- }
- attr {
- name: "shared_name"
- type: "string"
- default_value {
- s: ""
- }
- }
- is_stateful: true
-}
op {
name: "CropAndResize"
input_arg {
}
}
}
-op {
- name: "ExecuteInCriticalSection"
- input_arg {
- name: "critical_section"
- type: DT_RESOURCE
- }
- input_arg {
- name: "arguments"
- type_list_attr: "Targuments"
- }
- output_arg {
- name: "outputs"
- type_list_attr: "output_types"
- }
- attr {
- name: "f"
- type: "func"
- }
- attr {
- name: "Targuments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
- is_stateful: true
-}
-op {
- name: "ExecuteInCriticalSection"
- input_arg {
- name: "critical_section"
- type: DT_RESOURCE
- }
- input_arg {
- name: "arguments"
- type_list_attr: "Targuments"
- }
- output_arg {
- name: "outputs"
- type_list_attr: "output_types"
- }
- attr {
- name: "f"
- type: "func"
- }
- attr {
- name: "Targuments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- }
- is_stateful: true
-}
op {
name: "Exit"
input_arg {
return Status::OK();
});
-REGISTER_OP("CriticalSectionOp")
+REGISTER_OP("MutexV2")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("resource: resource")
return Status::OK();
});
-REGISTER_OP("ExecuteInCriticalSection")
- .Input("critical_section: resource")
- .Input("arguments: Targuments")
- .Output("outputs: output_types")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 0")
- .Attr("output_shapes: list(shape) >= 0")
+REGISTER_OP("MutexLock")
+ .Input("mutex: resource")
+ .Output("mutex_lock: variant")
+ .SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- for (int i = 0; i < output_shapes.size(); ++i) {
- ShapeHandle s;
- TF_RETURN_IF_ERROR(
- c->MakeShapeFromPartialTensorShape(output_shapes[i], &s));
- c->set_output(i, s);
- }
+ c->set_output(0, c->Scalar());
return Status::OK();
});
+REGISTER_OP("ConsumeMutexLock")
+ .Input("mutex_lock: variant")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+
} // namespace tensorflow
ops.EagerTensor, _convert_to_graph_tensor, priority=-1)
-class _CapturingContext(object):
- """Tracks references to Tensors outside this context while it is active."""
+# pylint: disable=invalid-name
+class HelperContext(object):
+ """ControlFlowContext with a customizable AddOp method."""
- def __init__(self):
- # known_ops are ops which are created while this context is active
- self.known_ops = set()
+ def __init__(self, add_op_internal):
+ self._add_op_internal = add_op_internal
+ self._values = set() # control flow code sometimes updates this.
+
+ def _AddOpInternal(self, op):
+ self._add_op_internal(op)
+
+ @property
+ def outer_context(self):
+ return self._outer_context
+
+ def GetWhileContext(self):
+ if self._outer_context:
+ return self._outer_context.GetWhileContext()
+
+ def IsWhileContext(self):
+ return False
+
+ def IsCondContext(self):
+ return False
- # captured_tensors are all tensors referenced to by ops in this context but
- # not produced in it
- self.captured_tensors = set()
+ def IsXLAContext(self):
+ return False
def AddOp(self, op): # pylint: disable=invalid-name
- if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
- raise ValueError("tfe.defun cannot capture variables created without "
- "using tf.get_variable. Op: %s" % op)
- self.known_ops.add(op)
- for i in op.inputs:
- if i.op not in self.known_ops:
- self.captured_tensors.add(i)
+ self._AddOpInternal(op)
+ if self._outer_context:
+ self._outer_context.AddOp(op)
+
+ def AddName(self, _):
+ pass
+
+ def AddInnerOp(self, op):
+ self._AddOpInternal(op)
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ if self._outer_context:
+ return self._outer_context.AddValue(val)
+ else:
+ return val
def __enter__(self):
+ # pylint: disable=protected-access
self._g = ops.get_default_graph()
- self._old = self._g._get_control_flow_context() # pylint: disable=protected-access
- self._g._set_control_flow_context(self) # pylint: disable=protected-access
+ self._outer_context = self._g._get_control_flow_context()
+ self._g._set_control_flow_context(self)
+ self._nested_contexts = (
+ self._outer_context._nested_contexts
+ if self._outer_context is not None else None)
+ # pylint: enable=protected-access
- def __exit__(self, _, __, ___): # pylint: disable=invalid-name
- self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
+ def __exit__(self, *_):
+ self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access
+# pylint: enable=invalid-name
def _forward_name(n):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
with self._graph.as_default(), context.graph_mode():
- c = _CapturingContext()
+ c_known_ops = set()
+ c_captured_tensors = set()
+
+ def add_op_internal(op):
+ if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
+ raise ValueError("tfe.defun cannot capture variables created without "
+ "using tf.get_variable. Op: %s" % op)
+ c_known_ops.add(op)
+ for i in op.inputs:
+ if i.op not in c_known_ops:
+ c_captured_tensors.add(i)
+
+ c = HelperContext(add_op_internal)
+
with c:
filtered_outputs = [x for x in self._returns if x is not None]
self._out_grad_placeholders = [
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
- captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
+ captures = list(sorted(c_captured_tensors, key=lambda x: x.name))
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
# means rerunning the function-defining code will always define the same
# function, which is useful if we serialize this etc.
function_def_ops = tuple(x
- for x in sorted(c.known_ops, key=lambda x: x.name)
+ for x in sorted(c_known_ops, key=lambda x: x.name)
if x not in all_ignored_ops)
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
if context.in_eager_mode():
return tensors
with ops.name_scope(name, "tuple", tensors) as name:
- gating_ops = [t.op for t in tensors if t is not None]
+ tensors = [t if (isinstance(t, ops.Operation)
+ or tensor_util.is_tensor(t)
+ or t is None)
+ else ops.convert_to_tensor(t) for t in tensors]
+ gating_ops = [t if isinstance(t, ops.Operation) else t.op for t in tensors
+ if t is not None]
if control_inputs:
for c in control_inputs:
if isinstance(c, ops.Tensor):
gate = group(*gating_ops)
tpl = []
for t in tensors:
- if t is not None:
+ if tensor_util.is_tensor(t):
tpl.append(with_dependencies([gate], t))
+ elif isinstance(t, ops.Operation):
+ with ops.control_dependencies([gate]):
+ tpl.append(group(t))
else:
tpl.append(None)
return tpl