Collective Ops Part 6
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 19:26:06 +0000 (12:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 20:19:17 +0000 (13:19 -0700)
Distributed-mode implementations of CollectiveRemoteAccess.
Extend Worker interface with corresponding new methods.

This change is part of a series of changes introducing infrastructure
for collective ops and initial implementations of reduction and broadcast.

PiperOrigin-RevId: 196010718

19 files changed:
tensorflow/core/BUILD
tensorflow/core/distributed_runtime/BUILD
tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
tensorflow/core/distributed_runtime/collective_rma_distributed.cc [new file with mode: 0644]
tensorflow/core/distributed_runtime/collective_rma_distributed.h [new file with mode: 0644]
tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc [new file with mode: 0644]
tensorflow/core/distributed_runtime/rpc/BUILD
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
tensorflow/core/distributed_runtime/test_utils.h
tensorflow/core/distributed_runtime/worker.cc
tensorflow/core/distributed_runtime/worker.h
tensorflow/core/distributed_runtime/worker_interface.h
tensorflow/core/protobuf/transport_options.proto [new file with mode: 0644]
tensorflow/core/protobuf/worker.proto
tensorflow/core/protobuf/worker_service.proto

index 76ff372..ccb8488 100644 (file)
@@ -224,6 +224,7 @@ ADDITIONAL_CORE_PROTO_SRCS = [
     "protobuf/named_tensor.proto",
     "protobuf/saved_model.proto",
     "protobuf/tensorflow_server.proto",
+    "protobuf/transport_options.proto",
     "util/test_log.proto",
 ]
 
index 256ce52..18b7069 100644 (file)
@@ -453,6 +453,40 @@ cc_library(
 )
 
 cc_library(
+    name = "collective_rma_distributed",
+    srcs = ["collective_rma_distributed.cc"],
+    hdrs = ["collective_rma_distributed.h"],
+    deps = [
+        ":worker_cache",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",  # protobuf::Any
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:worker_proto_cc",
+    ],
+)
+
+tf_cc_test(
+    name = "collective_rma_distributed_test",
+    size = "small",
+    srcs = ["collective_rma_distributed_test.cc"],
+    deps = [
+        ":collective_rma_distributed",
+        ":device_resolver_distributed",
+        ":test_utils",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core:worker_proto_cc",
+    ],
+)
+
+cc_library(
     name = "collective_param_resolver_distributed",
     srcs = ["collective_param_resolver_distributed.cc"],
     hdrs = ["collective_param_resolver_distributed.h"],
index ecf5db8..7a93b54 100644 (file)
@@ -284,7 +284,6 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
     const GroupRecCallback& done) {
   VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
           << " dev: " << device << " is_leader=" << (group_leader_.empty());
-  VLOG(0) << "cp: " << cp->ToString();
   if (group_leader_.empty()) {
     // This is the group leader, so resolution is local.
     return CompleteGroupLocal(device, cp, done);
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
new file mode 100644 (file)
index 0000000..54adcb9
--- /dev/null
@@ -0,0 +1,206 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/platform/protobuf_internal.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+//
+// TODO(tucker): Maybe unify this with CancellableCall in
+// collective_param_resolver_distributed.cc.
+class CancellableCall {
+ public:
+  CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+                  WorkerCacheInterface* wc)
+      : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
+    wi_ = wc_->CreateWorker(remote_worker_);
+  }
+  virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+  virtual void IssueCall(const StatusCallback& done) = 0;
+
+  void Start(const StatusCallback& done) {
+    CancellationToken token = cancel_mgr_->get_cancellation_token();
+    const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+        token, [this, token]() { opts_.StartCancel(); });
+    if (not_yet_cancelled) {
+      IssueCall([this, token, done](const Status& s) {
+        cancel_mgr_->DeregisterCallback(token);
+        done(s);
+      });
+    } else {
+      done(errors::Cancelled("RPC Request was cancelled"));
+    }
+  }
+
+ protected:
+  mutable mutex mu_;
+  CancellationManager* cancel_mgr_;  // Not owned
+  const string remote_worker_;
+  WorkerCacheInterface* wc_;  // Not owned
+  WorkerInterface* wi_;       // Owned by wc_, must be released.
+  CallOptions opts_;
+};
+
+class RecvBufCall : public CancellableCall {
+ public:
+  RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
+              const string& key, Device* to_device,
+              DeviceContext* to_device_ctx,
+              const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+              const DeviceLocality& client_locality,
+              const DeviceLocality& server_locality,
+              CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
+      : CancellableCall(cancel_mgr, peer_task, wc) {
+    req_.set_step_id(step_id);
+    req_.set_buf_rendezvous_key(key);
+    *req_.mutable_client_locality() = client_locality;
+    *req_.mutable_server_locality() = server_locality;
+    req_.set_num_bytes(to_tensor->TotalBytes());
+    req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
+    req_.set_src_device(peer_device);
+    req_.set_dst_device(to_device->name());
+  }
+
+  ~RecvBufCall() override {}
+
+  void IssueCall(const StatusCallback& done) override {
+    wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
+  }
+
+  RecvBufRequest req_;
+  RecvBufResponse resp_;
+};
+
+}  // namespace
+
+void CollectiveRemoteAccessDistributed::RecvFromPeer(
+    const string& peer_device, const string& peer_task, bool peer_is_local,
+    const string& key, Device* to_device, DeviceContext* to_device_ctx,
+    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+    const DeviceLocality& client_locality, const StatusCallback& done) {
+  if (peer_is_local) {
+    CollectiveRemoteAccessLocal::RecvFromPeer(
+        peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
+        to_alloc_attr, to_tensor, client_locality, done);
+    return;
+  }
+
+  // State that needs to be threaded through a couple of async calls
+  // in order to make this function completely non-blocking.
+  struct State {
+    DeviceLocality server_locality;
+    std::unique_ptr<RecvBufCall> call;
+  };
+  State* state = new State;
+
+  // Logic to be executed on the RecvBufferAsync callback.
+  auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
+                            to_device_ctx, to_tensor, done](const Status& s) {
+    std::unique_ptr<State> del_on_exit(state);
+    if (s.ok()) {
+      // In this generic implementation the bytes come back in the
+      // RPC response protobuf rather than via RDMA so we need to copy
+      // them into the destination tensor here.
+      RecvBufRespExtra extra;
+      state->call->resp_.transport_options().UnpackTo(&extra);
+      int64 num_bytes = extra.tensor_content().size();
+      if (num_bytes != to_tensor->TotalBytes()) {
+        done(errors::Internal("RecvBufResponse returned ", num_bytes,
+                              " bytes where to_tensor expected ",
+                              to_tensor->TotalBytes()));
+        return;
+      }
+      if (to_device->tensorflow_gpu_device_info()) {
+        // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
+        // Use GPU-registered memory for the CPU tensor so the transfer
+        // goes faster.
+        Device* cpu_dev = nullptr;
+        Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
+        if (!status.ok()) {
+          done(status);
+          return;
+        }
+        AllocatorAttributes cpu_attr;
+        cpu_attr.set_gpu_compatible(true);
+        Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+                                        to_tensor->dtype(), to_tensor->shape());
+        memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(),
+               num_bytes);
+        // Then copy it to the GPU.
+        CopyTensor::ViaDMA("",  // edge name (non-existent)
+                           nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
+                           to_device, cpu_attr, to_alloc_attr, cpu_tensor,
+                           to_tensor,
+                           [this, cpu_tensor, done](const Status& s) {
+                             delete cpu_tensor;
+                             // This callback must not block, so execute
+                             // done in another thread.
+                             SchedClosure([s, done] { done(s); });
+                           });
+        return;
+      } else {
+        // CPU device
+        memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(),
+               num_bytes);
+      }
+    }
+    if (!s.ok() && errors::IsFailedPrecondition(s)) {
+      dev_resolver_->ClearTask(peer_task);
+    }
+
+    done(s);
+  };
+
+  // Logic to execute once we have the device locality for the server-side
+  // device.
+  auto dev_locality_callback = [this, state, peer_device, peer_task, key,
+                                to_device, to_device_ctx, to_alloc_attr,
+                                to_tensor, client_locality,
+                                recv_buf_callback](const Status& s) {
+    if (!s.ok()) {
+      recv_buf_callback(s);
+    } else {
+      state->call.reset(new RecvBufCall(
+          step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
+          to_alloc_attr, to_tensor, client_locality, state->server_locality,
+          &cancel_mgr_, worker_cache_));
+      state->call->Start(recv_buf_callback);
+    }
+  };
+
+  dev_resolver_->GetLocalityAsync(
+      peer_device, peer_task, &state->server_locality, dev_locality_callback);
+}
+
+void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
+  CollectiveRemoteAccessLocal::StartAbort(s);
+  cancel_mgr_.StartCancel();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
new file mode 100644 (file)
index 0000000..cfa9110
--- /dev/null
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+class WorkerCacheInterface;
+
+// Extend CollectiveRemoteAccessLocal with access to remote peers.
+class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
+ public:
+  CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
+                                    DeviceResolverInterface* dev_resolver,
+                                    WorkerCacheInterface* worker_cache,
+                                    int64 step_id)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+        worker_cache_(worker_cache) {}
+
+  ~CollectiveRemoteAccessDistributed() override {}
+
+  void RecvFromPeer(const string& peer_device, const string& peer_task,
+                    bool peer_is_local, const string& key, Device* to_device,
+                    DeviceContext* to_device_ctx,
+                    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+                    const DeviceLocality& client_locality,
+                    const StatusCallback& done) override;
+
+  void StartAbort(const Status& s) override;
+
+ protected:
+  WorkerCacheInterface* worker_cache_;  // Not owned
+  CancellationManager cancel_mgr_;
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
new file mode 100644 (file)
index 0000000..a552f81
--- /dev/null
@@ -0,0 +1,356 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+// The only interesting method on CollectiveRemoteAccessDistributed
+// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
+// issues a RecvBufAsync call against a WorkerInterface.  That's all
+// that's tested here.  Note that RecvFromPeer can do a
+// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
+// for the RecvBufAsync.
+
+namespace tensorflow {
+namespace {
+
+static Device* NewDevice(const string& type, const string& name) {
+  class FakeDevice : public Device {
+   public:
+    explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+    Status Sync() override { return Status::OK(); }
+    Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+  };
+  DeviceAttributes attr;
+  attr.set_name(name);
+  attr.set_device_type(type);
+  attr.mutable_locality()->set_numa_node(3);  // a non-default value
+  return new FakeDevice(attr);
+}
+
+static int64 kStepId = 123;
+
+class FakeWorker : public TestWorkerInterface {
+ public:
+  FakeWorker(const string& name, DeviceMgr* dev_mgr,
+             DeviceResolverDistributed* dres)
+      : name_(name),
+        device_mgr_(dev_mgr),
+        device_resolver_(dres),
+        buf_rendezvous_(kStepId) {}
+
+  // Direct access to a BufRendezvous that holds whatever the remote
+  // worker is supposed to have.
+  BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
+
+  void GetStatusAsync(const GetStatusRequest* request,
+                      GetStatusResponse* response,
+                      StatusCallback done) override {
+    std::vector<DeviceAttributes> dev_attr;
+    device_mgr_->ListDeviceAttributes(&dev_attr);
+    for (const auto& da : dev_attr) {
+      *response->add_device_attributes() = da;
+    }
+    done(Status::OK());
+  }
+
+  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                    RecvBufResponse* response, StatusCallback done) override {
+    opts->SetCancelCallback([this]() {
+      // Within this test the call is satisfied by a process-local
+      // BufRendezvous table. In real application the BufRendezvous
+      // would be on the other side of a network hop, so call
+      // BufRendezvous::StartAbort() from a separate thread to be
+      // more consistent with that situation and avoid mutex deadlock.
+      SchedClosure([this]() {
+        Env::Default()->SleepForMicroseconds(100);
+        buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
+      });
+    });
+    buf_rendezvous_.ConsumeBuf(
+        request->buf_rendezvous_key(),
+        [this, opts, request, response, done](const Status& s,
+                                              BufRendezvous::Hook* h) {
+          if (s.ok()) {
+            opts->ClearCancelCallback();
+            // Since this is not really RDMA into pre-allocated memory send the
+            // bytes in the response.
+            RecvBufRespExtra extra;
+            int64 num_bytes = h->prod_value->TotalBytes();
+            extra.set_tensor_content(string(
+                reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+                num_bytes));
+            response->mutable_transport_options()->PackFrom(extra);
+          }
+          done(s);
+          if (h) BufRendezvous::DoneWithHook(h);
+        });
+  }
+
+ private:
+  string name_;
+  DeviceMgr* device_mgr_;
+  DeviceResolverDistributed* device_resolver_;
+  BufRendezvous buf_rendezvous_;
+};
+
+class FakeCache : public TestWorkerCache {
+ public:
+  // Override the Locality methods to actually pass through to the
+  // worker.
+  bool GetDeviceLocalityNonBlocking(const string& device,
+                                    DeviceLocality* locality) override {
+    return false;
+  }
+
+  void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+                              StatusCallback done) override {
+    string task_name;
+    string dev_part;
+    if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+      done(errors::Internal("failed to parse device name"));
+      return;
+    }
+    auto it = workers_.find(task_name);
+    if (it == workers_.end()) {
+      done(errors::Internal("failed to find worker ", task_name));
+      return;
+    }
+    WorkerInterface* wi = it->second;
+    GetStatusRequest req;
+    GetStatusResponse resp;
+    Notification note;
+    Status status = wi->GetStatus(&req, &resp);
+    if (!status.ok()) {
+      done(status);
+      return;
+    }
+    for (const auto& it : resp.device_attributes()) {
+      if (it.name() == device) {
+        *locality = it.locality();
+        done(Status::OK());
+        return;
+      }
+    }
+    done(errors::Internal("device not found: ", device));
+  }
+};
+
+class CollRMADistTest : public ::testing::Test {
+ protected:
+  CollRMADistTest() {}
+
+  ~CollRMADistTest() override {
+    for (DeviceMgr* dm : device_mgrs_) {
+      delete dm;
+    }
+    for (auto it : dev_resolvers_) {
+      delete it.second;
+    }
+    for (FakeWorker* w : workers_) {
+      delete w;
+    }
+  }
+
+  void SetUp() override {
+    const int num_workers = 2;
+    const int num_devices = 1;
+    string device_type = "CPU";
+    ConfigProto config;
+    string dev0_worker_name;
+    for (int w = 0; w < num_workers; ++w) {
+      string name = strings::StrCat("/job:worker/replica:0/task:", w);
+      if (w == 0) {
+        dev0_worker_name = name;
+        // TODO(tucker): Change to use config when available.
+        // config.set_collective_group_leader(name);
+      }
+      DefineWorker(config, name, device_type, num_devices);
+    }
+    // All tests simulate requests from worker 0 to worker 1.
+    rma_.reset(new CollectiveRemoteAccessDistributed(
+        device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId));
+
+    const int kNumElts = 8;
+    expected_value_ = Tensor(DT_FLOAT, {kNumElts});
+    to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+    auto exp_alias = expected_value_.flat<float>();
+    auto to_alias = to_tensor_.flat<float>();
+    for (int i = 0; i < kNumElts; ++i) {
+      exp_alias(i) = i;
+      to_alias(i) = -1;
+    }
+  }
+
+  void DefineWorker(const ConfigProto& config, const string& worker_name,
+                    const string& device_type, int num_devices) {
+    std::vector<Device*> devices;
+    for (int i = 0; i < num_devices; ++i) {
+      devices.push_back(NewDevice(
+          device_type,
+          strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+    }
+    DeviceMgr* dev_mgr = new DeviceMgr(devices);
+    device_mgrs_.push_back(dev_mgr);
+    std::vector<string>* dv = &dev_by_task_[worker_name];
+    for (auto d : devices) {
+      dv->push_back(d->name());
+    }
+    DeviceResolverDistributed* dev_res =
+        new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+    dev_resolvers_[worker_name] = dev_res;
+    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+    workers_.push_back(fw);
+    wc_.AddWorker(worker_name, fw);
+  }
+
+  void ValidateResultTensor() {
+    ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
+    for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+      EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
+                      to_tensor_.flat<float>()(i));
+    }
+  }
+
+  FakeCache wc_;
+  CancellationManager cm_;
+  std::vector<DeviceMgr*> device_mgrs_;
+  std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+  std::unordered_map<string, std::vector<string>> dev_by_task_;
+  std::vector<FakeWorker*> workers_;
+  std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
+  mutex mu_;
+  int num_done_ GUARDED_BY(mu_);
+  condition_variable done_;
+  Tensor expected_value_;
+  Tensor to_tensor_;
+  CallOptions opts_;
+  DeviceLocality device_locality_;
+  AllocatorAttributes alloc_attr_;
+};
+
+TEST_F(CollRMADistTest, ProdFirstOK) {
+  Notification consumer_note;
+  Notification producer_note;
+  Status consumer_status;
+  Status producer_status;
+  FakeWorker* wi = workers_[1];
+  const string kBufKey = "fake_buf_key";
+  wi->buf_rendezvous()->ProvideBuf(
+      kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+      AllocatorAttributes(),
+      [this, &producer_note, &producer_status](const Status& s) {
+        producer_status.Update(s);
+        producer_note.Notify();
+      });
+  Status status;
+  Device* dst_device = nullptr;
+  string dev_name = "CPU:0";
+  TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  DeviceContext* to_device_ctx = nullptr;
+  rma_->RecvFromPeer(
+      "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
+      "/job:worker/replica:0/task:1",                     // peer_task
+      false,                                              // peer_is_local
+      kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+      device_locality_,
+      [this, &consumer_status, &consumer_note](const Status& s) {
+        consumer_status = s;
+        consumer_note.Notify();
+      });
+  consumer_note.WaitForNotification();
+  TF_EXPECT_OK(consumer_status);
+  producer_note.WaitForNotification();
+  TF_EXPECT_OK(producer_status);
+  ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstOK) {
+  Notification consumer_note;
+  Notification producer_note;
+  Status consumer_status;
+  Status producer_status;
+  FakeWorker* wi = workers_[1];
+  const string kBufKey = "fake_buf_key";
+  Status status;
+  Device* dst_device = nullptr;
+  string dev_name = "CPU:0";
+  TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  DeviceContext* to_device_ctx = nullptr;
+  rma_->RecvFromPeer(
+      "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
+      "/job:worker/replica:0/task:1",                     // peer_task
+      false,                                              // peer_is_local
+      kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+      device_locality_,
+      [this, &consumer_status, &consumer_note](const Status& s) {
+        consumer_status = s;
+        consumer_note.Notify();
+      });
+  wi->buf_rendezvous()->ProvideBuf(
+      kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+      AllocatorAttributes(),
+      [this, &producer_note, &producer_status](const Status& s) {
+        producer_status.Update(s);
+        producer_note.Notify();
+      });
+  consumer_note.WaitForNotification();
+  TF_EXPECT_OK(consumer_status);
+  producer_note.WaitForNotification();
+  TF_EXPECT_OK(producer_status);
+  ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstAbort) {
+  Notification consumer_note;
+  Status consumer_status;
+  const string kBufKey = "fake_buf_key";
+  Status status;
+  Device* dst_device = nullptr;
+  string dev_name = "CPU:0";
+  TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  DeviceContext* to_device_ctx = nullptr;
+  rma_->RecvFromPeer(
+      "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
+      "/job:worker/replica:0/task:1",                     // peer_task
+      false,                                              // peer_is_local
+      kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+      device_locality_,
+      [this, &consumer_status, &consumer_note](const Status& s) {
+        consumer_status = s;
+        consumer_note.Notify();
+      });
+  rma_->StartAbort(errors::Internal("Deliberate Failure"));
+  consumer_note.WaitForNotification();
+  EXPECT_EQ(consumer_status.error_message(), "Cancelled");
+}
+
+}  // namespace
+}  // namespace tensorflow
index c2719f5..40028ee 100644 (file)
@@ -171,6 +171,7 @@ tf_cuda_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:worker_proto_cc",
         "//tensorflow/core/distributed_runtime:graph_mgr",
         "//tensorflow/core/distributed_runtime:recent_request_ids",
