From c2d6faafc48b251faa24a342dc063d9fa624421e Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 5 Apr 2018 22:37:49 -0700 Subject: [PATCH] Fix StringPiece use-after-free in MasterSession::ReffedClientGraph. Use the owned ClientGraph as the source for the node_to_name_ map, rather than the borrowed GraphExecutionState (which can be deleted while the ReffedClientGraph is in use). PiperOrigin-RevId: 191847023 --- .../core/distributed_runtime/master_session.cc | 24 ++++++++++------------ 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 01da54f..64adf35 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -66,8 +66,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { std::unique_ptr cg, const SessionOptions& session_opts, const StatsPublisherFactory& stats_publisher_factory, - GraphExecutionState* execution_state, bool is_partial, - WorkerCacheInterface* worker_cache, bool should_deregister) + bool is_partial, WorkerCacheInterface* worker_cache, + bool should_deregister) : session_handle_(handle), client_graph_(std::move(cg)), session_opts_(session_opts), @@ -80,8 +80,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts); - // Initialize a name to node map for testing that fetches are reachable. - for (Node* n : execution_state->full_graph()->nodes()) { + // Initialize a name to node map for processing device stats. + for (Node* n : client_graph_->graph.nodes()) { name_to_node_.insert({n->name(), n}); } } @@ -829,8 +829,6 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends // on once at setup time to prevent us from computing the dependencies // everytime. -// TODO(suharshs,mrry): Consider removing the need for execution_state to reduce -// contention. Status MasterSession::ReffedClientGraph::CheckFetches( const RunStepRequestWrapper& req, const RunState* run_state, GraphExecutionState* execution_state) { @@ -840,8 +838,8 @@ Status MasterSession::ReffedClientGraph::CheckFetches( // Skip if already fed. if (input.second) continue; TensorId id(ParseTensorName(input.first)); - const auto it = name_to_node_.find(id.first); - if (it == name_to_node_.end()) { + const Node* n = execution_state->get_node_by_name(id.first.ToString()); + if (n == nullptr) { return errors::NotFound("Feed ", input.first, ": not found"); } pending_feeds.insert(id); @@ -856,11 +854,11 @@ Status MasterSession::ReffedClientGraph::CheckFetches( for (size_t i = 0; i < req.num_fetches(); ++i) { const string& fetch = req.fetch_name(i); const TensorId id(ParseTensorName(fetch)); - auto it = name_to_node_.find(id.first); - if (it == name_to_node_.end()) { + const Node* n = execution_state->get_node_by_name(id.first.ToString()); + if (n == nullptr) { return errors::NotFound("Fetch ", fetch, ": not found"); } - stack.push_back(it->second); + stack.push_back(n); } // Any tensor needed for fetches can't be in pending_feeds. @@ -1293,8 +1291,8 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( handle_, opts, std::move(client_graph), session_opts_, - stats_publisher_factory_, execution_state_.get(), is_partial, - worker_cache, !should_delete_worker_sessions_); + stats_publisher_factory_, is_partial, worker_cache, + !should_delete_worker_sessions_); iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } -- 2.7.4