From 1d6a57edc0be0dcc0c92eb2610b88420a7b7be51 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 12 Mar 2018 11:02:29 -0700 Subject: [PATCH] Fix race in C API. RecordMutation could race with ExtendSessionGraphHelper, which would release the graph lock and only keep the session lock when extending the session. Also makes sure thread annotations are on declarations, not definitions (otherwise they have no effect). PiperOrigin-RevId: 188747158 --- tensorflow/c/c_api.cc | 38 ++++++++++++++++---------------------- tensorflow/c/c_api_internal.h | 12 +++++++----- tensorflow/c/python_api.cc | 3 +-- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 8b9b3da..778cb66 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -63,6 +63,7 @@ limitations under the License. // brain namespace because we are defining 'extern "C"' functions. using tensorflow::AllocationDescription; using tensorflow::DataType; +using tensorflow::ExtendSessionGraphHelper; using tensorflow::Graph; using tensorflow::GraphDef; using tensorflow::mutex_lock; @@ -640,11 +641,11 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const char* mutation_type) { // If any session has already run this node_id, mark this session as // unrunnable. for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); if (it.first->last_num_graph_nodes > op.node.id()) { it.second = FailedPrecondition( "Operation '", op.node.DebugString(), "' was changed by ", @@ -713,10 +714,12 @@ Status LoadLibrary(const char* library_filename, void** result, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). -bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(session->mu) { +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. session->graph->mu.lock(); + mutex_lock session_lock(session->mu); const Graph& graph = session->graph->graph; status->status = session->graph->sessions[session]; @@ -2571,12 +2574,9 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - { - mutex_lock l(session->mu); - if (session->extend_before_run && - !tensorflow::ExtendSessionGraphHelper(session, status)) { - return; - } + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; } TF_Run_Setup(noutputs, output_values, status); @@ -2612,12 +2612,9 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, const char** handle, TF_Status* status) { *handle = nullptr; - { - mutex_lock l(session->mu); - if (session->extend_before_run && - !tensorflow::ExtendSessionGraphHelper(session, status)) { - return; - } + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; } std::vector input_names(ninputs); @@ -2659,12 +2656,9 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - { - mutex_lock l(session->mu); - if (session->extend_before_run && - !tensorflow::ExtendSessionGraphHelper(session, status)) { - return; - } + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { + return; } TF_Run_Setup(noutputs, output_values, status); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 2523393..e885a69 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -124,16 +124,16 @@ struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; - TF_Graph* graph; + TF_Graph* const graph; - tensorflow::mutex mu; + tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call // ExtendSessionGraphHelper before running the graph (this is the default // public behavior). Can be set to false if the caller needs to call // ExtendSessionGraphHelper manually. - bool extend_before_run GUARDED_BY(mu); + std::atomic extend_before_run; }; struct TF_ImportGraphDefOptions { @@ -211,9 +211,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, TF_Status* status); void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type); + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu); -bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status); +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + LOCKS_EXCLUDED(session->graph->mu, session->mu); } // end namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 26683f5..cd60453 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -105,9 +105,8 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { } void ExtendSession(TF_Session* session, TF_Status* status) { - mutex_lock l(session->mu); - session->extend_before_run = false; ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; } } // namespace tensorflow -- 2.7.4