"framework/attr_value_util.h",
"framework/bfloat16.h",
"framework/cancellation.h",
+ "framework/collective.h",
"framework/common_shape_fns.h",
"framework/control_flow.h", # TODO(josh11b): Make internal?
"framework/dataset.h",
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/bfc_allocator.h",
+ "common_runtime/collective_executor_mgr.h",
+ "common_runtime/collective_param_resolver_local.h",
+ "common_runtime/collective_rma_local.h",
+ "common_runtime/device_resolver_local.h",
+ "common_runtime/buf_rendezvous.h",
"common_runtime/build_graph_options.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/accumulate_n_optimizer.cc",
"common_runtime/allocator_retry.cc",
"common_runtime/bfc_allocator.cc",
+ "common_runtime/buf_rendezvous.cc",
"common_runtime/build_graph_options.cc",
+ "common_runtime/collective_executor_mgr.cc",
+ "common_runtime/collective_param_resolver_local.cc",
+ "common_runtime/collective_rma_local.cc",
"common_runtime/constant_folding.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/costmodel_manager.cc",
"common_runtime/device.cc",
"common_runtime/device_factory.cc",
"common_runtime/device_mgr.cc",
+ "common_runtime/device_resolver_local.cc",
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
"common_runtime/function.cc",
name = "higher_level_tests",
size = "small",
srcs = [
+ "common_runtime/buf_rendezvous_test.cc",
+ "common_runtime/collective_executor_mgr_test.cc",
+ "common_runtime/collective_param_resolver_local_test.cc",
+ "common_runtime/collective_rma_local_test.cc",
+ "common_runtime/device_resolver_local_test.cc",
"common_runtime/device_set_test.cc",
"common_runtime/optimization_registry_test.cc",
"common_runtime/pending_counts_test.cc",
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+
+namespace tensorflow {
+
+BufRendezvous::~BufRendezvous() {
+ mutex_lock l(mu_);
+ if (!hook_table_.empty()) {
+ PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
+ &hook_table_);
+ }
+}
+
+void BufRendezvous::StartAbort(const Status& s) {
+ CHECK(!s.ok());
+ HookTable dummy_table;
+ {
+ mutex_lock l(mu_);
+ status_.Update(s);
+ hook_table_.swap(dummy_table);
+ }
+ PurgeTable(s, &dummy_table);
+}
+
+void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
+ for (auto& it : *table) {
+ Hook* h = it.second;
+ if (h->cons_cb != nullptr) {
+ h->cons_cb(s, nullptr);
+ }
+ if (h->prod_cb != nullptr) {
+ h->prod_cb(s);
+ }
+ delete h;
+ }
+ table->clear();
+}
+
+string BufRendezvous::Hook::DebugString() const {
+ return strings::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
+ ", ctx:", reinterpret_cast<uint64>(prod_ctx),
+ ", val:", reinterpret_cast<uint64>(prod_value),
+ ", pcb:", reinterpret_cast<uint64>(&prod_cb),
+ ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
+}
+
+void BufRendezvous::ProvideBuf(const string& key, Device* dev,
+ DeviceContext* dev_ctx, const Tensor* v,
+ const AllocatorAttributes& attr,
+ const ProducerCallback& done) {
+ Hook* h = nullptr;
+ Status providebuf_status;
+ do {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ providebuf_status = status_;
+ break;
+ } else {
+ auto it = hook_table_.find(key);
+ if (it == hook_table_.end()) {
+ h = new Hook;
+ it = hook_table_.insert(std::make_pair(key, h)).first;
+ } else {
+ if (it->second->prod_cb != nullptr) {
+ providebuf_status = errors::Internal(
+ "BufRendezvous::ProvideBuf already called for key ", key);
+ break;
+ }
+ h = it->second;
+ }
+ // Populate Hook with all of the prod values.
+ h->prod_dev = dev;
+ h->prod_ctx = dev_ctx;
+ h->prod_value = v;
+ h->prod_attr = attr;
+ h->prod_cb = done;
+ // If consumer is waiting, kick off right away, removing Hook from table.
+ if (h->cons_cb != nullptr) {
+ hook_table_.erase(it);
+ } else {
+ h = nullptr;
+ }
+ }
+ } while (false);
+ if (h) {
+ h->cons_cb(Status::OK(), h);
+ }
+ if (!providebuf_status.ok()) {
+ done(providebuf_status);
+ }
+}
+
+void BufRendezvous::ConsumeBuf(const string& key,
+ const ConsumerCallback& done) {
+ Hook* existing_hook = nullptr;
+ Status consumebuf_status;
+ do {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ consumebuf_status = status_;
+ break;
+ }
+ auto it = hook_table_.find(key);
+ if (it != hook_table_.end()) {
+ // Prepare to consume immediately.
+ if (it->second->cons_cb) {
+ consumebuf_status =
+ errors::Internal("Second consumer arrived for key ", key);
+ break;
+ }
+ existing_hook = it->second;
+ hook_table_.erase(it);
+ existing_hook->cons_cb = done;
+ } else {
+ // Hang consumer callback on the Hook.
+ Hook* h = new Hook;
+ hook_table_[key] = h;
+ h->cons_cb = done;
+ return;
+ }
+ } while (false);
+ if (existing_hook) {
+ existing_hook->cons_cb(Status::OK(), existing_hook);
+ return;
+ }
+ if (!consumebuf_status.ok()) {
+ done(consumebuf_status, nullptr);
+ return;
+ }
+}
+
+/*static*/
+void BufRendezvous::DoneWithHook(Hook* h) {
+ h->prod_cb(Status::OK());
+ delete h;
+}
+
+void BufRendezvous::LogContents() {
+ mutex_lock l(mu_);
+ LOG(INFO) << strings::StrCat("BufRendezvous ",
+ strings::Hex(reinterpret_cast<uint64>(this)),
+ " step_id=", step_id_, " current contents:");
+ for (auto it : hook_table_) {
+ LOG(INFO) << it.first << ":" << it.second->DebugString();
+ }
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#define TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+class Device;
+class DeviceContext;
+class Tensor;
+
+// EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local
+// Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense
+// numeric types. Similar to Rendezvous but never owns a Ref on the
+// tensor, instead it uses an explicit callback to the producer when
+// the consumer side is finished with the value. This allows the
+// producer to perform in-place updates on the source buffer or to take
+// other actions that depend on knowing the consumer has passed a certain
+// execution point.
+class BufRendezvous {
+ public:
+ explicit BufRendezvous(uint64 step_id) : step_id_(step_id) {}
+
+ ~BufRendezvous();
+
+ // Inform all all waiting parties that this BufRendezvous is defunct
+ // because of an error Status interrupting the Step.
+ void StartAbort(const Status& s);
+
+ struct Hook;
+ // Provided by the consumer to be called when access to the buffer
+ // is available. If the Status arg is not OK, then hook will not
+ // be populated. Ownership of Hook passes to consumer with the
+ // callback.
+ typedef std::function<void(const Status&, Hook*)> ConsumerCallback;
+ // Provided by the producer to be called when the consumer has finished
+ // reading the buffer and will no longer access it.
+ typedef std::function<void(const Status&)> ProducerCallback;
+
+ struct Hook {
+ Device* prod_dev;
+ DeviceContext* prod_ctx;
+ const Tensor* prod_value;
+ AllocatorAttributes prod_attr;
+ ProducerCallback prod_cb;
+ ConsumerCallback cons_cb;
+ Hook()
+ : prod_dev(nullptr),
+ prod_ctx(nullptr),
+ prod_value(nullptr),
+ prod_cb(nullptr),
+ cons_cb(nullptr) {}
+ string DebugString() const;
+ };
+
+ // Called to advertise availability of a Tensor value corresponding
+ // to key. That value must stay valid until done is called.
+ void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
+ const Tensor* v, const AllocatorAttributes& attr,
+ const ProducerCallback& done);
+
+ // Called to request access to a Tensor value corresponding to key.
+ // Consumer is provide with a Hook as soon as availble.
+ void ConsumeBuf(const string& key, const ConsumerCallback& done);
+
+ // Consumer must call this function when it's done reading the Hook provided
+ // by the ConsumerCallback. This function will invoke the producer callback
+ // and then delete h.
+ static void DoneWithHook(Hook* h);
+
+ // Write the current contents of the table to the INFO log.
+ void LogContents();
+
+ protected:
+ const uint64 step_id_;
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+ typedef gtl::FlatMap<string, Hook*> HookTable;
+ HookTable hook_table_ GUARDED_BY(mu_);
+
+ void PurgeTable(const Status& s, HookTable* table);
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class BufRendezvousTest : public ::testing::Test {
+ protected:
+ BufRendezvousTest() {
+ br_.reset(new BufRendezvous(123));
+ fake_dev_ptr_ = reinterpret_cast<Device*>(512LLU);
+ fake_dev_ctx_ = reinterpret_cast<DeviceContext*>(1024LLU);
+ a_ = Tensor(DT_FLOAT, TensorShape({24}));
+ b_ = Tensor(DT_FLOAT, TensorShape({24}));
+ }
+
+ Device* fake_dev_ptr_ = nullptr;
+ DeviceContext* fake_dev_ctx_ = nullptr;
+ Tensor a_;
+ Tensor b_;
+ AllocatorAttributes aa_;
+ std::unique_ptr<BufRendezvous> br_;
+};
+
+TEST_F(BufRendezvousTest, CorrectUseProducerFirst) {
+ Status prod_status;
+ Status cons_status;
+ bool prod_callback_called = false;
+ bool cons_callback_called = false;
+ Notification note;
+ br_->ProvideBuf(
+ "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [¬e, &prod_status, &prod_callback_called](const Status& s) {
+ prod_status = s;
+ prod_callback_called = true;
+ note.Notify();
+ });
+ EXPECT_FALSE(prod_callback_called);
+ br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_callback_called = true;
+ ASSERT_TRUE(h != nullptr);
+ EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+ EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+ EXPECT_EQ(h->prod_value, &a_);
+ br_->DoneWithHook(h);
+ });
+ EXPECT_TRUE(cons_callback_called);
+ note.WaitForNotification();
+ EXPECT_TRUE(prod_callback_called);
+ TF_EXPECT_OK(cons_status);
+ TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
+ Status prod_status;
+ Status cons_status;
+ bool prod_callback_called = false;
+ bool cons_callback_called = false;
+ Notification note;
+ br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_callback_called = true;
+ ASSERT_TRUE(h != nullptr);
+ EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+ EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+ EXPECT_EQ(h->prod_value, &a_);
+ br_->DoneWithHook(h);
+ });
+ EXPECT_FALSE(cons_callback_called);
+ br_->ProvideBuf(
+ "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [¬e, &prod_status, &prod_callback_called](const Status& s) {
+ prod_status = s;
+ prod_callback_called = true;
+ note.Notify();
+ });
+ EXPECT_TRUE(cons_callback_called);
+ note.WaitForNotification();
+ EXPECT_TRUE(prod_callback_called);
+ TF_EXPECT_OK(cons_status);
+ TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
+ bool prod_callback_called = false;
+ br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_callback_called](const Status& s) {
+ prod_callback_called = true;
+ });
+ Status bad_status;
+ Notification note;
+ br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [&bad_status, ¬e](const Status& s) {
+ bad_status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_FALSE(bad_status.ok());
+ EXPECT_EQ("BufRendezvous::ProvideBuf already called for key key0",
+ bad_status.error_message());
+ EXPECT_FALSE(prod_callback_called);
+ br_.reset();
+}
+
+TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) {
+ Status cons_status;
+ br_->ConsumeBuf(
+ "key0", [this, &cons_status](const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ EXPECT_EQ(h, nullptr);
+ });
+ EXPECT_TRUE(cons_status.ok());
+ br_.reset();
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ("Delete called on non-empty BufRendezvous",
+ cons_status.error_message());
+}
+
+TEST_F(BufRendezvousTest, AbortNonEmpty) {
+ Status cons_status;
+ Status prod_status;
+ Notification prod_note;
+ Notification cons_note;
+ br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_note.Notify();
+ });
+ br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_note, &prod_status](const Status& s) {
+ prod_status = s;
+ prod_note.Notify();
+ });
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+ prod_note.WaitForNotification();
+ cons_note.WaitForNotification();
+ EXPECT_FALSE(prod_status.ok());
+ EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+TEST_F(BufRendezvousTest, AbortEmpty) {
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+}
+
+TEST_F(BufRendezvousTest, UseAfterAbort) {
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+ Status cons_status;
+ Status prod_status;
+ Notification prod_note;
+ Notification cons_note;
+ br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_note.Notify();
+ });
+ br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_note, &prod_status](const Status& s) {
+ prod_status = s;
+ prod_note.Notify();
+ });
+ prod_note.WaitForNotification();
+ cons_note.WaitForNotification();
+ EXPECT_FALSE(prod_status.ok());
+ EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+} // namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/build_graph_options.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace {
+// TODO(tucker): Temporary class just until a real CollectiveExecutor
+// implementation is submitted in a later CL.
+class DummyCollectiveExecutor : public CollectiveExecutor {
+ public:
+ explicit DummyCollectiveExecutor(CollectiveExecutorMgr* ce_mgr)
+ : CollectiveExecutor(ce_mgr) {}
+
+ ~DummyCollectiveExecutor() override {}
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override {
+ done(errors::Internal("Unimplemented"));
+ }
+
+ void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override {
+ done(errors::Internal("Unimplemented"));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DummyCollectiveExecutor);
+};
+} // namespace
+
+CollectiveExecutorMgr::CollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ ParamResolverInterface* param_resolver)
+ : dev_mgr_(dev_mgr),
+ dev_resolver_(dev_resolver),
+ param_resolver_(param_resolver) {}
+
+CollectiveExecutorMgr::~CollectiveExecutorMgr() {
+ for (auto iter : executor_table_) {
+ iter.second->Unref();
+ }
+}
+
+CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
+ CollectiveExecutor* ce = nullptr;
+ {
+ mutex_lock l(exec_mu_);
+ auto it = executor_table_.find(step_id);
+ if (it != executor_table_.end()) {
+ ce = it->second;
+ } else {
+ ce = new DummyCollectiveExecutor(this);
+ executor_table_[step_id] = ce;
+ }
+ ce->Ref();
+ }
+ return ce;
+}
+
+void CollectiveExecutorMgr::Cleanup(int64 step_id) {
+ CollectiveExecutor* ce = nullptr;
+ {
+ mutex_lock l(exec_mu_);
+ auto it = executor_table_.find(step_id);
+ if (it != executor_table_.end()) {
+ ce = it->second;
+ executor_table_.erase(it);
+ }
+ }
+ if (ce) ce->Unref();
+}
+
+void CollectiveExecutorMgr::GetStepSequenceAsync(
+ const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+ const StatusCallback& done) {
+ done(errors::Internal(
+ "CollectiveExecutorMgr does not implement GetStepSequence."));
+}
+
+void CollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ done(errors::Internal(
+ "CollectiveExecutorMgr does not implement RefreshStepIdSequence."));
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class ConfigProto;
+class DeviceMgr;
+
+class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
+ public:
+ CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ ParamResolverInterface* param_resolver);
+
+ virtual ~CollectiveExecutorMgr();
+
+ CollectiveExecutor* FindOrCreate(int64 step_id) override;
+
+ void Cleanup(int64 step_id) override;
+
+ ParamResolverInterface* GetParamResolver() const override {
+ return param_resolver_.get();
+ }
+
+ DeviceResolverInterface* GetDeviceResolver() const override {
+ return dev_resolver_.get();
+ }
+
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) override;
+
+ void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) override;
+
+ int64 NextStepId(int64 graph_key) override {
+ return CollectiveExecutor::kInvalidId;
+ }
+
+ void RetireStepId(int64 graph_key, int64 step_id) override {}
+
+ protected:
+ const DeviceMgr* dev_mgr_;
+ std::unique_ptr<DeviceResolverInterface> dev_resolver_;
+ std::unique_ptr<ParamResolverInterface> param_resolver_;
+ CollectiveRemoteAccess* remote_access_;
+ string task_name_;
+ mutex exec_mu_;
+ // Map from step_id to CollectiveExecutor
+ gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+ CollectiveExecutorMgrTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ string task_name = "/job:localhost/replica:0/task:0";
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ cme_.reset(new CollectiveExecutorMgr(
+ cp, device_mgr_.get(), drl,
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ }
+
+ std::unique_ptr<CollectiveExecutorMgr> cme_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(CollectiveExecutorMgrTest, FindOrCreate) {
+ CollectiveExecutor::Handle* h =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_TRUE(h->get());
+ CollectiveExecutor::Handle* h2 =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(h->get(), h2->get());
+ CollectiveExecutor* ce = h->get();
+ delete h;
+ delete h2;
+ CollectiveExecutor::Handle h3(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(ce, h3.get());
+ cme_->Cleanup(1);
+}
+
+TEST_F(CollectiveExecutorMgrTest, StepSequenceRelated) {
+ EXPECT_EQ(CollectiveExecutor::kInvalidId, cme_->NextStepId(123));
+ Notification ss_note;
+ Status ss_status;
+ cme_->RefreshStepIdSequenceAsync(
+ 123, [this, &ss_status, &ss_note](const Status& s) {
+ ss_status = s;
+ ss_note.Notify();
+ });
+ ss_note.WaitForNotification();
+ EXPECT_FALSE(ss_status.ok());
+ EXPECT_EQ(ss_status.error_message(),
+ "CollectiveExecutorMgr does not implement RefreshStepIdSequence.");
+ Notification gs_note;
+ Status gs_status;
+ GetStepSequenceRequest* req = nullptr;
+ GetStepSequenceResponse* resp = nullptr;
+ cme_->GetStepSequenceAsync(req, resp,
+ [this, &gs_status, &gs_note](const Status& s) {
+ gs_status = s;
+ gs_note.Notify();
+ });
+ gs_note.WaitForNotification();
+ EXPECT_FALSE(gs_status.ok());
+ EXPECT_EQ(gs_status.error_message(),
+ "CollectiveExecutorMgr does not implement GetStepSequence.");
+}
+
+} // namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+CollectiveParamResolverLocal::CollectiveParamResolverLocal(
+ const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+ const string& task_name)
+ : dev_mgr_(dev_mgr), dev_resolver_(dev_resolver), task_name_(task_name) {}
+
+void CollectiveParamResolverLocal::CompleteGroupAsync(
+ const CompleteGroupRequest* request, CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ done(
+ errors::Internal("CompleteGroup is not implemented by "
+ "CollectiveParamResolverLocal which is "
+ "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteGroupLocal(
+ const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
+ VLOG(1) << "CompleteGroupLocal " << cp << ": " << cp->ToString();
+ std::vector<StatusCallback> to_be_called;
+ GroupRec* gr = nullptr;
+ {
+ mutex_lock l(group_mu_);
+ auto it = group_table_.find(cp->group.group_key);
+ if (it == group_table_.end()) {
+ gr = new GroupRec;
+ gr->group.group_key = cp->group.group_key;
+ gr->group.group_size = cp->group.group_size;
+ gr->group.device_type = cp->group.device_type;
+ group_table_[gr->group.group_key].reset(gr);
+ VLOG(2) << "New group_key=" << gr->group.group_key
+ << " group_size=" << gr->group.group_size;
+ } else {
+ gr = it->second.get();
+ }
+ }
+ Status status;
+ {
+ mutex_lock gr_lock(gr->mu);
+ if (!gr->device_set.empty()) {
+ // Check for consistency with existing GroupRec.
+ if (cp->group.device_type != gr->group.device_type) {
+ status = errors::Internal(
+ "Collective Op ", cp->name, " is assigned to device ", device,
+ " with type ", cp->group.device_type.type_string(),
+ " and group_key ", cp->group.group_key, " but that group has type ",
+ gr->group.device_type.type_string());
+ } else if (cp->group.group_size != gr->group.group_size) {
+ status = errors::Internal(
+ "Collective Op ", cp->name, " has group_size ",
+ cp->group.group_size, " and group_key", cp->group.group_key,
+ " but that group has size ", gr->group.group_size);
+ }
+ }
+ if (status.ok()) {
+ // Insert device if not already present.
+ auto it = gr->device_set.find(device);
+ if (it == gr->device_set.end()) {
+ if (gr->device_set.size() == gr->group.group_size) {
+ // The group is already full.
+ status = errors::Internal(
+ "Collective Op ", cp->name, " is assigned to device ", device,
+ " and group_key ", cp->group.group_key,
+ " but that group doesn't contain that device.");
+ } else {
+ // This is a new device that has not yet joined the group.
+ gr->device_set.insert(device);
+ gr->device_list.push_back(device);
+ DeviceNameUtils::ParsedName parsed_device;
+ DeviceNameUtils::ParseFullName(device, &parsed_device);
+ string task_name = strings::StrCat("/job:", parsed_device.job,
+ "/replica:", parsed_device.replica,
+ "/task:", parsed_device.task);
+ gr->task_set.insert(task_name);
+ gr->task_list.push_back(task_name);
+ gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
+ VLOG(1) << "group_key=" << gr->group.group_key
+ << " group_size=" << gr->group.group_size
+ << " dev_set=" << gr->device_set.size();
+ }
+ }
+ }
+
+ if (status.ok()) {
+ // If the group is not yet complete, queue to wait for it.
+ VLOG(2) << "group_size " << gr->group.group_size << " set size "
+ << gr->device_set.size() << " gr " << gr;
+
+ if (gr->device_set.size() < gr->group.group_size) {
+ gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
+ return;
+ }
+ CHECK_EQ(gr->device_set.size(), gr->group.group_size);
+ if (!gr->waiting.empty()) {
+ std::swap(to_be_called, gr->waiting);
+ }
+ }
+ }
+ done(status, gr);
+ for (int i = 0; i < to_be_called.size(); ++i) {
+ to_be_called[i](Status::OK());
+ }
+}
+
+namespace {
+
+struct DevRec {
+ string task;
+ string device;
+ int original_rank;
+ int local_rank;
+ int global_rank;
+ const DeviceLocality* locality;
+};
+typedef std::unordered_map<string, DevRec> TaskDeviceMap;
+typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
+
+// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
+GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
+ const std::vector<DeviceLocality>& localities) {
+ GlobalDeviceMap gdm;
+ CHECK_EQ(ip.device_names.size(), ip.task_names.size());
+ CHECK_EQ(ip.device_names.size(), localities.size());
+ for (int i = 0; i < ip.device_names.size(); ++i) {
+ TaskDeviceMap& tdm = gdm[ip.task_names[i]];
+ DevRec* dr = &tdm[ip.device_names[i]];
+ dr->task = ip.task_names[i];
+ dr->device = ip.device_names[i];
+ dr->original_rank = i;
+ dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
+ dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
+ dr->locality = &localities[i];
+ }
+ return gdm;
+}
+
+void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
+ CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
+ int least_rank = -1;
+ string next_device;
+ std::set<string> selected;
+ // Starting device is one with the least initial rank.
+ for (const auto& it : *tdm) {
+ if (least_rank < 0 || it.second.original_rank < least_rank) {
+ least_rank = it.second.original_rank;
+ next_device = it.second.device;
+ }
+ }
+ CHECK_GE(least_rank, 0);
+ DeviceNameUtils::ParsedName parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
+ // NOTE: InterconnectLink has only a device_id, nothing more, so for
+ // the time being if there's more than one device at a task we
+ // assume they're all GPUs.
+
+ int next_rank = 0;
+ while (true) {
+ selected.insert(next_device);
+ DevRec* dr = &(*tdm)[next_device];
+ dr->local_rank = next_rank;
+ ++next_rank;
+ if (selected.size() == tdm->size()) {
+ break;
+ }
+ // For the present time we assume Locality links only cover GPUs.
+ // For multiple CPUs, just take them in order.
+ const InterconnectLink* best_link = nullptr;
+ if (parsed_name.type == "GPU") {
+ for (const InterconnectLink& il : dr->locality->links().link()) {
+ parsed_name.id = il.device_id();
+ string endpoint_device =
+ DeviceNameUtils::ParsedNameToString(parsed_name);
+ if (selected.find(endpoint_device) != selected.end()) {
+ continue;
+ }
+ if (best_link == nullptr || il.strength() > best_link->strength()) {
+ best_link = &il;
+ }
+ }
+ }
+ if (best_link != nullptr) {
+ // Follow the best edge
+ parsed_name.id = best_link->device_id();
+ next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
+ } else {
+ // No good edges, alas. Pick the lowest initial rank among remaining
+ // devices.
+ least_rank = -1;
+ for (const auto& it : *tdm) {
+ if (selected.find(it.second.device) != selected.end()) {
+ continue;
+ }
+ if (least_rank < 0 || it.second.original_rank < least_rank) {
+ least_rank = it.second.original_rank;
+ next_device = it.second.device;
+ }
+ }
+ CHECK_GE(least_rank, 0);
+ }
+ }
+}
+
+// The first time a shared CollectiveParams is established for a
+// shared set of instances we compute a good rank order for all the
+// devices in the group, that is appropriate for a ring algorithm.
+// This order need not be the same across different instance groups
+// sharing the same device group where there is more than one good
+// order.
+GlobalDeviceMap EstablishGlobalRank(
+ CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
+ VLOG(1) << "EstablishGlobalRank";
+ GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
+ for (auto& iter : gdm) {
+ TaskDeviceMap& tdm = iter.second;
+ OrderTaskDeviceMap(&tdm);
+ }
+ // Connect the global rank order by the order in which tasks first appear.
+ std::set<string> ordered_tasks;
+ int next_rank = 0;
+ for (int i = 0; i < cp->instance.task_names.size(); ++i) {
+ const string& task_name = cp->instance.task_names[i];
+ if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
+ continue;
+ }
+ ordered_tasks.insert(task_name);
+ TaskDeviceMap* tdm = &gdm[task_name];
+ for (auto& it : *tdm) {
+ it.second.global_rank = it.second.local_rank + next_rank;
+ }
+ next_rank += tdm->size();
+ }
+ return gdm;
+}
+
+// Sort cp->instance.device_names lexicographically, but do by first
+// computing a reordering permutation so we can keep cp->instance.task_names
+// in corresponding order.
+void SortDevicesAndTasks(CollectiveParams* cp) {
+ VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
+ CHECK(cp);
+ CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
+ CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
+ std::vector<int> perm(cp->group.group_size);
+ // TODO(tucker): substitute std::iota when the windows build supports it.
+ // std::iota(perm.begin(), perm.end(), 0);
+ for (int i = 0; i < perm.size(); ++i) {
+ perm[i] = i;
+ }
+ std::sort(perm.begin(), perm.end(), [cp](const int& a, const int& b) {
+ return cp->instance.device_names[a] < cp->instance.device_names[b];
+ });
+ std::vector<string> new_devs;
+ std::vector<string> new_tasks;
+ new_devs.reserve(cp->group.group_size);
+ new_tasks.reserve(cp->group.group_size);
+ for (int pi : perm) {
+ new_devs.push_back(cp->instance.device_names[pi]);
+ new_tasks.push_back(cp->instance.task_names[pi]);
+ }
+ cp->instance.device_names = std::move(new_devs);
+ cp->instance.task_names = std::move(new_tasks);
+ VLOG(1) << "Modified device_names on " << cp;
+}
+
+// Establish the requested number of subdivision permutations based on the
+// ring order implicit in the device order.
+void GenerateSubdivPerms(const string& device, int source_rank,
+ CollectiveParams* cp) {
+ CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+ cp->instance.impl_details.subdiv_permutations.resize(
+ cp->instance.impl_details.subdiv_offsets.size());
+ // Each subdiv permutation is a ring formed by rotating each
+ // single-task subsequence of devices by an offset. This makes most
+ // sense when each task has the same number of devices but we can't
+ // depend on that being the case so we'll compute something that
+ // works in any case.
+
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(cp->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &cp->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < cp->group.group_size; ++di) {
+ if (cp->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &cp->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
+
+ // Generate a ring permutation for each requested offset.
+ CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+ VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
+ << &cp->instance.impl_details.subdiv_permutations;
+ cp->instance.impl_details.subdiv_permutations.resize(
+ cp->instance.impl_details.subdiv_offsets.size());
+ cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
+ ++sdi) {
+ std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = cp->instance.impl_details.subdiv_offsets[sdi];
+ int prior_dev_count = 0;
+ for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int offset_di = (di + offset) % dev_per_task[ti];
+ int permuted_di = prior_dev_count + offset_di;
+ perm.push_back(permuted_di);
+ if (cp->instance.device_names[prior_dev_count + di] == device) {
+ CHECK_EQ(prior_dev_count + di, cp->default_rank);
+ cp->subdiv_rank[sdi] = permuted_di;
+ }
+ }
+ prior_dev_count += dev_per_task[ti];
+ }
+ CHECK_EQ(cp->group.group_size, perm.size());
+ }
+
+ if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ CHECK_GE(source_rank, 0);
+ cp->subdiv_source_rank.resize(
+ cp->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0; sdi < cp->subdiv_source_rank.size(); ++sdi) {
+ for (int j = 0; j < cp->group.group_size; ++j) {
+ if (cp->instance.impl_details.subdiv_permutations[sdi][j] ==
+ source_rank) {
+ cp->subdiv_source_rank[sdi] = j;
+ break;
+ }
+ }
+ CHECK_GE(cp->subdiv_source_rank[sdi], 0);
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ // Log the computed ring order for each subdiv.
+ string buf;
+ for (int sdi = 0;
+ sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
+ buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
+ for (int di = 0;
+ di < cp->instance.impl_details.subdiv_permutations[sdi].size();
+ ++di) {
+ int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
+ strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
+ }
+ strings::StrAppend(&buf, " subdiv_offsets: ");
+ for (auto o : cp->instance.impl_details.subdiv_offsets)
+ strings::StrAppend(&buf, o, " ");
+ strings::StrAppend(&buf, " SubdivRank: ");
+ for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
+ VLOG(1) << buf;
+ }
+ }
+}
+
+} // namespace
+
+void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
+ CollectiveParams* cp) {
+ cp->task.is_local.resize(cp->group.group_size, false);
+ for (int i = 0; i < cp->group.group_size; ++i) {
+ cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
+ }
+}
+
+void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
+ CollectiveParams* cp) {
+ CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
+ for (int i = 0; i < cp->group.group_size; ++i) {
+ if (cp->instance.device_names[i] == device) {
+ cp->default_rank = i;
+ break;
+ }
+ }
+}
+
+Status CollectiveParamResolverLocal::InitInstanceSharedParams(
+ GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
+ VLOG(1) << "InitInstanceSharedParams " << ir;
+ ir->shared.instance = cp->instance;
+ {
+ mutex_lock gl(gr->mu);
+ ir->shared.group = gr->group;
+ ir->shared.instance.device_names.assign(gr->device_list.begin(),
+ gr->device_list.end());
+ ir->shared.instance.task_names.assign(gr->task_list.begin(),
+ gr->task_list.end());
+ VLOG(2) << "Initialized names for instance: "
+ << ir->shared.instance.ToString();
+ }
+ ir->shared.default_rank = -1;
+
+ // Sort devce_names lexicographcally, keeping task_names in
+ // corresponding order.
+ SortDevicesAndTasks(&ir->shared);
+
+ // Get Locality data for all devices.
+
+ // Set is_local and task_names in *shared prior to invoking
+ // GetDeviceLocalitiesAsync. In a distributed context this function can be
+ // called by a derived class, some of the devices may be non-local and
+ // GetDeviceLocalitiesAsync will use those fields to launch RPCs.
+ CompleteTaskIsLocal(task_name_, &ir->shared);
+ std::vector<DeviceLocality> localities;
+ Notification note;
+ Status status;
+ dev_resolver_->GetDeviceLocalitiesAsync(ir->shared.instance, &localities,
+ [¬e, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ if (status.ok()) {
+ CompleteDefaultRanking(gr, cp, ir, localities);
+ }
+ return status;
+}
+
+void CollectiveParamResolverLocal::CompleteDefaultRanking(
+ GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
+ const std::vector<DeviceLocality>& localities) {
+ // Establish an instance-specific default rank order for devices
+ // based on localities. This rank order should be a good ring
+ // order, if possible.
+ GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
+ // Reflect the new global ranking on shared
+ size_t num_devices = ir->shared.group.group_size;
+ std::vector<string> new_device_names(num_devices, "");
+ std::vector<string> new_task_names(num_devices, "");
+ for (const auto& git : gdm) {
+ const TaskDeviceMap& tdm = git.second;
+ for (const auto& tit : tdm) {
+ const DevRec& dr = tit.second;
+ new_device_names[dr.global_rank] =
+ ir->shared.instance.device_names[dr.original_rank];
+ new_task_names[dr.global_rank] =
+ ir->shared.instance.task_names[dr.original_rank];
+ }
+ }
+
+ ir->shared.instance.device_names = new_device_names;
+ ir->shared.instance.task_names = new_task_names;
+ if (VLOG_IS_ON(2)) {
+ string buf;
+ for (const auto& d : cp->instance.device_names)
+ strings::StrAppend(&buf, "\n", d);
+ VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
+ }
+}
+
+void CollectiveParamResolverLocal::CallbackWithStatus(
+ const InstanceRecCallback& done, InstanceRec* irec) {
+ Status s;
+ {
+ mutex_lock l(irec->out_mu);
+ s = irec->status;
+ }
+ done(s, irec);
+}
+
+void CollectiveParamResolverLocal::FindInstanceRec(
+ GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
+ InstanceRec* irec = nullptr;
+ bool exit_outside_locks = false;
+ {
+ mutex_lock l(instance_mu_);
+ auto it = instance_table_.find(cp->instance.instance_key);
+ if (it != instance_table_.end()) {
+ irec = it->second.get();
+ {
+ mutex_lock l(irec->in_mu);
+ if (irec->is_init) {
+ exit_outside_locks = true;
+ } else {
+ irec->init_waiters.push_back([this, gr, cp, done](InstanceRec* irec) {
+ CallbackWithStatus(done, irec);
+ });
+ return;
+ }
+ }
+ } else {
+ // Create new InstanceRec.
+ irec = new InstanceRec;
+ instance_table_[cp->instance.instance_key].reset(irec);
+ }
+ }
+ if (exit_outside_locks) {
+ CallbackWithStatus(done, irec);
+ return;
+ }
+ // Initialize the new InstanceRec while holding out_mu.
+ {
+ mutex_lock il(irec->out_mu);
+ irec->known.resize(cp->group.group_size, false);
+ irec->status = InitInstanceSharedParams(gr, cp, irec);
+ }
+ // Prepare to invoke any waiters that accumlated during initialization.
+ std::vector<IRConsumer> init_waiters;
+ {
+ mutex_lock tl(instance_mu_);
+ {
+ mutex_lock l(irec->in_mu);
+ irec->is_init = true;
+ if (!irec->init_waiters.empty()) {
+ std::swap(init_waiters, irec->init_waiters);
+ }
+ }
+ }
+ CallbackWithStatus(done, irec);
+ for (auto& f : init_waiters) {
+ f(irec);
+ }
+}
+
+void CollectiveParamResolverLocal::CompleteParamsAsync(
+ const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+ const StatusCallback& done) {
+ VLOG(1) << "CompleteParams " << device << " for " << cp << ": "
+ << cp->ToString();
+ CompleteGroupLocal(
+ device, cp, [this, device, cp, done](const Status& s, GroupRec* gr) {
+ if (s.ok()) {
+ CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+ } else {
+ done(s);
+ }
+ });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceAsync(
+ const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ done(
+ errors::Internal("CompleteInstance is not implemented by "
+ "CollectiveParamResolverLocal which is "
+ "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceLocal(
+ const string& device, GroupRec* gr, CollectiveParams* cp, bool is_source,
+ const StatusCallback& done) {
+ VLOG(1) << "CompleteInstanceLocal " << device
+ << " instance_key: " << cp->instance.instance_key << " gr " << gr;
+
+ // Populate the group portion of *cp from *gr. Most of it should already
+ // match.
+ DCHECK_EQ(cp->group.group_key, gr->group.group_key);
+ DCHECK_EQ(cp->group.group_size, gr->group.group_size);
+ DCHECK_EQ(cp->group.device_type, gr->group.device_type);
+ cp->group = gr->group;
+
+ // Get the shared InstanceRec for this instance.
+ FindInstanceRec(gr, cp,
+ [this, device, gr, cp, is_source, done](const Status& s,
+ InstanceRec* ir) {
+ if (s.ok()) {
+ CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
+ is_source, done);
+ } else {
+ done(s);
+ }
+ });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
+ const string& device, GroupRec* gr, CollectiveParams* cp, InstanceRec* ir,
+ bool is_source, const StatusCallback& done) {
+ // Populate the fields common across instance.
+ {
+ mutex_lock l(ir->out_mu);
+ // custom operator= does a deep copy.
+ cp->instance = ir->shared.instance;
+ }
+ // Populate the fields common across task, also default_rank.
+ SetDefaultRank(device, cp);
+ CompleteTaskIsLocal(task_name_, cp);
+ // If broadcast, may need to wait for source discovery.
+ if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ CompleteInstanceSource(ir, cp, is_source,
+ [this, ir, device, cp, done](InstanceRec* irec) {
+ CHECK_EQ(ir, irec);
+ Status s;
+ int source_rank;
+ {
+ mutex_lock l(irec->out_mu);
+ s = irec->status;
+ source_rank = ir->source_rank;
+ }
+ if (s.ok()) {
+ GenerateSubdivPerms(device, source_rank, cp);
+ }
+ done(s);
+ });
+ return;
+ } else {
+ GenerateSubdivPerms(device, 0, cp);
+ }
+ done(Status::OK());
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
+ CollectiveParams* cp,
+ bool is_source,
+ const IRConsumer& f) {
+ std::vector<IRConsumer> ready_waiters;
+ {
+ mutex_lock l(ir->out_mu);
+ CHECK_EQ(cp->group.group_size, ir->known.size());
+ CHECK_GE(cp->default_rank, 0);
+ if (!ir->known[cp->default_rank]) {
+ ir->known[cp->default_rank] = true;
+ ++ir->known_count;
+ if (is_source) {
+ if (ir->source_rank >= 0) {
+ ir->status = errors::Internal("Instance ", cp->instance.instance_key,
+ " already has source ", ir->source_rank,
+ ", recevied second claim from ",
+ cp->default_rank);
+ } else {
+ ir->source_rank = cp->default_rank;
+ }
+ }
+ }
+ if (ir->known_count < ir->shared.group.group_size) {
+ ir->known_waiters.push_back(f);
+ return;
+ }
+ CHECK_EQ(ir->known_count, ir->shared.group.group_size);
+ CHECK_GE(ir->source_rank, 0);
+ if (!ir->known_waiters.empty()) {
+ ready_waiters = std::move(ir->known_waiters);
+ }
+ }
+ f(ir);
+ for (auto& f : ready_waiters) {
+ f(ir);
+ }
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceMgr;
+
+// Implements ParamResolverInterface for a single-task context.
+// It also implements the functionality necessary to serve as the
+// group leader for param resolution in a multi-task context.
+class CollectiveParamResolverLocal : public ParamResolverInterface {
+ public:
+ CollectiveParamResolverLocal(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ const string& task_name);
+
+ ~CollectiveParamResolverLocal() override {}
+
+ void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteGroupAsync(const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ protected:
+ // Used to complete/verify CollGroup.
+ struct GroupRec {
+ CollGroupParams group;
+ mutex mu;
+ Status status GUARDED_BY(mu);
+ std::set<string> device_set GUARDED_BY(mu);
+ std::vector<string> device_list GUARDED_BY(mu);
+ std::set<string> task_set GUARDED_BY(mu);
+ std::vector<string> task_list GUARDED_BY(mu);
+ std::vector<StatusCallback> waiting GUARDED_BY(mu);
+ };
+
+ // Finds the GroupRec that corresponds to cp->group_key.
+ // Also populates cp->group from that group_rec.
+ // Will wait until GroupRec is fully populated or an error arises before
+ // calling done. Callback GroupRec* arg is only valid if status is ok.
+ // Ownership of GroupRec stays with this object and does not pass to the
+ // callback.
+ typedef std::function<void(const Status& s, GroupRec* gr)> GroupRecCallback;
+ void CompleteGroupLocal(const string& device, CollectiveParams* cp,
+ const GroupRecCallback& done)
+ LOCKS_EXCLUDED(group_mu_);
+
+ // Used to complete/verify CollInstance.
+ struct InstanceRec;
+ typedef std::function<void(InstanceRec*)> IRConsumer;
+ struct InstanceRec {
+ // This structure has two mutexes so that a possibly long
+ // initialization can be done without holding the instance_mu_
+ // table lock the whole time (which can cause an excessive number
+ // of threads to block on it), and because the compiler may not
+ // permit mutex locks to be taken in more than one order.
+ //
+ // out_mu guards access to most of the fields.
+ // in_mu guards access to a queue of comsumer callbacks wanting to
+ // read the fields guarded by out_mu.
+ //
+ // The in_mu should be locked only while holding instance_mu_; the
+ // out_mu should be locked only while not holding
+ // instance_mu_.
+ //
+ // When is_init is false (the initial value) any potential user
+ // other than the creator should queue a callback on init_waiters.
+ // As soon as the shared member of this structure is fully
+ // initialized is_init will be set true and those callbacks will
+ // be invoked.
+ //
+ // Once inserted in the table this structure will never be replaced
+ // so users can capture the pointer while holding instance_mu_,
+ // drop that lock, then take a lock on out_mu before
+ // reading/modifying its values.
+ mutex in_mu;
+ bool is_init GUARDED_BY(in_mu);
+ std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
+
+ // Values to be shared by all instances, constant after initialization.
+ mutex out_mu;
+ CollectiveParams shared GUARDED_BY(out_mu);
+ // If an error occurs during initialization this structure stays in
+ // the table with a non-OK status. Purging the table and restarting
+ // needs to be done at a higher level.
+ Status status GUARDED_BY(out_mu);
+
+ // These fields are used to count the instances that have called
+ // in and become known while resolving broadcast source identity.
+ int source_rank GUARDED_BY(out_mu);
+ int known_count GUARDED_BY(out_mu);
+ std::vector<bool> known GUARDED_BY(out_mu);
+ std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
+
+ InstanceRec() : is_init(false), source_rank(-1), known_count(0) {}
+ };
+
+ // Find the InstanceRec with the same instance_key as cp. If it doesn't
+ // already exist, create and initialize from gr and cp.
+ //
+ // Precondition: *gr must be a complete GroupRec, i.e. the value set
+ // by CompleteGroupLocal. *cp must be populated with all the fields
+ // required by InitInstanceSharedParams. Ownership of InstanceRec stays
+ // with this object and does not pass to the callback.
+ typedef std::function<void(const Status& s, InstanceRec* ir)>
+ InstanceRecCallback;
+ void FindInstanceRec(GroupRec* gr, CollectiveParams* cp,
+ const InstanceRecCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ // Populate *ir with device membership from gr, then initialize to be specific
+ // to cp->instance_key, i.e. order the devices and tasks.
+ //
+ // Preconditions:
+ // cp is populated with all DeviceLocalities
+ Status InitInstanceSharedParams(GroupRec* gr, const CollectiveParams* cp,
+ InstanceRec* ir)
+ EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+
+ // Establishes the final order of ir->shared.instance.device_names and
+ // ir->shared.instance.task_names by considering localities of all devices.
+ void CompleteDefaultRanking(GroupRec* gr, const CollectiveParams* cp,
+ InstanceRec* ir,
+ const std::vector<DeviceLocality>& localities)
+ EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu);
+
+ // Finish populating *cp.
+ // Precondition: *gr has been fully populated by CompleteGroupLocal.
+ void CompleteInstanceLocal(const string& device, GroupRec* gr,
+ CollectiveParams* cp, bool is_source,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ // Finish populating *cp from fully initialized *ir.
+ // Precondition: *gr and *ir are fully populated.
+ void CompleteInstanceFromInitializedIRec(const string& device, GroupRec* gr,
+ CollectiveParams* cp,
+ InstanceRec* ir, bool is_source,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(ir->out_mu);
+
+ // Complete source data for a broadcast instance.
+ // Precondition: *cp has complete group data and default_rank.
+ void CompleteInstanceSource(InstanceRec* ir, CollectiveParams* cp,
+ bool is_source, const IRConsumer& f)
+ LOCKS_EXCLUDED(ir->out_mu);
+
+ // If cp.device_names contains only devices local to this process
+ // populates *localities, else returns an error.
+ Status GetLocalDeviceLocalities(const CollectiveParams& cp,
+ std::vector<DeviceLocality>* localities);
+
+ // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
+ // Precondition: cp->device_names is fully populated and in final order.
+ void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
+
+ // Sets cp->instance_default_rank according to location of device in
+ // current ordering of cp->instance.device_names.
+ void SetDefaultRank(const string& device, CollectiveParams* cp);
+
+ // Helper to grab status under lock, invoke callback out of lock.
+ void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
+ LOCKS_EXCLUDED(irec->out_mu);
+
+ const DeviceMgr* dev_mgr_;
+ DeviceResolverInterface* dev_resolver_;
+ string task_name_;
+ mutex group_mu_;
+ gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
+ GUARDED_BY(group_mu_);
+ mutex instance_mu_;
+ gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
+ GUARDED_BY(instance_mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveParamResolverLocalTest : public ::testing::Test {
+ protected:
+ CollectiveParamResolverLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ string task_name = "/job:localhost/replica:0/task:0";
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+ task_name));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+ std::unique_ptr<CollectiveParamResolverLocal> prl_;
+};
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
+ CollectiveParams cps[NUM_DEVS];
+ Status statuses[NUM_DEVS];
+ Notification note[NUM_DEVS];
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ CollectiveParams* cp = &cps[i];
+ cp->group.group_key = 1;
+ cp->group.group_size = 3;
+ cp->group.device_type = DeviceType("CPU");
+ cp->group.num_tasks = 1;
+ cp->instance.instance_key = 7;
+ cp->instance.type = REDUCTION_COLLECTIVE;
+ cp->instance.data_type = DataType(DT_FLOAT);
+ cp->instance.shape = TensorShape({5});
+ cp->instance.device_names.push_back(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+ cp->instance.impl_details.subdiv_offsets.push_back(0);
+ cp->is_source = false;
+ Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
+ prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+ nullptr /*CancellationManager*/,
+ [this, &statuses, ¬e, i](const Status& s) {
+ statuses[i] = s;
+ note[i].Notify();
+ });
+ });
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ note[i].WaitForNotification();
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ TF_ASSERT_OK(statuses[i]);
+ ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+ for (int j = 0; j < NUM_DEVS; ++j) {
+ EXPECT_EQ(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+ cps[i].instance.device_names[j]);
+ EXPECT_TRUE(cps[i].task.is_local[j]);
+ }
+ EXPECT_EQ(cps[i].subdiv_rank[0], i);
+ EXPECT_EQ(cps[i].subdiv_source_rank.size(), 0);
+ EXPECT_FALSE(cps[i].is_source);
+ EXPECT_EQ(cps[i].default_rank, i);
+ }
+}
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
+ CollectiveParams cps[NUM_DEVS];
+ Status statuses[NUM_DEVS];
+ Notification note[NUM_DEVS];
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ CollectiveParams* cp = &cps[i];
+ cp->group.group_key = 1;
+ cp->group.group_size = 3;
+ cp->group.device_type = DeviceType("CPU");
+ cp->group.num_tasks = 1;
+ cp->instance.instance_key = 3;
+ cp->instance.type = BROADCAST_COLLECTIVE;
+ cp->instance.data_type = DataType(DT_FLOAT);
+ cp->instance.shape = TensorShape({5});
+ cp->instance.device_names.push_back(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+ cp->instance.impl_details.subdiv_offsets.push_back(0);
+ cp->is_source = (i == 1);
+ Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
+ prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+ nullptr /*CancellationManager*/,
+ [this, &statuses, ¬e, i](const Status& s) {
+ statuses[i] = s;
+ note[i].Notify();
+ });
+ });
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ note[i].WaitForNotification();
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ TF_ASSERT_OK(statuses[i]);
+ ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+ for (int j = 0; j < NUM_DEVS; ++j) {
+ EXPECT_EQ(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+ cps[i].instance.device_names[j]);
+ EXPECT_TRUE(cps[i].task.is_local[j]);
+ }
+ ASSERT_GT(cps[i].subdiv_rank.size(), 0);
+ EXPECT_EQ(cps[i].subdiv_rank[0], i);
+ ASSERT_GT(cps[i].subdiv_source_rank.size(), 0);
+ EXPECT_EQ(cps[i].subdiv_source_rank[0], 1);
+ EXPECT_EQ(cps[i].is_source, (i == 1));
+ EXPECT_EQ(cps[i].default_rank, i);
+ }
+}
+
+// TEST_F(CollectiveParamResolverLocalTest,
+
+} // namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+
+namespace tensorflow {
+
+void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
+ buf_rendezvous_.StartAbort(s);
+}
+
+void CollectiveRemoteAccessLocal::RecvFromPeer(
+ const string& peer_device, const string& peer_task, bool peer_is_local,
+ const string& key, Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
+ << key;
+ if (!peer_is_local) {
+ done(
+ errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer "
+ "called with peer_is_local=false"));
+ return;
+ }
+ buf_rendezvous_.ConsumeBuf(
+ key, [this, to_tensor, to_device_ctx, to_device, to_alloc_attr, done](
+ const Status& s, BufRendezvous::Hook* hook) {
+ if (!s.ok()) {
+ done(s);
+ delete hook;
+ } else {
+ int64 recv_bytes = to_tensor->TotalBytes();
+ CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes());
+ MemCpyAsync(hook->prod_ctx, // src DeviceContext
+ to_device_ctx, // dst DeviceContext
+ hook->prod_dev, // src Device
+ to_device, // dst Device
+ hook->prod_attr, // src AllocatorAttributes
+ to_alloc_attr, // dst AllocatorAttributes
+ hook->prod_value, // src Tensor*
+ to_tensor, // dst Tensor*
+ [hook, done](const Status& s) {
+ done(s);
+ hook->prod_cb(s);
+ delete hook;
+ });
+ }
+ });
+}
+
+void CollectiveRemoteAccessLocal::PostToPeer(
+ const string& peer_device, const string& peer_task, const string& key,
+ Device* from_device, DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ VLOG(1) << "PostToPeer " << this << " key " << key
+ << " step_id_=" << step_id_;
+ buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
+ from_alloc_attr, done);
+}
+
+/*static*/
+void CollectiveRemoteAccessLocal::MemCpyAsync(
+ DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
+ Device* dst_dev, const AllocatorAttributes& src_attr,
+ const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
+ const StatusCallback& done) {
+ // We want a real copy to happen, i.e. the bytes inside of src should be
+ // transferred to the buffer backing dst. If src and dst are on different
+ // devices then CopyTensor::ViaDMA will do just that. But if they're both
+ // the same CPU, then it will actually just reset dst to point to src.
+ // Since this routine is used for copying between devices and within a
+ // device, we need to detect and bypass the wrong-semantics case.
+ const DeviceType src_device_type(
+ src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type());
+ const DeviceType dst_device_type(
+ dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
+ const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
+ const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
+ if (non_cpu_src) CHECK(src_dev_ctx);
+ if (non_cpu_dst) CHECK(dst_dev_ctx);
+ if (non_cpu_src || non_cpu_dst) {
+ CopyTensor::ViaDMA("", // edge name (non-existent)
+ src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
+ dst_attr, src, dst, done);
+ } else {
+ int64 bytes = src->TotalBytes();
+ DCHECK_EQ(dst->TotalBytes(), bytes);
+ memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes);
+ done(Status::OK());
+ }
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/rendezvous.h"
+
+namespace tensorflow {
+
+// Basic implementation of PerStepCollectiveRemoteAccess.
+class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
+ public:
+ CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ int64 step_id)
+ : dev_mgr_(dev_mgr),
+ dev_resolver_(dev_resolver),
+ buf_rendezvous_(step_id),
+ step_id_(step_id) {}
+
+ virtual ~CollectiveRemoteAccessLocal() {}
+
+ void StartAbort(const Status& s);
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override {
+ dev_resolver_->GetDeviceLocalitiesAsync(ci_params, localities, done);
+ }
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override {
+ dev_resolver_->GetLocalityAsync(device, task, locality, done);
+ }
+
+ void ClearTask(const string& task) override {
+ dev_resolver_->ClearTask(task);
+ }
+
+ // Copy utility that always copies bytes from src to dst even if
+ // they are on the same device, unlike CopyTensor::ViaDMA which will
+ // just change the dst buffer pointer in that case.
+ static void MemCpyAsync(DeviceContext* src_dev_ctx,
+ DeviceContext* dst_dev_ctx, Device* src_dev,
+ Device* dst_dev, const AllocatorAttributes& src_attr,
+ const AllocatorAttributes& dst_attr,
+ const Tensor* src, Tensor* dst,
+ const StatusCallback& done);
+
+ protected:
+ const DeviceMgr* dev_mgr_; // not owned
+ DeviceResolverInterface* dev_resolver_; // not owned
+ BufRendezvous buf_rendezvous_;
+ int64 step_id_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+static const int kStepId = 123;
+
+class CollectiveRemoteAccessLocalTest : public ::testing::Test {
+ protected:
+ const string kTaskName = "/job:localhost/replica:0/task:0";
+
+ CollectiveRemoteAccessLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+ kTaskName));
+ rma_.reset(new CollectiveRemoteAccessLocal(device_mgr_.get(), drl_.get(),
+ kStepId));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+ std::unique_ptr<CollectiveParamResolverLocal> prl_;
+ std::unique_ptr<CollectiveRemoteAccessLocal> rma_;
+};
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
+ Device* cpu0 = nullptr;
+ AllocatorAttributes attr;
+ DeviceLocality dev_locality;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
+ Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+ Notification recv_note;
+ Status recv_status;
+ rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
+ "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
+ attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ [this, &recv_note, &recv_status](const Status& s) {
+ recv_status = s;
+ recv_note.Notify();
+ });
+ Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ source_tensor.flat<float>()(i) = i / 2;
+ }
+ // Tensors have distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+ Notification send_note;
+ Status send_status;
+ rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
+ cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
+ attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+ [this, &send_note, &send_status](const Status& s) {
+ send_status = s;
+ send_note.Notify();
+ });
+ recv_note.WaitForNotification();
+ send_note.WaitForNotification();
+ TF_EXPECT_OK(recv_status);
+ TF_EXPECT_OK(send_status);
+ // Sink tensor gets the source tensor values.
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+ }
+ // And still has distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
+ Device* cpu2 = nullptr;
+ AllocatorAttributes attr;
+ DeviceLocality dev_locality;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:2", &cpu2));
+ Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+ Notification recv_note;
+ Status recv_status;
+ rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/,
+ "key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
+ attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ [this, &recv_note, &recv_status](const Status& s) {
+ recv_status = s;
+ recv_note.Notify();
+ });
+ Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ source_tensor.flat<float>()(i) = i / 2;
+ }
+ // Tensors have distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+ Device* cpu1 = nullptr;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:1", &cpu1));
+ Notification send_note;
+ Status send_status;
+ rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0",
+ cpu1 /*from_device*/, nullptr /*from_device_ctx*/,
+ attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+ [this, &send_note, &send_status](const Status& s) {
+ send_status = s;
+ send_note.Notify();
+ });
+ recv_note.WaitForNotification();
+ send_note.WaitForNotification();
+ TF_EXPECT_OK(recv_status);
+ TF_EXPECT_OK(send_status);
+ // Sink tensor gets the source tensor values.
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+ }
+ // And still has distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+} // namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+void DeviceResolverLocal::GetDeviceLocalitiesAsync(
+ const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities, const StatusCallback& done) {
+ localities->clear();
+ for (const string& device_name : ci_params.device_names) {
+ Device* dev;
+ Status s = dev_mgr_->LookupDevice(device_name, &dev);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ localities->push_back(dev->attributes().locality());
+ }
+ done(Status::OK());
+}
+
+void DeviceResolverLocal::GetLocalityAsync(const string& device,
+ const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) {
+ Device* dev;
+ Status s = dev_mgr_->LookupDevice(device, &dev);
+ if (s.ok()) {
+ *locality = dev->attributes().locality();
+ }
+ done(s);
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+class DeviceMgr;
+
+// Implements DeviceResolverInterface in a single-task context.
+class DeviceResolverLocal : public DeviceResolverInterface {
+ public:
+ DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
+
+ virtual ~DeviceResolverLocal() {}
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override;
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override;
+
+ void ClearTask(const string& task) override {}
+
+ protected:
+ const DeviceMgr* dev_mgr_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class DeviceResolverLocalTest : public ::testing::Test {
+ protected:
+ DeviceResolverLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ string task_name = "/job:localhost/replica:0/task:0";
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+};
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesKnown) {
+ CollectiveParams cp;
+ std::vector<DeviceLocality> localities;
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:CPU:1");
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:CPU:2");
+ Notification note;
+ Status status;
+ drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+ [this, ¬e, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(2, localities.size());
+}
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesUnknown) {
+ CollectiveParams cp;
+ std::vector<DeviceLocality> localities;
+ // In some builds there may be 1 GPU, but there should never be 9.
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:GPU:9");
+ Notification note;
+ Status status;
+ drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+ [this, ¬e, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(0, localities.size());
+}
+
+} // namespace
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/collective.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string CollGroupParams::ToString() const {
+ return strings::StrCat("CollGroupParams {group_key=", group_key,
+ " group_size=", group_size,
+ " device_type=", device_type.type_string(),
+ " num_tasks=", num_tasks, "}");
+}
+
+CollInstanceParams& CollInstanceParams::operator=(
+ const CollInstanceParams& other) {
+ if (this != &other) {
+ instance_key = other.instance_key;
+ type = other.type;
+ data_type = other.data_type;
+ shape = other.shape;
+ device_names.clear();
+ device_names.assign(other.device_names.begin(), other.device_names.end());
+ task_names.assign(other.task_names.begin(), other.task_names.end());
+ impl_details.subdiv_offsets.assign(
+ other.impl_details.subdiv_offsets.begin(),
+ other.impl_details.subdiv_offsets.end());
+ impl_details.subdiv_permutations.clear();
+ for (auto p : other.impl_details.subdiv_permutations) {
+ impl_details.subdiv_permutations.push_back(
+ std::vector<int>(p.begin(), p.end()));
+ }
+ impl_details.subdiv_source_rank.assign(
+ other.impl_details.subdiv_source_rank.begin(),
+ other.impl_details.subdiv_source_rank.end());
+ }
+ return *this;
+}
+
+string CollInstanceParams::ToString() const {
+ string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key,
+ " type=", type, " data_type=", data_type,
+ " shape=", shape.DebugString(), " devices {");
+ for (const auto& d : device_names) {
+ strings::StrAppend(&v, d, ",");
+ }
+ strings::StrAppend(&v, "} task_names={");
+ for (const auto& n : task_names) {
+ strings::StrAppend(&v, n, ", ");
+ }
+ strings::StrAppend(&v, "}, subdiv_offsets={");
+ for (const auto& d : impl_details.subdiv_offsets) {
+ strings::StrAppend(&v, d, ",");
+ }
+ strings::StrAppend(&v, "}, subdiv_perms={");
+ for (const auto& p : impl_details.subdiv_permutations) {
+ strings::StrAppend(&v, "{");
+ for (const auto& i : p) {
+ strings::StrAppend(&v, i, ",");
+ }
+ strings::StrAppend(&v, "}"); // one subdiv
+ }
+ strings::StrAppend(&v, "}"); // all subdivs
+ return v;
+}
+
+string CollTaskParams::ToString() const {
+ string v = strings::StrCat("CollTaskParams {is_local={");
+ for (const auto& b : is_local) {
+ strings::StrAppend(&v, static_cast<int>(b), ",");
+ }
+ strings::StrAppend(&v, "}}");
+ return v;
+}
+
+string CollectiveParams::ToString() const {
+ string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
+ strings::StrAppend(&v, " ", instance.ToString());
+ strings::StrAppend(&v, " ", task.ToString());
+ strings::StrAppend(&v, " default_rank=", default_rank,
+ " is_source=", is_source, " subdiv_rank={");
+ for (const auto& r : subdiv_rank) {
+ strings::StrAppend(&v, r, ",");
+ }
+ if (!subdiv_source_rank.empty()) {
+ strings::StrAppend(&v, " subdiv_rank={");
+ for (const auto& r : subdiv_source_rank) {
+ strings::StrAppend(&v, r, ",");
+ }
+ strings::StrAppend(&v, "}");
+ }
+ strings::StrAppend(&v, "}}");
+ return v;
+}
+
+/*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
+ OpKernelContext* ctx) {
+ return ctx->params_;
+}
+
+/*static*/
+int64 CollectiveExecutor::kInvalidId = -1;
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#define TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class BufRendezvous;
+class CancellationManager;
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceLocality;
+class GetStepSequenceRequest;
+class GetStepSequenceResponse;
+class Op;
+class Tensor;
+
+// Types of supported collective operations.
+enum CollectiveType {
+ REDUCTION_COLLECTIVE = 0,
+ BROADCAST_COLLECTIVE,
+ UNDEFINED_COLLECTIVE,
+};
+
+// Data common to all members of a device group.
+// All members share the same device set but its order is
+// particular to an instance so it is stored there.
+struct CollGroupParams {
+ int32 group_key;
+ int32 group_size;
+ DeviceType device_type;
+ int32 num_tasks; // number of distinct tasks in group
+ string ToString() const;
+ CollGroupParams() : device_type(DEVICE_CPU) {}
+};
+
+// The best implementation of a collective op depends on many factors
+// including the number of devices involved, the topology of
+// interconnects between them and the sizes of inputs. This structure
+// is used in generating and representing data movement choreography
+// for each specific algorithm, hence it does not have a single, fixed
+// interpretation. On first execution the runtime will update this
+// structure with decisions that will guide all subsequent executions.
+struct CollImplDetails {
+ std::vector<std::vector<int>> subdiv_permutations;
+ std::vector<int> subdiv_offsets;
+ // broadcast only: rank of source in each subdiv
+ std::vector<int> subdiv_source_rank;
+};
+
+// Data common to all members of a collective instance.
+struct CollInstanceParams {
+ int32 instance_key; // Identifies all participating graph nodes.
+ CollectiveType type;
+ DataType data_type;
+ TensorShape shape;
+ // Fully qualified name of device for each member, in default rank order.
+ std::vector<string> device_names;
+ // Task name prefix of corresponding device name.
+ std::vector<string> task_names;
+ CollImplDetails impl_details;
+ string ToString() const;
+ CollInstanceParams& operator=(const struct CollInstanceParams& other);
+};
+
+// Data common to all instance members in the same task.
+struct CollTaskParams {
+ // True for devices that are local to the process, i.e. no RPC needed.
+ std::vector<bool> is_local;
+ string ToString() const;
+};
+
+// Unique to a single CollectiveOp node.
+struct CollectiveParams {
+ CollGroupParams group;
+ CollInstanceParams instance;
+ CollTaskParams task;
+
+ string name; // node name used only for log or error messages
+ int default_rank; // index of this op within device_names
+ bool is_source; // broadcast only
+ // Rank of this device in each subdivision permutation.
+ std::vector<int> subdiv_rank;
+ std::vector<int> subdiv_source_rank;
+ const Tensor* in_tensor; // kernel input
+ Tensor* out_tensor; // kernel output
+ std::unique_ptr<OpKernel> merge_op; // reduction only
+ std::unique_ptr<OpKernel> final_op; // reduction only
+ OpKernelContext* op_context;
+ string ToString() const;
+};
+
+class CollectiveExecutor;
+
+// Interface that provides resolution of device localities.
+class DeviceResolverInterface {
+ public:
+ virtual ~DeviceResolverInterface() {}
+
+ // Collects DeviceLocality protobufs from all of the devices identified
+ // in 'col_params'.
+ virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) = 0;
+
+ // Populate *locality with the DeviceLocality of the specified
+ // device.
+ virtual void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) = 0;
+
+ // Clear the cache of device data belonging
+ // to the specified task.
+ virtual void ClearTask(const string& task) = 0;
+};
+
+// Interface that provides resolution of shared CollectiveParams fields.
+class ParamResolverInterface {
+ public:
+ virtual ~ParamResolverInterface() {}
+
+ // Called by each collective op at first execution in order to fill out
+ // the CollectiveParams structure with data gathered from the full
+ // (maybe distributed) collection of peer nodes.
+ virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+
+ // Used within a distributed implementation to discover/verify
+ // data shared across a device group.
+ virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+
+ // Used within a distributed implementation to discover/verify data
+ // shared across an instance group.
+ virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+};
+
+// Graphs which utilize Collective Ops in a common instance must
+// execute with identical step_ids even if they are disjoint graphs
+// run by otherwise independent tasks. This interface supplies
+// coordinated step_ids to use in such cases.
+class StepSequenceInterface {
+ public:
+ virtual ~StepSequenceInterface() {}
+
+ // Used with a distributed implementation to coordinate step_id
+ // sequences across tasks.
+ virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) = 0;
+
+ // Refresh the local per-graph_key step_id sequence from collective
+ // group leader, if applicable.
+ virtual void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) = 0;
+
+ // Returns the the step_id that should be used for initiating a new execution
+ // on the specified graph. May return the same step_id multiple times if
+ // RetireStepId or RefreshStepIdReservation is not called.
+ virtual int64 NextStepId(int64 graph_key) = 0;
+
+ // Reports that execution of the given step has completed successfully.
+ // Should be called immediately after a step completes with OK status,
+ // prior to calling NextStepId(). If the step fails, don't call.
+ virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
+};
+
+// Interface that provides access to per-step CollectiveExecutor
+// instances and various distributed resolution capabilities.
+class CollectiveExecutorMgrInterface : public StepSequenceInterface {
+ public:
+ virtual ~CollectiveExecutorMgrInterface() {}
+
+ // Returns the step-specific CollectiveExecutor, creating if one does not
+ // already exist. The caller assumes ownership of one Ref on the object.
+ virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
+
+ // If there is a CollectiveExecutor for step_id, remove it from the
+ // table.
+ virtual void Cleanup(int64 step_id) = 0;
+
+ virtual ParamResolverInterface* GetParamResolver() const = 0;
+
+ virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
+};
+
+// Interface that a Collective Op implementation uses to exchange data
+// with peers. Note that data exchange is currently limited to types
+// for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric
+// types.
+class PeerAccessInterface {
+ public:
+ virtual ~PeerAccessInterface() {}
+
+ virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key,
+ Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr,
+ Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) = 0;
+
+ virtual void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) = 0;
+};
+
+class PerStepCollectiveRemoteAccess;
+
+// A step-specific object that can execute a collective operation completely
+// described by a CollectiveParams object.
+class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
+ public:
+ virtual void StartAbort(const Status& s) {}
+
+ virtual void ExecuteAsync(OpKernelContext* ctx,
+ const CollectiveParams& col_params,
+ const string& exec_key, StatusCallback done) {
+ done(errors::Internal(
+ "A collective Op has been called in a context in which "
+ "a CollectiveExecutor has not been provided."));
+ }
+
+ virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ StatusCallback done) {
+ cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
+ }
+
+ virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; }
+
+ // Used to designate an invalid group or instance key.
+ static int64 kInvalidId;
+
+ // Lexically scoped handle for Ref.
+ class Handle {
+ public:
+ explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
+ if (!inherit_ref) ce->Ref();
+ }
+ ~Handle() { ce_->Unref(); }
+ CollectiveExecutor* get() const { return ce_; }
+
+ private:
+ CollectiveExecutor* ce_;
+ };
+
+ protected:
+ explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
+ : cem_(cem) {}
+
+ // For use only by derived classes
+ static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
+ CollectiveExecutorMgrInterface* cem_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
+};
+
+// Interface of a helper object that provices a CollectiveExecutor with
+// all of the remote access it needs.
+class CollectiveRemoteAccess : public PeerAccessInterface,
+ public DeviceResolverInterface {
+ public:
+ virtual ~CollectiveRemoteAccess() {}
+};
+
+// A per-step version of CollectiveRemoteAccess that cleans up outstanding
+// communications in case step execution is abandoned.
+class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
+ public:
+ virtual ~PerStepCollectiveRemoteAccess() {}
+ virtual void StartAbort(const Status& s) = 0;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
void NotifyUseOfPersistentTensor(const Tensor& tensor);
Status status_;
+ friend class CollectiveExecutor; // for access to params_
Params* params_; // not owned
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);