Fix StringPiece use-after-free in MasterSession::ReffedClientGraph.
authorDerek Murray <mrry@google.com>
Fri, 6 Apr 2018 05:37:49 +0000 (22:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 05:40:16 +0000 (22:40 -0700)
Use the owned ClientGraph as the source for the node_to_name_ map, rather than the borrowed GraphExecutionState (which can be deleted while the ReffedClientGraph is in use).

PiperOrigin-RevId: 191847023

tensorflow/core/distributed_runtime/master_session.cc

index 01da54f..64adf35 100644 (file)
@@ -66,8 +66,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
                     std::unique_ptr<ClientGraph> cg,
                     const SessionOptions& session_opts,
                     const StatsPublisherFactory& stats_publisher_factory,
-                    GraphExecutionState* execution_state, bool is_partial,
-                    WorkerCacheInterface* worker_cache, bool should_deregister)
+                    bool is_partial, WorkerCacheInterface* worker_cache,
+                    bool should_deregister)
       : session_handle_(handle),
         client_graph_(std::move(cg)),
         session_opts_(session_opts),
@@ -80,8 +80,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
 
     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
 
-    // Initialize a name to node map for testing that fetches are reachable.
-    for (Node* n : execution_state->full_graph()->nodes()) {
+    // Initialize a name to node map for processing device stats.
+    for (Node* n : client_graph_->graph.nodes()) {
       name_to_node_.insert({n->name(), n});
     }
   }
@@ -829,8 +829,6 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
 // on once at setup time to prevent us from computing the dependencies
 // everytime.
-// TODO(suharshs,mrry): Consider removing the need for execution_state to reduce
-// contention.
 Status MasterSession::ReffedClientGraph::CheckFetches(
     const RunStepRequestWrapper& req, const RunState* run_state,
     GraphExecutionState* execution_state) {
@@ -840,8 +838,8 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
     // Skip if already fed.
     if (input.second) continue;
     TensorId id(ParseTensorName(input.first));
-    const auto it = name_to_node_.find(id.first);
-    if (it == name_to_node_.end()) {
+    const Node* n = execution_state->get_node_by_name(id.first.ToString());
+    if (n == nullptr) {
       return errors::NotFound("Feed ", input.first, ": not found");
     }
     pending_feeds.insert(id);
@@ -856,11 +854,11 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
   for (size_t i = 0; i < req.num_fetches(); ++i) {
     const string& fetch = req.fetch_name(i);
     const TensorId id(ParseTensorName(fetch));
-    auto it = name_to_node_.find(id.first);
-    if (it == name_to_node_.end()) {
+    const Node* n = execution_state->get_node_by_name(id.first.ToString());
+    if (n == nullptr) {
       return errors::NotFound("Fetch ", fetch, ": not found");
     }
-    stack.push_back(it->second);
+    stack.push_back(n);
   }
 
   // Any tensor needed for fetches can't be in pending_feeds.
@@ -1293,8 +1291,8 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
       WorkerCacheInterface* worker_cache = get_worker_cache();
       auto entry = new ReffedClientGraph(
           handle_, opts, std::move(client_graph), session_opts_,
-          stats_publisher_factory_, execution_state_.get(), is_partial,
-          worker_cache, !should_delete_worker_sessions_);
+          stats_publisher_factory_, is_partial, worker_cache,
+          !should_delete_worker_sessions_);
       iter = m->insert({hash, entry}).first;
       VLOG(1) << "Preparing to execute new graph";
     }