[ClusterFLR] Prolong the lifetime of the RunGraphRequest until the call has completed.
authorDerek Murray <mrry@google.com>
Fri, 2 Mar 2018 00:00:17 +0000 (16:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 00:04:36 +0000 (16:04 -0800)
Some WorkerService implementations rely on the request object remaining live until the callback is called.

PiperOrigin-RevId: 187548140

tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc

index 3a8d5912369525253904bd700dfdc6e3eb26e0ae..0c5c4d59edc8c73d6bcac3ce0f9ec0b77495fb58 100644 (file)
@@ -175,32 +175,33 @@ void ClusterFunctionLibraryRuntime::Run(
     return;
   }
 
-  RunGraphRequest req;
-  req.set_session_handle(worker_session_->session_name);
-  req.set_graph_handle(function_data->graph_handle);
+  RunGraphRequest* req = new RunGraphRequest;
+  req->set_session_handle(worker_session_->session_name);
+  req->set_graph_handle(function_data->graph_handle);
   // Borrowed from master_session.cc
   const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
-  req.set_step_id(step_id);
+  req->set_step_id(step_id);
   int i = 0;
   for (const auto& send_key : function_data->send_keys) {
-    NamedTensorProto* send = req.add_send();
+    NamedTensorProto* send = req->add_send();
     send->set_name(send_key);
     args[i].AsProtoTensorContent(send->mutable_tensor());
     i++;
   }
   const std::vector<string>& recv_keys = function_data->recv_keys;
   for (const auto& recv_key : recv_keys) {
-    req.add_recv_key(recv_key);
+    req->add_recv_key(recv_key);
   }
 
   RunGraphResponse* resp = new RunGraphResponse();
   CallOptions* call_options = new CallOptions();
   wi->RunGraphAsync(
-      call_options, &req, resp,
-      [call_options, resp, rets, recv_keys, done](const Status& status) {
+      call_options, req, resp,
+      [call_options, req, resp, rets, recv_keys, done](const Status& status) {
         if (!status.ok()) {
           done(status);
           delete call_options;
+          delete req;
           delete resp;
           return;
         }
@@ -212,25 +213,28 @@ void ClusterFunctionLibraryRuntime::Run(
         for (const auto& recv_key : recv_keys) {
           TensorProto* tp = mapped_recvs[recv_key];
           if (tp == nullptr) {
+            done(errors::Internal("Could not find key: ", recv_key));
             delete call_options;
+            delete req;
             delete resp;
-            done(errors::Internal("Could not find key: ", recv_key));
             return;
           }
           Tensor t;
           if (t.FromProto(*tp)) {
             rets->push_back(t);
           } else {
-            delete call_options;
-            delete resp;
             done(errors::Internal("Could not convert tensor proto: ",
                                   tp->DebugString()));
+            delete call_options;
+            delete req;
+            delete resp;
             return;
           }
         }
+        done(status);
         delete call_options;
+        delete req;
         delete resp;
-        done(status);
       });
 }