index 5b7b74c..1acf1fb 100644 (file)
@@ -54,6 +54,7 @@ class GrpcRemoteWorker : public WorkerInterface {
         cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
         cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
         recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
+        recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
         logging_(Method(GrpcWorkerMethod::kLogging)),
         tracing_(Method(GrpcWorkerMethod::kTracing)),
         completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
@@ -118,6 +119,11 @@ class GrpcRemoteWorker : public WorkerInterface {
     IssueRequest(request, response, cleanupall_, std::move(done));
   }
 
+  void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
+                    RecvBufResponse* response, StatusCallback done) override {
+    IssueRequest(request, response, recvbuf_, std::move(done), call_opts);
+  }
+
   void CompleteGroupAsync(CallOptions* call_opts,
                           const CompleteGroupRequest* request,
                           CompleteGroupResponse* response,
@@ -239,6 +245,7 @@ class GrpcRemoteWorker : public WorkerInterface {
   const ::grpc::string cleanupgraph_;
   const ::grpc::string cleanupall_;
   const ::grpc::string recvtensor_;
+  const ::grpc::string recvbuf_;
   const ::grpc::string logging_;
   const ::grpc::string tracing_;
   const ::grpc::string completegroup_;
index 26fad1f..4383e41 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include "grpc++/alarm.h"
 #include "grpc++/server_builder.h"
 
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
@@ -37,10 +38,12 @@ limitations under the License.
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
 #include "tensorflow/core/distributed_runtime/worker_session.h"
 #include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
 
 namespace tensorflow {
@@ -159,6 +162,9 @@ class GrpcWorkerService : public AsyncServiceInterface {
       for (int i = 0; i < 1000; ++i) {
         EnqueueRecvTensorRequestRaw();
       }
+      for (int i = 0; i < 500; ++i) {
+        ENQUEUE_REQUEST(RecvBuf, true);
+      }
       for (int i = 0; i < 100; ++i) {
         ENQUEUE_REQUEST(RunGraph, true);
       }
@@ -170,9 +176,9 @@ class GrpcWorkerService : public AsyncServiceInterface {
       ENQUEUE_REQUEST(Tracing, false);
 
       for (int i = 0; i < 10; ++i) {
-        ENQUEUE_REQUEST(CompleteGroup, false);
-        ENQUEUE_REQUEST(CompleteInstance, false);
-        ENQUEUE_REQUEST(GetStepSequence, false);
+        ENQUEUE_REQUEST(CompleteGroup, true);
+        ENQUEUE_REQUEST(CompleteInstance, true);
+        ENQUEUE_REQUEST(GetStepSequence, true);
       }
 
       void* tag;
@@ -322,6 +328,20 @@ class GrpcWorkerService : public AsyncServiceInterface {
       ENQUEUE_REQUEST(Tracing, false);
     }
 
+    void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
+      Schedule([this, call]() {
+        CallOptions* call_opts = new CallOptions;
+        call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+        worker_->RecvBufAsync(call_opts, &call->request, &call->response,
+                              [call, call_opts](const Status& s) {
+                                call->ClearCancelCallback();
+                                delete call_opts;
+                                call->SendResponse(ToGrpcStatus(s));
+                              });
+      });
+      ENQUEUE_REQUEST(RecvBuf, true);
+    }
+
     void CompleteGroupHandler(
         WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
       Schedule([this, call]() {
@@ -334,7 +354,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
                                       call->SendResponse(ToGrpcStatus(s));
                                     });
       });
-      ENQUEUE_REQUEST(CompleteGroup, false);
+      ENQUEUE_REQUEST(CompleteGroup, true);
     }
 
     void CompleteInstanceHandler(
@@ -360,7 +380,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
             &call->request, &call->response,
             [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
       });
-      ENQUEUE_REQUEST(GetStepSequence, false);
+      ENQUEUE_REQUEST(GetStepSequence, true);
     }
 #undef ENQUEUE_REQUEST
 
@@ -485,6 +505,74 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
       });
 }
 
