"protobuf/named_tensor.proto",
"protobuf/saved_model.proto",
"protobuf/tensorflow_server.proto",
+ "protobuf/transport_options.proto",
"util/test_log.proto",
]
)
cc_library(
+ name = "collective_rma_distributed",
+ srcs = ["collective_rma_distributed.cc"],
+ hdrs = ["collective_rma_distributed.h"],
+ deps = [
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal", # protobuf::Any
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "collective_rma_distributed_test",
+ size = "small",
+ srcs = ["collective_rma_distributed_test.cc"],
+ deps = [
+ ":collective_rma_distributed",
+ ":device_resolver_distributed",
+ ":test_utils",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
name = "collective_param_resolver_distributed",
srcs = ["collective_param_resolver_distributed.cc"],
hdrs = ["collective_param_resolver_distributed.h"],
const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
<< " dev: " << device << " is_leader=" << (group_leader_.empty());
- VLOG(0) << "cp: " << cp->ToString();
if (group_leader_.empty()) {
// This is the group leader, so resolution is local.
return CompleteGroupLocal(device, cp, done);
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/platform/protobuf_internal.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+//
+// TODO(tucker): Maybe unify this with CancellableCall in
+// collective_param_resolver_distributed.cc.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
+ wi_ = wc_->CreateWorker(remote_worker_);
+ }
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* wc_; // Not owned
+ WorkerInterface* wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+class RecvBufCall : public CancellableCall {
+ public:
+ RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
+ const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const DeviceLocality& server_locality,
+ CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
+ : CancellableCall(cancel_mgr, peer_task, wc) {
+ req_.set_step_id(step_id);
+ req_.set_buf_rendezvous_key(key);
+ *req_.mutable_client_locality() = client_locality;
+ *req_.mutable_server_locality() = server_locality;
+ req_.set_num_bytes(to_tensor->TotalBytes());
+ req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
+ req_.set_src_device(peer_device);
+ req_.set_dst_device(to_device->name());
+ }
+
+ ~RecvBufCall() override {}
+
+ void IssueCall(const StatusCallback& done) override {
+ wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
+ }
+
+ RecvBufRequest req_;
+ RecvBufResponse resp_;
+};
+
+} // namespace
+
+void CollectiveRemoteAccessDistributed::RecvFromPeer(
+ const string& peer_device, const string& peer_task, bool peer_is_local,
+ const string& key, Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ if (peer_is_local) {
+ CollectiveRemoteAccessLocal::RecvFromPeer(
+ peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
+ to_alloc_attr, to_tensor, client_locality, done);
+ return;
+ }
+
+ // State that needs to be threaded through a couple of async calls
+ // in order to make this function completely non-blocking.
+ struct State {
+ DeviceLocality server_locality;
+ std::unique_ptr<RecvBufCall> call;
+ };
+ State* state = new State;
+
+ // Logic to be executed on the RecvBufferAsync callback.
+ auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
+ to_device_ctx, to_tensor, done](const Status& s) {
+ std::unique_ptr<State> del_on_exit(state);
+ if (s.ok()) {
+ // In this generic implementation the bytes come back in the
+ // RPC response protobuf rather than via RDMA so we need to copy
+ // them into the destination tensor here.
+ RecvBufRespExtra extra;
+ state->call->resp_.transport_options().UnpackTo(&extra);
+ int64 num_bytes = extra.tensor_content().size();
+ if (num_bytes != to_tensor->TotalBytes()) {
+ done(errors::Internal("RecvBufResponse returned ", num_bytes,
+ " bytes where to_tensor expected ",
+ to_tensor->TotalBytes()));
+ return;
+ }
+ if (to_device->tensorflow_gpu_device_info()) {
+ // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
+ // Use GPU-registered memory for the CPU tensor so the transfer
+ // goes faster.
+ Device* cpu_dev = nullptr;
+ Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+ to_tensor->dtype(), to_tensor->shape());
+ memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(),
+ num_bytes);
+ // Then copy it to the GPU.
+ CopyTensor::ViaDMA("", // edge name (non-existent)
+ nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
+ to_device, cpu_attr, to_alloc_attr, cpu_tensor,
+ to_tensor,
+ [this, cpu_tensor, done](const Status& s) {
+ delete cpu_tensor;
+ // This callback must not block, so execute
+ // done in another thread.
+ SchedClosure([s, done] { done(s); });
+ });
+ return;
+ } else {
+ // CPU device
+ memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(),
+ num_bytes);
+ }
+ }
+ if (!s.ok() && errors::IsFailedPrecondition(s)) {
+ dev_resolver_->ClearTask(peer_task);
+ }
+
+ done(s);
+ };
+
+ // Logic to execute once we have the device locality for the server-side
+ // device.
+ auto dev_locality_callback = [this, state, peer_device, peer_task, key,
+ to_device, to_device_ctx, to_alloc_attr,
+ to_tensor, client_locality,
+ recv_buf_callback](const Status& s) {
+ if (!s.ok()) {
+ recv_buf_callback(s);
+ } else {
+ state->call.reset(new RecvBufCall(
+ step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
+ to_alloc_attr, to_tensor, client_locality, state->server_locality,
+ &cancel_mgr_, worker_cache_));
+ state->call->Start(recv_buf_callback);
+ }
+ };
+
+ dev_resolver_->GetLocalityAsync(
+ peer_device, peer_task, &state->server_locality, dev_locality_callback);
+}
+
+void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
+ CollectiveRemoteAccessLocal::StartAbort(s);
+ cancel_mgr_.StartCancel();
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+class WorkerCacheInterface;
+
+// Extend CollectiveRemoteAccessLocal with access to remote peers.
+class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
+ public:
+ CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ WorkerCacheInterface* worker_cache,
+ int64 step_id)
+ : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+ worker_cache_(worker_cache) {}
+
+ ~CollectiveRemoteAccessDistributed() override {}
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void StartAbort(const Status& s) override;
+
+ protected:
+ WorkerCacheInterface* worker_cache_; // Not owned
+ CancellationManager cancel_mgr_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+// The only interesting method on CollectiveRemoteAccessDistributed
+// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
+// issues a RecvBufAsync call against a WorkerInterface. That's all
+// that's tested here. Note that RecvFromPeer can do a
+// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
+// for the RecvBufAsync.
+
+namespace tensorflow {
+namespace {
+
+static Device* NewDevice(const string& type, const string& name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ attr.mutable_locality()->set_numa_node(3); // a non-default value
+ return new FakeDevice(attr);
+}
+
+static int64 kStepId = 123;
+
+class FakeWorker : public TestWorkerInterface {
+ public:
+ FakeWorker(const string& name, DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dres)
+ : name_(name),
+ device_mgr_(dev_mgr),
+ device_resolver_(dres),
+ buf_rendezvous_(kStepId) {}
+
+ // Direct access to a BufRendezvous that holds whatever the remote
+ // worker is supposed to have.
+ BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ std::vector<DeviceAttributes> dev_attr;
+ device_mgr_->ListDeviceAttributes(&dev_attr);
+ for (const auto& da : dev_attr) {
+ *response->add_device_attributes() = da;
+ }
+ done(Status::OK());
+ }
+
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ opts->SetCancelCallback([this]() {
+ // Within this test the call is satisfied by a process-local
+ // BufRendezvous table. In real application the BufRendezvous
+ // would be on the other side of a network hop, so call
+ // BufRendezvous::StartAbort() from a separate thread to be
+ // more consistent with that situation and avoid mutex deadlock.
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100);
+ buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
+ });
+ });
+ buf_rendezvous_.ConsumeBuf(
+ request->buf_rendezvous_key(),
+ [this, opts, request, response, done](const Status& s,
+ BufRendezvous::Hook* h) {
+ if (s.ok()) {
+ opts->ClearCancelCallback();
+ // Since this is not really RDMA into pre-allocated memory send the
+ // bytes in the response.
+ RecvBufRespExtra extra;
+ int64 num_bytes = h->prod_value->TotalBytes();
+ extra.set_tensor_content(string(
+ reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+ num_bytes));
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ done(s);
+ if (h) BufRendezvous::DoneWithHook(h);
+ });
+ }
+
+ private:
+ string name_;
+ DeviceMgr* device_mgr_;
+ DeviceResolverDistributed* device_resolver_;
+ BufRendezvous buf_rendezvous_;
+};
+
+class FakeCache : public TestWorkerCache {
+ public:
+ // Override the Locality methods to actually pass through to the
+ // worker.
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ string task_name;
+ string dev_part;
+ if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+ done(errors::Internal("failed to parse device name"));
+ return;
+ }
+ auto it = workers_.find(task_name);
+ if (it == workers_.end()) {
+ done(errors::Internal("failed to find worker ", task_name));
+ return;
+ }
+ WorkerInterface* wi = it->second;
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ Notification note;
+ Status status = wi->GetStatus(&req, &resp);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ for (const auto& it : resp.device_attributes()) {
+ if (it.name() == device) {
+ *locality = it.locality();
+ done(Status::OK());
+ return;
+ }
+ }
+ done(errors::Internal("device not found: ", device));
+ }
+};
+
+class CollRMADistTest : public ::testing::Test {
+ protected:
+ CollRMADistTest() {}
+
+ ~CollRMADistTest() override {
+ for (DeviceMgr* dm : device_mgrs_) {
+ delete dm;
+ }
+ for (auto it : dev_resolvers_) {
+ delete it.second;
+ }
+ for (FakeWorker* w : workers_) {
+ delete w;
+ }
+ }
+
+ void SetUp() override {
+ const int num_workers = 2;
+ const int num_devices = 1;
+ string device_type = "CPU";
+ ConfigProto config;
+ string dev0_worker_name;
+ for (int w = 0; w < num_workers; ++w) {
+ string name = strings::StrCat("/job:worker/replica:0/task:", w);
+ if (w == 0) {
+ dev0_worker_name = name;
+ // TODO(tucker): Change to use config when available.
+ // config.set_collective_group_leader(name);
+ }
+ DefineWorker(config, name, device_type, num_devices);
+ }
+ // All tests simulate requests from worker 0 to worker 1.
+ rma_.reset(new CollectiveRemoteAccessDistributed(
+ device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId));
+
+ const int kNumElts = 8;
+ expected_value_ = Tensor(DT_FLOAT, {kNumElts});
+ to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+ auto exp_alias = expected_value_.flat<float>();
+ auto to_alias = to_tensor_.flat<float>();
+ for (int i = 0; i < kNumElts; ++i) {
+ exp_alias(i) = i;
+ to_alias(i) = -1;
+ }
+ }
+
+ void DefineWorker(const ConfigProto& config, const string& worker_name,
+ const string& device_type, int num_devices) {
+ std::vector<Device*> devices;
+ for (int i = 0; i < num_devices; ++i) {
+ devices.push_back(NewDevice(
+ device_type,
+ strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+ }
+ DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ device_mgrs_.push_back(dev_mgr);
+ std::vector<string>* dv = &dev_by_task_[worker_name];
+ for (auto d : devices) {
+ dv->push_back(d->name());
+ }
+ DeviceResolverDistributed* dev_res =
+ new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+ dev_resolvers_[worker_name] = dev_res;
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+ workers_.push_back(fw);
+ wc_.AddWorker(worker_name, fw);
+ }
+
+ void ValidateResultTensor() {
+ ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
+ for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+ EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
+ to_tensor_.flat<float>()(i));
+ }
+ }
+
+ FakeCache wc_;
+ CancellationManager cm_;
+ std::vector<DeviceMgr*> device_mgrs_;
+ std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+ std::unordered_map<string, std::vector<string>> dev_by_task_;
+ std::vector<FakeWorker*> workers_;
+ std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
+ mutex mu_;
+ int num_done_ GUARDED_BY(mu_);
+ condition_variable done_;
+ Tensor expected_value_;
+ Tensor to_tensor_;
+ CallOptions opts_;
+ DeviceLocality device_locality_;
+ AllocatorAttributes alloc_attr_;
+};
+
+TEST_F(CollRMADistTest, ProdFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstAbort) {
+ Notification consumer_note;
+ Status consumer_status;
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ rma_->StartAbort(errors::Internal("Deliberate Failure"));
+ consumer_note.WaitForNotification();
+ EXPECT_EQ(consumer_status.error_message(), "Cancelled");
+}
+
+} // namespace
+} // namespace tensorflow
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:recent_request_ids",
cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
+ recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
logging_(Method(GrpcWorkerMethod::kLogging)),
tracing_(Method(GrpcWorkerMethod::kTracing)),
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
IssueRequest(request, response, cleanupall_, std::move(done));
}
+ void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ IssueRequest(request, response, recvbuf_, std::move(done), call_opts);
+ }
+
void CompleteGroupAsync(CallOptions* call_opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
const ::grpc::string cleanupgraph_;
const ::grpc::string cleanupall_;
const ::grpc::string recvtensor_;
+ const ::grpc::string recvbuf_;
const ::grpc::string logging_;
const ::grpc::string tracing_;
const ::grpc::string completegroup_;
#include "grpc++/alarm.h"
#include "grpc++/server_builder.h"
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
for (int i = 0; i < 1000; ++i) {
EnqueueRecvTensorRequestRaw();
}
+ for (int i = 0; i < 500; ++i) {
+ ENQUEUE_REQUEST(RecvBuf, true);
+ }
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunGraph, true);
}
ENQUEUE_REQUEST(Tracing, false);
for (int i = 0; i < 10; ++i) {
- ENQUEUE_REQUEST(CompleteGroup, false);
- ENQUEUE_REQUEST(CompleteInstance, false);
- ENQUEUE_REQUEST(GetStepSequence, false);
+ ENQUEUE_REQUEST(CompleteGroup, true);
+ ENQUEUE_REQUEST(CompleteInstance, true);
+ ENQUEUE_REQUEST(GetStepSequence, true);
}
void* tag;
ENQUEUE_REQUEST(Tracing, false);
}
+ void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
+ Schedule([this, call]() {
+ CallOptions* call_opts = new CallOptions;
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ worker_->RecvBufAsync(call_opts, &call->request, &call->response,
+ [call, call_opts](const Status& s) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ });
+ ENQUEUE_REQUEST(RecvBuf, true);
+ }
+
void CompleteGroupHandler(
WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
Schedule([this, call]() {
call->SendResponse(ToGrpcStatus(s));
});
});
- ENQUEUE_REQUEST(CompleteGroup, false);
+ ENQUEUE_REQUEST(CompleteGroup, true);
}
void CompleteInstanceHandler(
&call->request, &call->response,
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
});
- ENQUEUE_REQUEST(GetStepSequence, false);
+ ENQUEUE_REQUEST(GetStepSequence, true);
}
#undef ENQUEUE_REQUEST
});
}
+void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) {
+ // This is a generic, low performance implementation appropriate for grpc.
+ CollectiveExecutor::Handle ce_handle(
+ env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
+ CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
+ rma->buf_rendezvous()->ConsumeBuf(
+ request->buf_rendezvous_key(),
+ [this, opts, request, response, done](const Status& status,
+ BufRendezvous::Hook* hook) {
+ Status s = status;
+ if (s.ok()) {
+ if (!DMAHelper::CanUseDMA(hook->prod_value)) {
+ s = errors::Internal("Tensor value for key ",
+ request->buf_rendezvous_key(),
+ " is not of a type supported by RecvBuf");
+ }
+ }
+ if (s.ok()) {
+ // The RPC source tensor needs to be in CPU RAM. If not already
+ // there make a copy using memory appropriate to the purpose.
+ const size_t num_bytes = hook->prod_value->TotalBytes();
+ const bool on_host =
+ hook->prod_dev->attributes().device_type() == "CPU" ||
+ hook->prod_attr.on_host();
+ if ((!on_host) && (num_bytes > 0)) {
+ Device* cpu_dev = nullptr;
+ s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
+ if (s.ok()) {
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ cpu_attr.set_nic_compatible(true);
+ Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+ hook->prod_value->dtype(),
+ hook->prod_value->shape());
+ hook->prod_ctx->CopyDeviceTensorToCPU(
+ hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
+ [this, num_bytes, response, done, hook,
+ cpu_tensor](const Status& s) {
+ if (s.ok()) {
+ RecvBufRespExtra extra;
+ extra.set_tensor_content(reinterpret_cast<const char*>(
+ DMAHelper::base(cpu_tensor)),
+ num_bytes);
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ response->set_send_start_micros(env_->env->NowMicros());
+ done(s);
+ BufRendezvous::DoneWithHook(hook);
+ delete cpu_tensor;
+ });
+ return;
+ }
+ } else {
+ // Tensor is on CPU.
+ RecvBufRespExtra extra;
+ extra.set_tensor_content(reinterpret_cast<const char*>(
+ DMAHelper::base(hook->prod_value)),
+ num_bytes);
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ }
+ response->set_send_start_micros(env_->env->NowMicros());
+ done(s);
+ BufRendezvous::DoneWithHook(hook);
+ });
+}
+
void GrpcWorker::LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) {
auto env = this->env();
virtual void LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done);
+ virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done);
+
WorkerEnv* env();
private:
return "/tensorflow.WorkerService/CleanupAll";
case GrpcWorkerMethod::kRecvTensor:
return "/tensorflow.WorkerService/RecvTensor";
+ case GrpcWorkerMethod::kRecvBuf:
+ return "/tensorflow.WorkerService/RecvBuf";
case GrpcWorkerMethod::kLogging:
return "/tensorflow.WorkerService/Logging";
case GrpcWorkerMethod::kTracing:
kCleanupGraph,
kCleanupAll,
kRecvTensor,
+ kRecvBuf,
kLogging,
kTracing,
kCompleteGroup,
done(errors::Unimplemented("RunGraphAsync"));
}
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ done(errors::Unimplemented("RecvBufAsync"));
+ }
+
void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
done(errors::Unimplemented("Tracing"));
}
+void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) {
+ // The base Worker class does not implement RecvBufAsync because
+ // it is not currently used for worker-to-worker communication. Use a
+ // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
+ // instead.
+ done(errors::Unimplemented("Worker::RecvBufAsync()"));
+}
+
void Worker::CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
void TracingAsync(const TracingRequest* request, TracingResponse* response,
StatusCallback done) override;
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override;
+
void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
virtual void TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) = 0;
+ virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) = 0;
+
virtual void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
--- /dev/null
+syntax = "proto3";
+
+package tensorflow;
+
+// Extra data needed on a non-RDMA RecvBufResponse.
+message RecvBufRespExtra {
+ bytes tensor_content = 1;
+};
////////////////////////////////////////////////////////////////////////////////
//
+// Raw data transfers in support of Collective Ops.
+// These methods are experimental and subject to change.
+//
+// The intention is to allow collectives to take advantage of the most
+// efficient methods available on a platform, e.g. RDMA, and not be
+// constrained to use the RPC system in use by other methods.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RecvBufRequest {
+ // Use of the fields below may vary by implementation. For example
+ // the buf_ptr and num_bytes may be set only for local operations and
+ // not sent on the wire, or only sent on the wire in one direction.
+
+ // Used at server side to find the correct BufRendezvous.
+ int64 step_id = 1;
+
+ // Arbitrary string identifying a BufRendezvous entry.
+ string buf_rendezvous_key = 2;
+
+ // Size of value expected, must agree with BufRendezvous entry.
+ int64 num_bytes = 3;
+
+ // When RDMA is in use, address of destination field on client.
+ fixed64 buf_ptr = 4;
+
+ // Optional information on client-side device locality.
+ DeviceLocality client_locality = 5;
+
+ // Optional information on server-side device locality.
+ DeviceLocality server_locality = 6;
+
+ // Optional, implementation-specific data.
+ google.protobuf.Any transport_options = 7;
+ // Optional, for annotating the timeline.
+ string src_device = 8;
+ string dst_device = 9;
+}
+
+message RecvBufResponse {
+ // Use of the fields below may vary by implementation. Comments give
+ // intended use.
+
+ fixed64 buf_ptr = 1; // Address of source field on server.
+ int64 num_bytes = 2; // Byte length of buf_ptr field, if set.
+ bool is_dead = 3; // True if value is 'dead' like a tensor.
+ // Optional, implementation-specific data.
+ google.protobuf.Any transport_options = 4;
+ // Optional, for timeline.
+ int64 send_start_micros = 5;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
// Collective Op dynamic group resolution messages.
//
////////////////////////////////////////////////////////////////////////////////
rpc Tracing(TracingRequest) returns (TracingResponse);
// See worker.proto for details.
+ rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {
+ }
+
+ // See worker.proto for details.
rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
// See worker.proto for details.