From fa3a9bcabfea46bb3a4c63f559b50cc066d484e7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 May 2018 12:26:06 -0700 Subject: [PATCH] Collective Ops Part 6 Distributed-mode implementations of CollectiveRemoteAccess. Extend Worker interface with corresponding new methods. This change is part of a series of changes introducing infrastructure for collective ops and initial implementations of reduction and broadcast. PiperOrigin-RevId: 196010718 --- tensorflow/core/BUILD | 1 + tensorflow/core/distributed_runtime/BUILD | 34 ++ .../collective_param_resolver_distributed.cc | 1 - .../collective_rma_distributed.cc | 206 ++++++++++++ .../collective_rma_distributed.h | 50 +++ .../collective_rma_distributed_test.cc | 356 +++++++++++++++++++++ tensorflow/core/distributed_runtime/rpc/BUILD | 1 + .../distributed_runtime/rpc/grpc_remote_worker.cc | 7 + .../distributed_runtime/rpc/grpc_worker_service.cc | 98 +++++- .../distributed_runtime/rpc/grpc_worker_service.h | 3 + .../rpc/grpc_worker_service_impl.cc | 2 + .../rpc/grpc_worker_service_impl.h | 1 + tensorflow/core/distributed_runtime/test_utils.h | 5 + tensorflow/core/distributed_runtime/worker.cc | 9 + tensorflow/core/distributed_runtime/worker.h | 3 + .../core/distributed_runtime/worker_interface.h | 3 + tensorflow/core/protobuf/transport_options.proto | 8 + tensorflow/core/protobuf/worker.proto | 54 ++++ tensorflow/core/protobuf/worker_service.proto | 4 + 19 files changed, 840 insertions(+), 6 deletions(-) create mode 100644 tensorflow/core/distributed_runtime/collective_rma_distributed.cc create mode 100644 tensorflow/core/distributed_runtime/collective_rma_distributed.h create mode 100644 tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc create mode 100644 tensorflow/core/protobuf/transport_options.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 76ff372..ccb8488 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -224,6 +224,7 @@ ADDITIONAL_CORE_PROTO_SRCS = [ "protobuf/named_tensor.proto", "protobuf/saved_model.proto", "protobuf/tensorflow_server.proto", + "protobuf/transport_options.proto", "util/test_log.proto", ] diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 256ce52..18b7069 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -453,6 +453,40 @@ cc_library( ) cc_library( + name = "collective_rma_distributed", + srcs = ["collective_rma_distributed.cc"], + hdrs = ["collective_rma_distributed.h"], + deps = [ + ":worker_cache", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", # protobuf::Any + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:worker_proto_cc", + ], +) + +tf_cc_test( + name = "collective_rma_distributed_test", + size = "small", + srcs = ["collective_rma_distributed_test.cc"], + deps = [ + ":collective_rma_distributed", + ":device_resolver_distributed", + ":test_utils", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( name = "collective_param_resolver_distributed", srcs = ["collective_param_resolver_distributed.cc"], hdrs = ["collective_param_resolver_distributed.h"], diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index ecf5db8..7a93b54 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -284,7 +284,6 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( const GroupRecCallback& done) { VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key << " dev: " << device << " is_leader=" << (group_leader_.empty()); - VLOG(0) << "cp: " << cp->ToString(); if (group_leader_.empty()) { // This is the group leader, so resolution is local. return CompleteGroupLocal(device, cp, done); diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc new file mode 100644 index 0000000..54adcb9 --- /dev/null +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/platform/protobuf_internal.h" +#include "tensorflow/core/protobuf/transport_options.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +namespace { + +// Supports client side cancellation of WorkerInterface calls via +// registration with a CancellationManager. +// +// TODO(tucker): Maybe unify this with CancellableCall in +// collective_param_resolver_distributed.cc. +class CancellableCall { + public: + CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, + WorkerCacheInterface* wc) + : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) { + wi_ = wc_->CreateWorker(remote_worker_); + } + virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } + + virtual void IssueCall(const StatusCallback& done) = 0; + + void Start(const StatusCallback& done) { + CancellationToken token = cancel_mgr_->get_cancellation_token(); + const bool not_yet_cancelled = cancel_mgr_->RegisterCallback( + token, [this, token]() { opts_.StartCancel(); }); + if (not_yet_cancelled) { + IssueCall([this, token, done](const Status& s) { + cancel_mgr_->DeregisterCallback(token); + done(s); + }); + } else { + done(errors::Cancelled("RPC Request was cancelled")); + } + } + + protected: + mutable mutex mu_; + CancellationManager* cancel_mgr_; // Not owned + const string remote_worker_; + WorkerCacheInterface* wc_; // Not owned + WorkerInterface* wi_; // Owned by wc_, must be released. + CallOptions opts_; +}; + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +} // namespace + +void CollectiveRemoteAccessDistributed::RecvFromPeer( + const string& peer_device, const string& peer_task, bool peer_is_local, + const string& key, Device* to_device, DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, const StatusCallback& done) { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufferAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, done](const Status& s) { + std::unique_ptr del_on_exit(state); + if (s.ok()) { + // In this generic implementation the bytes come back in the + // RPC response protobuf rather than via RDMA so we need to copy + // them into the destination tensor here. + RecvBufRespExtra extra; + state->call->resp_.transport_options().UnpackTo(&extra); + int64 num_bytes = extra.tensor_content().size(); + if (num_bytes != to_tensor->TotalBytes()) { + done(errors::Internal("RecvBufResponse returned ", num_bytes, + " bytes where to_tensor expected ", + to_tensor->TotalBytes())); + return; + } + if (to_device->tensorflow_gpu_device_info()) { + // Move the bytes into a CPU tensor then use tensor-to-tensor copy. + // Use GPU-registered memory for the CPU tensor so the transfer + // goes faster. + Device* cpu_dev = nullptr; + Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev); + if (!status.ok()) { + done(status); + return; + } + AllocatorAttributes cpu_attr; + cpu_attr.set_gpu_compatible(true); + Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), + to_tensor->dtype(), to_tensor->shape()); + memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(), + num_bytes); + // Then copy it to the GPU. + CopyTensor::ViaDMA("", // edge name (non-existent) + nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev, + to_device, cpu_attr, to_alloc_attr, cpu_tensor, + to_tensor, + [this, cpu_tensor, done](const Status& s) { + delete cpu_tensor; + // This callback must not block, so execute + // done in another thread. + SchedClosure([s, done] { done(s); }); + }); + return; + } else { + // CPU device + memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(), + num_bytes); + } + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + done(s); + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); +} + +void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h new file mode 100644 index 0000000..cfa9110 --- /dev/null +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +class WorkerCacheInterface; + +// Extend CollectiveRemoteAccessLocal with access to remote peers. +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override; + + void StartAbort(const Status& s) override; + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc new file mode 100644 index 0000000..a552f81 --- /dev/null +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -0,0 +1,356 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" + +#include "google/protobuf/any.pb.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/test_utils.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/transport_options.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +// The only interesting method on CollectiveRemoteAccessDistributed +// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which +// issues a RecvBufAsync call against a WorkerInterface. That's all +// that's tested here. Note that RecvFromPeer can do a +// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation +// for the RecvBufAsync. + +namespace tensorflow { +namespace { + +static Device* NewDevice(const string& type, const string& name) { + class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} + Status Sync() override { return Status::OK(); } + Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + }; + DeviceAttributes attr; + attr.set_name(name); + attr.set_device_type(type); + attr.mutable_locality()->set_numa_node(3); // a non-default value + return new FakeDevice(attr); +} + +static int64 kStepId = 123; + +class FakeWorker : public TestWorkerInterface { + public: + FakeWorker(const string& name, DeviceMgr* dev_mgr, + DeviceResolverDistributed* dres) + : name_(name), + device_mgr_(dev_mgr), + device_resolver_(dres), + buf_rendezvous_(kStepId) {} + + // Direct access to a BufRendezvous that holds whatever the remote + // worker is supposed to have. + BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; } + + void GetStatusAsync(const GetStatusRequest* request, + GetStatusResponse* response, + StatusCallback done) override { + std::vector dev_attr; + device_mgr_->ListDeviceAttributes(&dev_attr); + for (const auto& da : dev_attr) { + *response->add_device_attributes() = da; + } + done(Status::OK()); + } + + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override { + opts->SetCancelCallback([this]() { + // Within this test the call is satisfied by a process-local + // BufRendezvous table. In real application the BufRendezvous + // would be on the other side of a network hop, so call + // BufRendezvous::StartAbort() from a separate thread to be + // more consistent with that situation and avoid mutex deadlock. + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100); + buf_rendezvous_.StartAbort(errors::Internal("Cancelled")); + }); + }); + buf_rendezvous_.ConsumeBuf( + request->buf_rendezvous_key(), + [this, opts, request, response, done](const Status& s, + BufRendezvous::Hook* h) { + if (s.ok()) { + opts->ClearCancelCallback(); + // Since this is not really RDMA into pre-allocated memory send the + // bytes in the response. + RecvBufRespExtra extra; + int64 num_bytes = h->prod_value->TotalBytes(); + extra.set_tensor_content(string( + reinterpret_cast(DMAHelper::base(h->prod_value)), + num_bytes)); + response->mutable_transport_options()->PackFrom(extra); + } + done(s); + if (h) BufRendezvous::DoneWithHook(h); + }); + } + + private: + string name_; + DeviceMgr* device_mgr_; + DeviceResolverDistributed* device_resolver_; + BufRendezvous buf_rendezvous_; +}; + +class FakeCache : public TestWorkerCache { + public: + // Override the Locality methods to actually pass through to the + // worker. + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return false; + } + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + string task_name; + string dev_part; + if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) { + done(errors::Internal("failed to parse device name")); + return; + } + auto it = workers_.find(task_name); + if (it == workers_.end()) { + done(errors::Internal("failed to find worker ", task_name)); + return; + } + WorkerInterface* wi = it->second; + GetStatusRequest req; + GetStatusResponse resp; + Notification note; + Status status = wi->GetStatus(&req, &resp); + if (!status.ok()) { + done(status); + return; + } + for (const auto& it : resp.device_attributes()) { + if (it.name() == device) { + *locality = it.locality(); + done(Status::OK()); + return; + } + } + done(errors::Internal("device not found: ", device)); + } +}; + +class CollRMADistTest : public ::testing::Test { + protected: + CollRMADistTest() {} + + ~CollRMADistTest() override { + for (DeviceMgr* dm : device_mgrs_) { + delete dm; + } + for (auto it : dev_resolvers_) { + delete it.second; + } + for (FakeWorker* w : workers_) { + delete w; + } + } + + void SetUp() override { + const int num_workers = 2; + const int num_devices = 1; + string device_type = "CPU"; + ConfigProto config; + string dev0_worker_name; + for (int w = 0; w < num_workers; ++w) { + string name = strings::StrCat("/job:worker/replica:0/task:", w); + if (w == 0) { + dev0_worker_name = name; + // TODO(tucker): Change to use config when available. + // config.set_collective_group_leader(name); + } + DefineWorker(config, name, device_type, num_devices); + } + // All tests simulate requests from worker 0 to worker 1. + rma_.reset(new CollectiveRemoteAccessDistributed( + device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId)); + + const int kNumElts = 8; + expected_value_ = Tensor(DT_FLOAT, {kNumElts}); + to_tensor_ = Tensor(DT_FLOAT, {kNumElts}); + auto exp_alias = expected_value_.flat(); + auto to_alias = to_tensor_.flat(); + for (int i = 0; i < kNumElts; ++i) { + exp_alias(i) = i; + to_alias(i) = -1; + } + } + + void DefineWorker(const ConfigProto& config, const string& worker_name, + const string& device_type, int num_devices) { + std::vector devices; + for (int i = 0; i < num_devices; ++i) { + devices.push_back(NewDevice( + device_type, + strings::StrCat(worker_name, "/device:", device_type, ":", i))); + } + DeviceMgr* dev_mgr = new DeviceMgr(devices); + device_mgrs_.push_back(dev_mgr); + std::vector* dv = &dev_by_task_[worker_name]; + for (auto d : devices) { + dv->push_back(d->name()); + } + DeviceResolverDistributed* dev_res = + new DeviceResolverDistributed(dev_mgr, &wc_, worker_name); + dev_resolvers_[worker_name] = dev_res; + FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res); + workers_.push_back(fw); + wc_.AddWorker(worker_name, fw); + } + + void ValidateResultTensor() { + ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements()); + for (int i = 0; i < to_tensor_.NumElements(); ++i) { + EXPECT_FLOAT_EQ(expected_value_.flat()(i), + to_tensor_.flat()(i)); + } + } + + FakeCache wc_; + CancellationManager cm_; + std::vector device_mgrs_; + std::unordered_map dev_resolvers_; + std::unordered_map> dev_by_task_; + std::vector workers_; + std::unique_ptr rma_; + mutex mu_; + int num_done_ GUARDED_BY(mu_); + condition_variable done_; + Tensor expected_value_; + Tensor to_tensor_; + CallOptions opts_; + DeviceLocality device_locality_; + AllocatorAttributes alloc_attr_; +}; + +TEST_F(CollRMADistTest, ProdFirstOK) { + Notification consumer_note; + Notification producer_note; + Status consumer_status; + Status producer_status; + FakeWorker* wi = workers_[1]; + const string kBufKey = "fake_buf_key"; + wi->buf_rendezvous()->ProvideBuf( + kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, + AllocatorAttributes(), + [this, &producer_note, &producer_status](const Status& s) { + producer_status.Update(s); + producer_note.Notify(); + }); + Status status; + Device* dst_device = nullptr; + string dev_name = "CPU:0"; + TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); + DeviceContext* to_device_ctx = nullptr; + rma_->RecvFromPeer( + "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev + "/job:worker/replica:0/task:1", // peer_task + false, // peer_is_local + kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, + device_locality_, + [this, &consumer_status, &consumer_note](const Status& s) { + consumer_status = s; + consumer_note.Notify(); + }); + consumer_note.WaitForNotification(); + TF_EXPECT_OK(consumer_status); + producer_note.WaitForNotification(); + TF_EXPECT_OK(producer_status); + ValidateResultTensor(); +} + +TEST_F(CollRMADistTest, ConsFirstOK) { + Notification consumer_note; + Notification producer_note; + Status consumer_status; + Status producer_status; + FakeWorker* wi = workers_[1]; + const string kBufKey = "fake_buf_key"; + Status status; + Device* dst_device = nullptr; + string dev_name = "CPU:0"; + TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); + DeviceContext* to_device_ctx = nullptr; + rma_->RecvFromPeer( + "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev + "/job:worker/replica:0/task:1", // peer_task + false, // peer_is_local + kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, + device_locality_, + [this, &consumer_status, &consumer_note](const Status& s) { + consumer_status = s; + consumer_note.Notify(); + }); + wi->buf_rendezvous()->ProvideBuf( + kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, + AllocatorAttributes(), + [this, &producer_note, &producer_status](const Status& s) { + producer_status.Update(s); + producer_note.Notify(); + }); + consumer_note.WaitForNotification(); + TF_EXPECT_OK(consumer_status); + producer_note.WaitForNotification(); + TF_EXPECT_OK(producer_status); + ValidateResultTensor(); +} + +TEST_F(CollRMADistTest, ConsFirstAbort) { + Notification consumer_note; + Status consumer_status; + const string kBufKey = "fake_buf_key"; + Status status; + Device* dst_device = nullptr; + string dev_name = "CPU:0"; + TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); + DeviceContext* to_device_ctx = nullptr; + rma_->RecvFromPeer( + "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev + "/job:worker/replica:0/task:1", // peer_task + false, // peer_is_local + kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, + device_locality_, + [this, &consumer_status, &consumer_note](const Status& s) { + consumer_status = s; + consumer_note.Notify(); + }); + rma_->StartAbort(errors::Internal("Deliberate Failure")); + consumer_note.WaitForNotification(); + EXPECT_EQ(consumer_status.error_message(), "Cancelled"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index c2719f5..40028ee 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -171,6 +171,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:graph_mgr", "//tensorflow/core/distributed_runtime:recent_request_ids", diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 5b7b74c..1acf1fb 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -54,6 +54,7 @@ class GrpcRemoteWorker : public WorkerInterface { cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)), cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)), recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)), + recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)), logging_(Method(GrpcWorkerMethod::kLogging)), tracing_(Method(GrpcWorkerMethod::kTracing)), completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)), @@ -118,6 +119,11 @@ class GrpcRemoteWorker : public WorkerInterface { IssueRequest(request, response, cleanupall_, std::move(done)); } + void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override { + IssueRequest(request, response, recvbuf_, std::move(done), call_opts); + } + void CompleteGroupAsync(CallOptions* call_opts, const CompleteGroupRequest* request, CompleteGroupResponse* response, @@ -239,6 +245,7 @@ class GrpcRemoteWorker : public WorkerInterface { const ::grpc::string cleanupgraph_; const ::grpc::string cleanupall_; const ::grpc::string recvtensor_; + const ::grpc::string recvbuf_; const ::grpc::string logging_; const ::grpc::string tracing_; const ::grpc::string completegroup_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 26fad1f..4383e41 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpc++/alarm.h" #include "grpc++/server_builder.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -37,10 +38,12 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { @@ -159,6 +162,9 @@ class GrpcWorkerService : public AsyncServiceInterface { for (int i = 0; i < 1000; ++i) { EnqueueRecvTensorRequestRaw(); } + for (int i = 0; i < 500; ++i) { + ENQUEUE_REQUEST(RecvBuf, true); + } for (int i = 0; i < 100; ++i) { ENQUEUE_REQUEST(RunGraph, true); } @@ -170,9 +176,9 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(Tracing, false); for (int i = 0; i < 10; ++i) { - ENQUEUE_REQUEST(CompleteGroup, false); - ENQUEUE_REQUEST(CompleteInstance, false); - ENQUEUE_REQUEST(GetStepSequence, false); + ENQUEUE_REQUEST(CompleteGroup, true); + ENQUEUE_REQUEST(CompleteInstance, true); + ENQUEUE_REQUEST(GetStepSequence, true); } void* tag; @@ -322,6 +328,20 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(Tracing, false); } + void RecvBufHandler(WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->RecvBufAsync(call_opts, &call->request, &call->response, + [call, call_opts](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + ENQUEUE_REQUEST(RecvBuf, true); + } + void CompleteGroupHandler( WorkerCall* call) { Schedule([this, call]() { @@ -334,7 +354,7 @@ class GrpcWorkerService : public AsyncServiceInterface { call->SendResponse(ToGrpcStatus(s)); }); }); - ENQUEUE_REQUEST(CompleteGroup, false); + ENQUEUE_REQUEST(CompleteGroup, true); } void CompleteInstanceHandler( @@ -360,7 +380,7 @@ class GrpcWorkerService : public AsyncServiceInterface { &call->request, &call->response, [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); }); }); - ENQUEUE_REQUEST(GetStepSequence, false); + ENQUEUE_REQUEST(GetStepSequence, true); } #undef ENQUEUE_REQUEST @@ -485,6 +505,74 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is a generic, low performance implementation appropriate for grpc. + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, opts, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + // The RPC source tensor needs to be in CPU RAM. If not already + // there make a copy using memory appropriate to the purpose. + const size_t num_bytes = hook->prod_value->TotalBytes(); + const bool on_host = + hook->prod_dev->attributes().device_type() == "CPU" || + hook->prod_attr.on_host(); + if ((!on_host) && (num_bytes > 0)) { + Device* cpu_dev = nullptr; + s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev); + if (s.ok()) { + AllocatorAttributes cpu_attr; + cpu_attr.set_gpu_compatible(true); + cpu_attr.set_nic_compatible(true); + Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), + hook->prod_value->dtype(), + hook->prod_value->shape()); + hook->prod_ctx->CopyDeviceTensorToCPU( + hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor, + [this, num_bytes, response, done, hook, + cpu_tensor](const Status& s) { + if (s.ok()) { + RecvBufRespExtra extra; + extra.set_tensor_content(reinterpret_cast( + DMAHelper::base(cpu_tensor)), + num_bytes); + response->mutable_transport_options()->PackFrom(extra); + } + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + delete cpu_tensor; + }); + return; + } + } else { + // Tensor is on CPU. + RecvBufRespExtra extra; + extra.set_tensor_content(reinterpret_cast( + DMAHelper::base(hook->prod_value)), + num_bytes); + response->mutable_transport_options()->PackFrom(extra); + } + } + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); +} + void GrpcWorker::LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done) { auto env = this->env(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index fbddbda..c0ed088 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -43,6 +43,9 @@ class GrpcWorker : public Worker { virtual void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done); + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done); + WorkerEnv* env(); private: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index a91cc06..38cc2b8 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -46,6 +46,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/CleanupAll"; case GrpcWorkerMethod::kRecvTensor: return "/tensorflow.WorkerService/RecvTensor"; + case GrpcWorkerMethod::kRecvBuf: + return "/tensorflow.WorkerService/RecvBuf"; case GrpcWorkerMethod::kLogging: return "/tensorflow.WorkerService/Logging"; case GrpcWorkerMethod::kTracing: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index c5104c6..da27083 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -81,6 +81,7 @@ enum class GrpcWorkerMethod { kCleanupGraph, kCleanupAll, kRecvTensor, + kRecvBuf, kLogging, kTracing, kCompleteGroup, diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index 0ed0782..48d8384 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -93,6 +93,11 @@ class TestWorkerInterface : public WorkerInterface { done(errors::Unimplemented("RunGraphAsync")); } + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override { + done(errors::Unimplemented("RecvBufAsync")); + } + void CompleteGroupAsync(CallOptions* opts, const CompleteGroupRequest* request, CompleteGroupResponse* response, diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index d682ac8..4e6500f 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -337,6 +337,15 @@ void Worker::TracingAsync(const TracingRequest* request, done(errors::Unimplemented("Tracing")); } +void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // The base Worker class does not implement RecvBufAsync because + // it is not currently used for worker-to-worker communication. Use a + // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`) + // instead. + done(errors::Unimplemented("Worker::RecvBufAsync()")); +} + void Worker::CompleteGroupAsync(CallOptions* opts, const CompleteGroupRequest* request, CompleteGroupResponse* response, diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index b5a9ada..91eb27c 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -90,6 +90,9 @@ class Worker : public WorkerInterface { void TracingAsync(const TracingRequest* request, TracingResponse* response, StatusCallback done) override; + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override; + void CompleteGroupAsync(CallOptions* opts, const CompleteGroupRequest* request, CompleteGroupResponse* response, diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index bad31d2..a50ac3b 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -112,6 +112,9 @@ class WorkerInterface { virtual void TracingAsync(const TracingRequest* request, TracingResponse* response, StatusCallback done) = 0; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) = 0; + virtual void CompleteGroupAsync(CallOptions* opts, const CompleteGroupRequest* request, CompleteGroupResponse* response, diff --git a/tensorflow/core/protobuf/transport_options.proto b/tensorflow/core/protobuf/transport_options.proto new file mode 100644 index 0000000..d7b1bdd --- /dev/null +++ b/tensorflow/core/protobuf/transport_options.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package tensorflow; + +// Extra data needed on a non-RDMA RecvBufResponse. +message RecvBufRespExtra { + bytes tensor_content = 1; +}; diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 602f6a1..f7816e9 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -418,6 +418,60 @@ message TracingResponse { //////////////////////////////////////////////////////////////////////////////// // +// Raw data transfers in support of Collective Ops. +// These methods are experimental and subject to change. +// +// The intention is to allow collectives to take advantage of the most +// efficient methods available on a platform, e.g. RDMA, and not be +// constrained to use the RPC system in use by other methods. +// +//////////////////////////////////////////////////////////////////////////////// + +message RecvBufRequest { + // Use of the fields below may vary by implementation. For example + // the buf_ptr and num_bytes may be set only for local operations and + // not sent on the wire, or only sent on the wire in one direction. + + // Used at server side to find the correct BufRendezvous. + int64 step_id = 1; + + // Arbitrary string identifying a BufRendezvous entry. + string buf_rendezvous_key = 2; + + // Size of value expected, must agree with BufRendezvous entry. + int64 num_bytes = 3; + + // When RDMA is in use, address of destination field on client. + fixed64 buf_ptr = 4; + + // Optional information on client-side device locality. + DeviceLocality client_locality = 5; + + // Optional information on server-side device locality. + DeviceLocality server_locality = 6; + + // Optional, implementation-specific data. + google.protobuf.Any transport_options = 7; + // Optional, for annotating the timeline. + string src_device = 8; + string dst_device = 9; +} + +message RecvBufResponse { + // Use of the fields below may vary by implementation. Comments give + // intended use. + + fixed64 buf_ptr = 1; // Address of source field on server. + int64 num_bytes = 2; // Byte length of buf_ptr field, if set. + bool is_dead = 3; // True if value is 'dead' like a tensor. + // Optional, implementation-specific data. + google.protobuf.Any transport_options = 4; + // Optional, for timeline. + int64 send_start_micros = 5; +} + +//////////////////////////////////////////////////////////////////////////////// +// // Collective Op dynamic group resolution messages. // //////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 01c76c0..e0c27f3 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -74,6 +74,10 @@ service WorkerService { rpc Tracing(TracingRequest) returns (TracingResponse); // See worker.proto for details. + rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) { + } + + // See worker.proto for details. rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse); // See worker.proto for details. -- 2.7.4