+void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                              RecvBufResponse* response, StatusCallback done) {
+  // This is a generic, low performance implementation appropriate for grpc.
+  CollectiveExecutor::Handle ce_handle(
+      env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
+  CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
+  rma->buf_rendezvous()->ConsumeBuf(
+      request->buf_rendezvous_key(),
+      [this, opts, request, response, done](const Status& status,
+                                            BufRendezvous::Hook* hook) {
+        Status s = status;
+        if (s.ok()) {
+          if (!DMAHelper::CanUseDMA(hook->prod_value)) {
+            s = errors::Internal("Tensor value for key ",
+                                 request->buf_rendezvous_key(),
+                                 " is not of a type supported by RecvBuf");
+          }
+        }
+        if (s.ok()) {
+          // The RPC source tensor needs to be in CPU RAM.  If not already
+          // there make a copy using memory appropriate to the purpose.
+          const size_t num_bytes = hook->prod_value->TotalBytes();
+          const bool on_host =
+              hook->prod_dev->attributes().device_type() == "CPU" ||
+              hook->prod_attr.on_host();
+          if ((!on_host) && (num_bytes > 0)) {
+            Device* cpu_dev = nullptr;
+            s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
+            if (s.ok()) {
+              AllocatorAttributes cpu_attr;
+              cpu_attr.set_gpu_compatible(true);
+              cpu_attr.set_nic_compatible(true);
+              Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+                                              hook->prod_value->dtype(),
+                                              hook->prod_value->shape());
+              hook->prod_ctx->CopyDeviceTensorToCPU(
+                  hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
+                  [this, num_bytes, response, done, hook,
+                   cpu_tensor](const Status& s) {
+                    if (s.ok()) {
+                      RecvBufRespExtra extra;
+                      extra.set_tensor_content(reinterpret_cast<const char*>(
+                                                   DMAHelper::base(cpu_tensor)),
+                                               num_bytes);
+                      response->mutable_transport_options()->PackFrom(extra);
+                    }
+                    response->set_send_start_micros(env_->env->NowMicros());
+                    done(s);
+                    BufRendezvous::DoneWithHook(hook);
+                    delete cpu_tensor;
+                  });
+              return;
+            }
+          } else {
+            // Tensor is on CPU.
+            RecvBufRespExtra extra;
+            extra.set_tensor_content(reinterpret_cast<const char*>(
+                                         DMAHelper::base(hook->prod_value)),
+                                     num_bytes);
+            response->mutable_transport_options()->PackFrom(extra);
+          }
+        }
+        response->set_send_start_micros(env_->env->NowMicros());
+        done(s);
+        BufRendezvous::DoneWithHook(hook);
+      });
+}
+
 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
                               LoggingResponse* response, StatusCallback done) {
   auto env = this->env();
index fbddbda..c0ed088 100644 (file)
@@ -43,6 +43,9 @@ class GrpcWorker : public Worker {
   virtual void LoggingAsync(const LoggingRequest* request,
                             LoggingResponse* response, StatusCallback done);
 
+  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                            RecvBufResponse* response, StatusCallback done);
+
   WorkerEnv* env();
 
  private:
index a91cc06..38cc2b8 100644 (file)
@@ -46,6 +46,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
       return "/tensorflow.WorkerService/CleanupAll";
     case GrpcWorkerMethod::kRecvTensor:
       return "/tensorflow.WorkerService/RecvTensor";
+    case GrpcWorkerMethod::kRecvBuf:
+      return "/tensorflow.WorkerService/RecvBuf";
     case GrpcWorkerMethod::kLogging:
       return "/tensorflow.WorkerService/Logging";
     case GrpcWorkerMethod::kTracing:
index c5104c6..da27083 100644 (file)
@@ -81,6 +81,7 @@ enum class GrpcWorkerMethod {
   kCleanupGraph,
   kCleanupAll,
   kRecvTensor,
+  kRecvBuf,
   kLogging,
   kTracing,
   kCompleteGroup,
index 0ed0782..48d8384 100644 (file)
@@ -93,6 +93,11 @@ class TestWorkerInterface : public WorkerInterface {
     done(errors::Unimplemented("RunGraphAsync"));
   }
 
+  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                    RecvBufResponse* response, StatusCallback done) override {
+    done(errors::Unimplemented("RecvBufAsync"));
+  }
+
   void CompleteGroupAsync(CallOptions* opts,
                           const CompleteGroupRequest* request,
                           CompleteGroupResponse* response,
index d682ac8..4e6500f 100644 (file)
@@ -337,6 +337,15 @@ void Worker::TracingAsync(const TracingRequest* request,
   done(errors::Unimplemented("Tracing"));
 }
 
+void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                          RecvBufResponse* response, StatusCallback done) {
+  // The base Worker class does not implement RecvBufAsync because
+  // it is not currently used for worker-to-worker communication. Use a
+  // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
+  // instead.
+  done(errors::Unimplemented("Worker::RecvBufAsync()"));
+}
+
 void Worker::CompleteGroupAsync(CallOptions* opts,
                                 const CompleteGroupRequest* request,
                                 CompleteGroupResponse* response,
index b5a9ada..91eb27c 100644 (file)
@@ -90,6 +90,9 @@ class Worker : public WorkerInterface {
   void TracingAsync(const TracingRequest* request, TracingResponse* response,
                     StatusCallback done) override;
 
+  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                    RecvBufResponse* response, StatusCallback done) override;
+
   void CompleteGroupAsync(CallOptions* opts,
                           const CompleteGroupRequest* request,
                           CompleteGroupResponse* response,
index bad31d2..a50ac3b 100644 (file)
@@ -112,6 +112,9 @@ class WorkerInterface {
   virtual void TracingAsync(const TracingRequest* request,
                             TracingResponse* response, StatusCallback done) = 0;
 
+  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+                            RecvBufResponse* response, StatusCallback done) = 0;
+
   virtual void CompleteGroupAsync(CallOptions* opts,
                                   const CompleteGroupRequest* request,
                                   CompleteGroupResponse* response,
diff --git a/tensorflow/core/protobuf/transport_options.proto b/tensorflow/core/protobuf/transport_options.proto
new file mode 100644 (file)
index 0000000..d7b1bdd
--- /dev/null
@@ -0,0 +1,8 @@
+syntax = "proto3";
+
+package tensorflow;
+
+// Extra data needed on a non-RDMA RecvBufResponse.
+message RecvBufRespExtra {
+  bytes tensor_content = 1;
+};
index 602f6a1..f7816e9 100644 (file)
@@ -418,6 +418,60 @@ message TracingResponse {
 
 ////////////////////////////////////////////////////////////////////////////////
 //
+// Raw data transfers in support of Collective Ops.
+// These methods are experimental and subject to change.
+//
+// The intention is to allow collectives to take advantage of the most
+// efficient methods available on a platform, e.g. RDMA, and not be
+// constrained to use the RPC system in use by other methods.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RecvBufRequest {
+  // Use of the fields below may vary by implementation.  For example
+  // the buf_ptr and num_bytes may be set only for local operations and
+  // not sent on the wire, or only sent on the wire in one direction.
+
+  // Used at server side to find the correct BufRendezvous.
+  int64 step_id = 1;
+
+  // Arbitrary string identifying a BufRendezvous entry.
+  string buf_rendezvous_key = 2;
+
+  // Size of value expected, must agree with BufRendezvous entry.
+  int64 num_bytes = 3;
+
+  // When RDMA is in use, address of destination field on client.
+  fixed64 buf_ptr = 4;
+
+  // Optional information on client-side device locality.
+  DeviceLocality client_locality = 5;
+
+  // Optional information on server-side device locality.
+  DeviceLocality server_locality = 6;
+
+  // Optional, implementation-specific data.
+  google.protobuf.Any transport_options = 7;
+  // Optional, for annotating the timeline.
+  string src_device = 8;
+  string dst_device = 9;
+}
+
+message RecvBufResponse {
+  // Use of the fields below may vary by implementation.  Comments give
+  // intended use.
+
+  fixed64 buf_ptr = 1;  // Address of source field on server.
+  int64 num_bytes = 2;  // Byte length of buf_ptr field, if set.
+  bool is_dead = 3;     // True if value is 'dead' like a tensor.
+  // Optional, implementation-specific data.
+  google.protobuf.Any transport_options = 4;
+  // Optional, for timeline.
+  int64 send_start_micros = 5;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
 // Collective Op dynamic group resolution messages.
 //
 ////////////////////////////////////////////////////////////////////////////////
index 01c76c0..e0c27f3 100644 (file)
@@ -74,6 +74,10 @@ service WorkerService {
   rpc Tracing(TracingRequest) returns (TracingResponse);
 
   // See worker.proto for details.
+  rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {
+  }
+
+  // See worker.proto for details.
   rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
 
   // See worker.proto for details.