Previously, the master would send a job name and task index in an
otherwise-empty ServerDef, and the worker would unquestioningly use
those to build its worker name. However, this would lead to errors if
the worker had a local name like "/job:worker/replica:1/task:0",
because the ServerDef doesn't support non-zero replica IDs, and so the
local worker would end up an inconsistent view of what its worker name
should be. In particular `WorkerSession::worker_name` would disagree
with the device names added during graph partitioning by the master,
which would lead to runtime failures ("InvalidArgumentError: Invalid
rendezvous key").
PiperOrigin-RevId:
193733855
deps = [
":session_mgr",
":worker_env",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
- if (options.cluster_def) {
- *workers[i].request.mutable_server_def()->mutable_cluster() =
- *options.cluster_def;
- workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
- // Session state is always isolated when ClusterSpec propagation
- // is in use.
- workers[i].request.set_isolate_session_state(true);
- } else {
- workers[i].request.set_isolate_session_state(
- session_opts_.config.isolate_session_state());
- }
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
return status;
}
- workers[i].request.mutable_server_def()->set_job_name(name.job);
- workers[i].request.mutable_server_def()->set_task_index(name.task);
+ if (options.cluster_def) {
+ *workers[i].request.mutable_server_def()->mutable_cluster() =
+ *options.cluster_def;
+ workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+ workers[i].request.mutable_server_def()->set_job_name(name.job);
+ workers[i].request.mutable_server_def()->set_task_index(name.task);
+ // Session state is always isolated when ClusterSpec propagation
+ // is in use.
+ workers[i].request.set_isolate_session_state(true);
+ } else {
+ // NOTE(mrry): Do not set any component of the ServerDef,
+ // because the worker will use its local configuration.
+ workers[i].request.set_isolate_session_state(
+ session_opts_.config.isolate_session_state());
+ }
}
for (size_t i = 0; i < worker_names.size(); ++i) {
new GraphMgr(worker_env, worker_env->device_mgr)))),
worker_cache_factory_(std::move(worker_cache_factory)) {}
+/* static */
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
return strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:",
server_def.task_index());
return errors::InvalidArgument("Session must be non-empty.");
}
- const string worker_name = WorkerNameFromServerDef(server_def);
-
WorkerCacheInterface* worker_cache = nullptr;
+ string worker_name;
if (server_def.cluster().job().empty()) {
worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
+ worker_name = legacy_session_->worker_name;
} else {
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
+ worker_name = WorkerNameFromServerDef(server_def);
}
if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) {
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
namespace tensorflow {
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
+TEST_F(SessionMgrTest, CreateSessionClusterDefWorkerName) {
+ ServerDef server_def;
+ server_def.set_job_name("worker");
+ server_def.set_task_index(3);
+ auto job = server_def.mutable_cluster()->add_job();
+ job->set_name("worker");
+ job->mutable_tasks()->insert({3, "localhost:3333"});
+
+ string session_handle = "test_session_handle";
+ TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
+ std::shared_ptr<WorkerSession> session;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
+ EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
+ EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name);
+ TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
+}
+
+TEST_F(SessionMgrTest, CreateSessionDefaultWorkerName) {
+ ServerDef server_def;
+ string session_handle = "test_session_handle";
+ TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
+ std::shared_ptr<WorkerSession> session;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
+ EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
+ EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name);
+ TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
+}
+
TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
ServerDef server_def;
server_def.set_job_name("worker");