Collective Ops Part 1
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 29 Mar 2018 00:06:44 +0000 (17:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 00:09:25 +0000 (17:09 -0700)
The basic interface definitions, local-only versions of remote-access,
param-resolution, device-resolution and mgr.

A collective op is able to execute synchronously across devices
and across separate graphs. Collective ops to be introduced eventually
include broadcast and all-reduce.  This change is part of a series of
changes that will introduce the necessary infrastructure then the
initial op implementations.

PiperOrigin-RevId: 190860248

19 files changed:
tensorflow/core/BUILD
tensorflow/core/common_runtime/buf_rendezvous.cc [new file with mode: 0644]
tensorflow/core/common_runtime/buf_rendezvous.h [new file with mode: 0644]
tensorflow/core/common_runtime/buf_rendezvous_test.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_executor_mgr.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_executor_mgr.h [new file with mode: 0644]
tensorflow/core/common_runtime/collective_executor_mgr_test.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_param_resolver_local.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_param_resolver_local.h [new file with mode: 0644]
tensorflow/core/common_runtime/collective_param_resolver_local_test.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_rma_local.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_rma_local.h [new file with mode: 0644]
tensorflow/core/common_runtime/collective_rma_local_test.cc [new file with mode: 0644]
tensorflow/core/common_runtime/device_resolver_local.cc [new file with mode: 0644]
tensorflow/core/common_runtime/device_resolver_local.h [new file with mode: 0644]
tensorflow/core/common_runtime/device_resolver_local_test.cc [new file with mode: 0644]
tensorflow/core/framework/collective.cc [new file with mode: 0644]
tensorflow/core/framework/collective.h [new file with mode: 0644]
tensorflow/core/framework/op_kernel.h

index 4726946..7121064 100644 (file)
@@ -455,6 +455,7 @@ tf_cuda_library(
         "framework/attr_value_util.h",
         "framework/bfloat16.h",
         "framework/cancellation.h",
+        "framework/collective.h",
         "framework/common_shape_fns.h",
         "framework/control_flow.h",  # TODO(josh11b): Make internal?
         "framework/dataset.h",
@@ -2172,6 +2173,11 @@ tf_cuda_library(
 CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/allocator_retry.h",
     "common_runtime/bfc_allocator.h",
+    "common_runtime/collective_executor_mgr.h",
+    "common_runtime/collective_param_resolver_local.h",
+    "common_runtime/collective_rma_local.h",
+    "common_runtime/device_resolver_local.h",
+    "common_runtime/buf_rendezvous.h",
     "common_runtime/build_graph_options.h",
     "common_runtime/constant_folding.h",
     "common_runtime/copy_tensor.h",
@@ -2210,7 +2216,11 @@ tf_cuda_library(
         "common_runtime/accumulate_n_optimizer.cc",
         "common_runtime/allocator_retry.cc",
         "common_runtime/bfc_allocator.cc",
+        "common_runtime/buf_rendezvous.cc",
         "common_runtime/build_graph_options.cc",
+        "common_runtime/collective_executor_mgr.cc",
+        "common_runtime/collective_param_resolver_local.cc",
+        "common_runtime/collective_rma_local.cc",
         "common_runtime/constant_folding.cc",
         "common_runtime/copy_tensor.cc",
         "common_runtime/costmodel_manager.cc",
@@ -2218,6 +2228,7 @@ tf_cuda_library(
         "common_runtime/device.cc",
         "common_runtime/device_factory.cc",
         "common_runtime/device_mgr.cc",
+        "common_runtime/device_resolver_local.cc",
         "common_runtime/device_set.cc",
         "common_runtime/executor.cc",
         "common_runtime/function.cc",
@@ -2825,6 +2836,11 @@ tf_cc_tests(
     name = "higher_level_tests",
     size = "small",
     srcs = [
+        "common_runtime/buf_rendezvous_test.cc",
+        "common_runtime/collective_executor_mgr_test.cc",
+        "common_runtime/collective_param_resolver_local_test.cc",
+        "common_runtime/collective_rma_local_test.cc",
+        "common_runtime/device_resolver_local_test.cc",
         "common_runtime/device_set_test.cc",
         "common_runtime/optimization_registry_test.cc",
         "common_runtime/pending_counts_test.cc",
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.cc b/tensorflow/core/common_runtime/buf_rendezvous.cc
new file mode 100644 (file)
index 0000000..b57eb29
--- /dev/null
@@ -0,0 +1,166 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+
+namespace tensorflow {
+
+BufRendezvous::~BufRendezvous() {
+  mutex_lock l(mu_);
+  if (!hook_table_.empty()) {
+    PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
+               &hook_table_);
+  }
+}
+
+void BufRendezvous::StartAbort(const Status& s) {
+  CHECK(!s.ok());
+  HookTable dummy_table;
+  {
+    mutex_lock l(mu_);
+    status_.Update(s);
+    hook_table_.swap(dummy_table);
+  }
+  PurgeTable(s, &dummy_table);
+}
+
+void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
+  for (auto& it : *table) {
+    Hook* h = it.second;
+    if (h->cons_cb != nullptr) {
+      h->cons_cb(s, nullptr);
+    }
+    if (h->prod_cb != nullptr) {
+      h->prod_cb(s);
+    }
+    delete h;
+  }
+  table->clear();
+}
+
+string BufRendezvous::Hook::DebugString() const {
+  return strings::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
+                         ", ctx:", reinterpret_cast<uint64>(prod_ctx),
+                         ", val:", reinterpret_cast<uint64>(prod_value),
+                         ", pcb:", reinterpret_cast<uint64>(&prod_cb),
+                         ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
+}
+
+void BufRendezvous::ProvideBuf(const string& key, Device* dev,
+                               DeviceContext* dev_ctx, const Tensor* v,
+                               const AllocatorAttributes& attr,
+                               const ProducerCallback& done) {
+  Hook* h = nullptr;
+  Status providebuf_status;
+  do {
+    mutex_lock l(mu_);
+    if (!status_.ok()) {
+      providebuf_status = status_;
+      break;
+    } else {
+      auto it = hook_table_.find(key);
+      if (it == hook_table_.end()) {
+        h = new Hook;
+        it = hook_table_.insert(std::make_pair(key, h)).first;
+      } else {
+        if (it->second->prod_cb != nullptr) {
+          providebuf_status = errors::Internal(
+              "BufRendezvous::ProvideBuf already called for key ", key);
+          break;
+        }
+        h = it->second;
+      }
+      // Populate Hook with all of the prod values.
+      h->prod_dev = dev;
+      h->prod_ctx = dev_ctx;
+      h->prod_value = v;
+      h->prod_attr = attr;
+      h->prod_cb = done;
+      // If consumer is waiting, kick off right away, removing Hook from table.
+      if (h->cons_cb != nullptr) {
+        hook_table_.erase(it);
+      } else {
+        h = nullptr;
+      }
+    }
+  } while (false);
+  if (h) {
+    h->cons_cb(Status::OK(), h);
+  }
+  if (!providebuf_status.ok()) {
+    done(providebuf_status);
+  }
+}
+
+void BufRendezvous::ConsumeBuf(const string& key,
+                               const ConsumerCallback& done) {
+  Hook* existing_hook = nullptr;
+  Status consumebuf_status;
+  do {
+    mutex_lock l(mu_);
+    if (!status_.ok()) {
+      consumebuf_status = status_;
+      break;
+    }
+    auto it = hook_table_.find(key);
+    if (it != hook_table_.end()) {
+      // Prepare to consume immediately.
+      if (it->second->cons_cb) {
+        consumebuf_status =
+            errors::Internal("Second consumer arrived for key ", key);
+        break;
+      }
+      existing_hook = it->second;
+      hook_table_.erase(it);
+      existing_hook->cons_cb = done;
+    } else {
+      // Hang consumer callback on the Hook.
+      Hook* h = new Hook;
+      hook_table_[key] = h;
+      h->cons_cb = done;
+      return;
+    }
+  } while (false);
+  if (existing_hook) {
+    existing_hook->cons_cb(Status::OK(), existing_hook);
+    return;
+  }
+  if (!consumebuf_status.ok()) {
+    done(consumebuf_status, nullptr);
+    return;
+  }
+}
+
+/*static*/
+void BufRendezvous::DoneWithHook(Hook* h) {
+  h->prod_cb(Status::OK());
+  delete h;
+}
+
+void BufRendezvous::LogContents() {
+  mutex_lock l(mu_);
+  LOG(INFO) << strings::StrCat("BufRendezvous ",
+                               strings::Hex(reinterpret_cast<uint64>(this)),
+                               " step_id=", step_id_, " current contents:");
+  for (auto it : hook_table_) {
+    LOG(INFO) << it.first << ":" << it.second->DebugString();
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h
new file mode 100644 (file)
index 0000000..e94e88b
--- /dev/null
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#define TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+class Device;
+class DeviceContext;
+class Tensor;
+
+// EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local
+// Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense
+// numeric types.  Similar to Rendezvous but never owns a Ref on the
+// tensor, instead it uses an explicit callback to the producer when
+// the consumer side is finished with the value.  This allows the
+// producer to perform in-place updates on the source buffer or to take
+// other actions that depend on knowing the consumer has passed a certain
+// execution point.
+class BufRendezvous {
+ public:
+  explicit BufRendezvous(uint64 step_id) : step_id_(step_id) {}
+
+  ~BufRendezvous();
+
+  // Inform all all waiting parties that this BufRendezvous is defunct
+  // because of an error Status interrupting the Step.
+  void StartAbort(const Status& s);
+
+  struct Hook;
+  // Provided by the consumer to be called when access to the buffer
+  // is available.  If the Status arg is not OK, then hook will not
+  // be populated.  Ownership of Hook passes to consumer with the
+  // callback.
+  typedef std::function<void(const Status&, Hook*)> ConsumerCallback;
+  // Provided by the producer to be called when the consumer has finished
+  // reading the buffer and will no longer access it.
+  typedef std::function<void(const Status&)> ProducerCallback;
+
+  struct Hook {
+    Device* prod_dev;
+    DeviceContext* prod_ctx;
+    const Tensor* prod_value;
+    AllocatorAttributes prod_attr;
+    ProducerCallback prod_cb;
+    ConsumerCallback cons_cb;
+    Hook()
+        : prod_dev(nullptr),
+          prod_ctx(nullptr),
+          prod_value(nullptr),
+          prod_cb(nullptr),
+          cons_cb(nullptr) {}
+    string DebugString() const;
+  };
+
+  // Called to advertise availability of a Tensor value corresponding
+  // to key.  That value must stay valid until done is called.
+  void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
+                  const Tensor* v, const AllocatorAttributes& attr,
+                  const ProducerCallback& done);
+
+  // Called to request access to a Tensor value corresponding to key.
+  // Consumer is provide with a Hook as soon as availble.
+  void ConsumeBuf(const string& key, const ConsumerCallback& done);
+
+  // Consumer must call this function when it's done reading the Hook provided
+  // by the ConsumerCallback.  This function will invoke the producer callback
+  // and then delete h.
+  static void DoneWithHook(Hook* h);
+
+  // Write the current contents of the table to the INFO log.
+  void LogContents();
+
+ protected:
+  const uint64 step_id_;
+  mutex mu_;
+  Status status_ GUARDED_BY(mu_);
+  typedef gtl::FlatMap<string, Hook*> HookTable;
+  HookTable hook_table_ GUARDED_BY(mu_);
+
+  void PurgeTable(const Status& s, HookTable* table);
+};
+}  // namespace tensorflow
+#endif  // TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc
new file mode 100644 (file)
index 0000000..0e79823
--- /dev/null
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class BufRendezvousTest : public ::testing::Test {
+ protected:
+  BufRendezvousTest() {
+    br_.reset(new BufRendezvous(123));
+    fake_dev_ptr_ = reinterpret_cast<Device*>(512LLU);
+    fake_dev_ctx_ = reinterpret_cast<DeviceContext*>(1024LLU);
+    a_ = Tensor(DT_FLOAT, TensorShape({24}));
+    b_ = Tensor(DT_FLOAT, TensorShape({24}));
+  }
+
+  Device* fake_dev_ptr_ = nullptr;
+  DeviceContext* fake_dev_ctx_ = nullptr;
+  Tensor a_;
+  Tensor b_;
+  AllocatorAttributes aa_;
+  std::unique_ptr<BufRendezvous> br_;
+};
+
+TEST_F(BufRendezvousTest, CorrectUseProducerFirst) {
+  Status prod_status;
+  Status cons_status;
+  bool prod_callback_called = false;
+  bool cons_callback_called = false;
+  Notification note;
+  br_->ProvideBuf(
+      "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+      [&note, &prod_status, &prod_callback_called](const Status& s) {
+        prod_status = s;
+        prod_callback_called = true;
+        note.Notify();
+      });
+  EXPECT_FALSE(prod_callback_called);
+  br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+                              const Status& s, BufRendezvous::Hook* h) {
+    cons_status = s;
+    cons_callback_called = true;
+    ASSERT_TRUE(h != nullptr);
+    EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+    EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+    EXPECT_EQ(h->prod_value, &a_);
+    br_->DoneWithHook(h);
+  });
+  EXPECT_TRUE(cons_callback_called);
+  note.WaitForNotification();
+  EXPECT_TRUE(prod_callback_called);
+  TF_EXPECT_OK(cons_status);
+  TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
+  Status prod_status;
+  Status cons_status;
+  bool prod_callback_called = false;
+  bool cons_callback_called = false;
+  Notification note;
+  br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+                              const Status& s, BufRendezvous::Hook* h) {
+    cons_status = s;
+    cons_callback_called = true;
+    ASSERT_TRUE(h != nullptr);
+    EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+    EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+    EXPECT_EQ(h->prod_value, &a_);
+    br_->DoneWithHook(h);
+  });
+  EXPECT_FALSE(cons_callback_called);
+  br_->ProvideBuf(
+      "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+      [&note, &prod_status, &prod_callback_called](const Status& s) {
+        prod_status = s;
+        prod_callback_called = true;
+        note.Notify();
+      });
+  EXPECT_TRUE(cons_callback_called);
+  note.WaitForNotification();
+  EXPECT_TRUE(prod_callback_called);
+  TF_EXPECT_OK(cons_status);
+  TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
+  bool prod_callback_called = false;
+  br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+                  [this, &prod_callback_called](const Status& s) {
+                    prod_callback_called = true;
+                  });
+  Status bad_status;
+  Notification note;
+  br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+                  [&bad_status, &note](const Status& s) {
+                    bad_status = s;
+                    note.Notify();
+                  });
+  note.WaitForNotification();
+  EXPECT_FALSE(bad_status.ok());
+  EXPECT_EQ("BufRendezvous::ProvideBuf already called for key key0",
+            bad_status.error_message());
+  EXPECT_FALSE(prod_callback_called);
+  br_.reset();
+}
+
+TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) {
+  Status cons_status;
+  br_->ConsumeBuf(
+      "key0", [this, &cons_status](const Status& s, BufRendezvous::Hook* h) {
+        cons_status = s;
+        EXPECT_EQ(h, nullptr);
+      });
+  EXPECT_TRUE(cons_status.ok());
+  br_.reset();
+  EXPECT_FALSE(cons_status.ok());
+  EXPECT_EQ("Delete called on non-empty BufRendezvous",
+            cons_status.error_message());
+}
+
+TEST_F(BufRendezvousTest, AbortNonEmpty) {
+  Status cons_status;
+  Status prod_status;
+  Notification prod_note;
+  Notification cons_note;
+  br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+                              const Status& s, BufRendezvous::Hook* h) {
+    cons_status = s;
+    cons_note.Notify();
+  });
+  br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+                  [this, &prod_note, &prod_status](const Status& s) {
+                    prod_status = s;
+                    prod_note.Notify();
+                  });
+  br_->StartAbort(errors::Internal("Falling sky detected"));
+  prod_note.WaitForNotification();
+  cons_note.WaitForNotification();
+  EXPECT_FALSE(prod_status.ok());
+  EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+  EXPECT_FALSE(cons_status.ok());
+  EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+TEST_F(BufRendezvousTest, AbortEmpty) {
+  br_->StartAbort(errors::Internal("Falling sky detected"));
+}
+
+TEST_F(BufRendezvousTest, UseAfterAbort) {
+  br_->StartAbort(errors::Internal("Falling sky detected"));
+  Status cons_status;
+  Status prod_status;
+  Notification prod_note;
+  Notification cons_note;
+  br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+                              const Status& s, BufRendezvous::Hook* h) {
+    cons_status = s;
+    cons_note.Notify();
+  });
+  br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+                  [this, &prod_note, &prod_status](const Status& s) {
+                    prod_status = s;
+                    prod_note.Notify();
+                  });
+  prod_note.WaitForNotification();
+  cons_note.WaitForNotification();
+  EXPECT_FALSE(prod_status.ok());
+  EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+  EXPECT_FALSE(cons_status.ok());
+  EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
new file mode 100644 (file)
index 0000000..a5c4946
--- /dev/null
@@ -0,0 +1,114 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/build_graph_options.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace {
+// TODO(tucker): Temporary class just until a real CollectiveExecutor
+// implementation is submitted in a later CL.
+class DummyCollectiveExecutor : public CollectiveExecutor {
+ public:
+  explicit DummyCollectiveExecutor(CollectiveExecutorMgr* ce_mgr)
+      : CollectiveExecutor(ce_mgr) {}
+
+  ~DummyCollectiveExecutor() override {}
+
+  void RecvFromPeer(const string& peer_device, const string& peer_task,
+                    bool peer_is_local, const string& key, Device* to_device,
+                    DeviceContext* to_device_ctx,
+                    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+                    const DeviceLocality& client_locality,
+                    const StatusCallback& done) override {
+    done(errors::Internal("Unimplemented"));
+  }
+
+  void PostToPeer(const string& peer_device, const string& peer_task,
+                  const string& key, Device* from_device,
+                  DeviceContext* from_device_ctx,
+                  const AllocatorAttributes& from_alloc_attr,
+                  const Tensor* from_tensor,
+                  const DeviceLocality& client_locality,
+                  const StatusCallback& done) override {
+    done(errors::Internal("Unimplemented"));
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(DummyCollectiveExecutor);
+};
+}  // namespace
+
+CollectiveExecutorMgr::CollectiveExecutorMgr(
+    const ConfigProto& config, const DeviceMgr* dev_mgr,
+    DeviceResolverInterface* dev_resolver,
+    ParamResolverInterface* param_resolver)
+    : dev_mgr_(dev_mgr),
+      dev_resolver_(dev_resolver),
+      param_resolver_(param_resolver) {}
+
+CollectiveExecutorMgr::~CollectiveExecutorMgr() {
+  for (auto iter : executor_table_) {
+    iter.second->Unref();
+  }
+}
+
+CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
+  CollectiveExecutor* ce = nullptr;
+  {
+    mutex_lock l(exec_mu_);
+    auto it = executor_table_.find(step_id);
+    if (it != executor_table_.end()) {
+      ce = it->second;
+    } else {
+      ce = new DummyCollectiveExecutor(this);
+      executor_table_[step_id] = ce;
+    }
+    ce->Ref();
+  }
+  return ce;
+}
+
+void CollectiveExecutorMgr::Cleanup(int64 step_id) {
+  CollectiveExecutor* ce = nullptr;
+  {
+    mutex_lock l(exec_mu_);
+    auto it = executor_table_.find(step_id);
+    if (it != executor_table_.end()) {
+      ce = it->second;
+      executor_table_.erase(it);
+    }
+  }
+  if (ce) ce->Unref();
+}
+
+void CollectiveExecutorMgr::GetStepSequenceAsync(
+    const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+    const StatusCallback& done) {
+  done(errors::Internal(
+      "CollectiveExecutorMgr does not implement GetStepSequence."));
+}
+
+void CollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+    int64 graph_key, const StatusCallback& done) {
+  done(errors::Internal(
+      "CollectiveExecutorMgr does not implement RefreshStepIdSequence."));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
new file mode 100644 (file)
index 0000000..4b42e2b
--- /dev/null
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class ConfigProto;
+class DeviceMgr;
+
+class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
+ public:
+  CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
+                        DeviceResolverInterface* dev_resolver,
+                        ParamResolverInterface* param_resolver);
+
+  virtual ~CollectiveExecutorMgr();
+
+  CollectiveExecutor* FindOrCreate(int64 step_id) override;
+
+  void Cleanup(int64 step_id) override;
+
+  ParamResolverInterface* GetParamResolver() const override {
+    return param_resolver_.get();
+  }
+
+  DeviceResolverInterface* GetDeviceResolver() const override {
+    return dev_resolver_.get();
+  }
+
+  void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+                            GetStepSequenceResponse* response,
+                            const StatusCallback& done) override;
+
+  void RefreshStepIdSequenceAsync(int64 graph_key,
+                                  const StatusCallback& done) override;
+
+  int64 NextStepId(int64 graph_key) override {
+    return CollectiveExecutor::kInvalidId;
+  }
+
+  void RetireStepId(int64 graph_key, int64 step_id) override {}
+
+ protected:
+  const DeviceMgr* dev_mgr_;
+  std::unique_ptr<DeviceResolverInterface> dev_resolver_;
+  std::unique_ptr<ParamResolverInterface> param_resolver_;
+  CollectiveRemoteAccess* remote_access_;
+  string task_name_;
+  mutex exec_mu_;
+  // Map from step_id to CollectiveExecutor
+  gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
new file mode 100644 (file)
index 0000000..34c9163
--- /dev/null
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+  CollectiveExecutorMgrTest() {
+    ConfigProto cp;
+    SessionOptions options;
+    auto* device_count = options.config.mutable_device_count();
+    string task_name = "/job:localhost/replica:0/task:0";
+    device_count->insert({"CPU", NUM_DEVS});
+    TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+    device_mgr_.reset(new DeviceMgr(devices_));
+    DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+    cme_.reset(new CollectiveExecutorMgr(
+        cp, device_mgr_.get(), drl,
+        new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+  }
+
+  std::unique_ptr<CollectiveExecutorMgr> cme_;
+  std::vector<Device*> devices_;
+  std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(CollectiveExecutorMgrTest, FindOrCreate) {
+  CollectiveExecutor::Handle* h =
+      new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+  EXPECT_TRUE(h->get());
+  CollectiveExecutor::Handle* h2 =
+      new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+  EXPECT_EQ(h->get(), h2->get());
+  CollectiveExecutor* ce = h->get();
+  delete h;
+  delete h2;
+  CollectiveExecutor::Handle h3(cme_->FindOrCreate(1), true);
+  EXPECT_EQ(ce, h3.get());
+  cme_->Cleanup(1);
+}
+
+TEST_F(CollectiveExecutorMgrTest, StepSequenceRelated) {
+  EXPECT_EQ(CollectiveExecutor::kInvalidId, cme_->NextStepId(123));
+  Notification ss_note;
+  Status ss_status;
+  cme_->RefreshStepIdSequenceAsync(
+      123, [this, &ss_status, &ss_note](const Status& s) {
+        ss_status = s;
+        ss_note.Notify();
+      });
+  ss_note.WaitForNotification();
+  EXPECT_FALSE(ss_status.ok());
+  EXPECT_EQ(ss_status.error_message(),
+            "CollectiveExecutorMgr does not implement RefreshStepIdSequence.");
+  Notification gs_note;
+  Status gs_status;
+  GetStepSequenceRequest* req = nullptr;
+  GetStepSequenceResponse* resp = nullptr;
+  cme_->GetStepSequenceAsync(req, resp,
+                             [this, &gs_status, &gs_note](const Status& s) {
+                               gs_status = s;
+                               gs_note.Notify();
+                             });
+  gs_note.WaitForNotification();
+  EXPECT_FALSE(gs_status.ok());
+  EXPECT_EQ(gs_status.error_message(),
+            "CollectiveExecutorMgr does not implement GetStepSequence.");
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
new file mode 100644 (file)
index 0000000..b34950b
--- /dev/null
@@ -0,0 +1,666 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+CollectiveParamResolverLocal::CollectiveParamResolverLocal(
+    const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+    const string& task_name)
+    : dev_mgr_(dev_mgr), dev_resolver_(dev_resolver), task_name_(task_name) {}
+
+void CollectiveParamResolverLocal::CompleteGroupAsync(
+    const CompleteGroupRequest* request, CompleteGroupResponse* response,
+    CancellationManager* cancel_mgr, const StatusCallback& done) {
+  done(
+      errors::Internal("CompleteGroup is not implemented by "
+                       "CollectiveParamResolverLocal which is "
+                       "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteGroupLocal(
+    const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
+  VLOG(1) << "CompleteGroupLocal " << cp << ": " << cp->ToString();
+  std::vector<StatusCallback> to_be_called;
+  GroupRec* gr = nullptr;
+  {
+    mutex_lock l(group_mu_);
+    auto it = group_table_.find(cp->group.group_key);
+    if (it == group_table_.end()) {
+      gr = new GroupRec;
+      gr->group.group_key = cp->group.group_key;
+      gr->group.group_size = cp->group.group_size;
+      gr->group.device_type = cp->group.device_type;
+      group_table_[gr->group.group_key].reset(gr);
+      VLOG(2) << "New group_key=" << gr->group.group_key
+              << " group_size=" << gr->group.group_size;
+    } else {
+      gr = it->second.get();
+    }
+  }
+  Status status;
+  {
+    mutex_lock gr_lock(gr->mu);
+    if (!gr->device_set.empty()) {
+      // Check for consistency with existing GroupRec.
+      if (cp->group.device_type != gr->group.device_type) {
+        status = errors::Internal(
+            "Collective Op ", cp->name, " is assigned to device ", device,
+            " with type ", cp->group.device_type.type_string(),
+            " and group_key ", cp->group.group_key, " but that group has type ",
+            gr->group.device_type.type_string());
+      } else if (cp->group.group_size != gr->group.group_size) {
+        status = errors::Internal(
+            "Collective Op ", cp->name, " has group_size ",
+            cp->group.group_size, " and group_key", cp->group.group_key,
+            " but that group has size ", gr->group.group_size);
+      }
+    }
+    if (status.ok()) {
+      // Insert device if not already present.
+      auto it = gr->device_set.find(device);
+      if (it == gr->device_set.end()) {
+        if (gr->device_set.size() == gr->group.group_size) {
+          // The group is already full.
+          status = errors::Internal(
+              "Collective Op ", cp->name, " is assigned to device ", device,
+              " and group_key ", cp->group.group_key,
+              " but that group doesn't contain that device.");
+        } else {
+          // This is a new device that has not yet joined the group.
+          gr->device_set.insert(device);
+          gr->device_list.push_back(device);
+          DeviceNameUtils::ParsedName parsed_device;
+          DeviceNameUtils::ParseFullName(device, &parsed_device);
+          string task_name = strings::StrCat("/job:", parsed_device.job,
+                                             "/replica:", parsed_device.replica,
+                                             "/task:", parsed_device.task);
+          gr->task_set.insert(task_name);
+          gr->task_list.push_back(task_name);
+          gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
+          VLOG(1) << "group_key=" << gr->group.group_key
+                  << " group_size=" << gr->group.group_size
+                  << " dev_set=" << gr->device_set.size();
+        }
+      }
+    }
+
+    if (status.ok()) {
+      // If the group is not yet complete, queue to wait for it.
+      VLOG(2) << "group_size " << gr->group.group_size << " set size "
+              << gr->device_set.size() << " gr " << gr;
+
+      if (gr->device_set.size() < gr->group.group_size) {
+        gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
+        return;
+      }
+      CHECK_EQ(gr->device_set.size(), gr->group.group_size);
+      if (!gr->waiting.empty()) {
+        std::swap(to_be_called, gr->waiting);
+      }
+    }
+  }
+  done(status, gr);
+  for (int i = 0; i < to_be_called.size(); ++i) {
+    to_be_called[i](Status::OK());
+  }
+}
+
+namespace {
+
+struct DevRec {
+  string task;
+  string device;
+  int original_rank;
+  int local_rank;
+  int global_rank;
+  const DeviceLocality* locality;
+};
+typedef std::unordered_map<string, DevRec> TaskDeviceMap;
+typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
+
+// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
+GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
+                             const std::vector<DeviceLocality>& localities) {
+  GlobalDeviceMap gdm;
+  CHECK_EQ(ip.device_names.size(), ip.task_names.size());
+  CHECK_EQ(ip.device_names.size(), localities.size());
+  for (int i = 0; i < ip.device_names.size(); ++i) {
+    TaskDeviceMap& tdm = gdm[ip.task_names[i]];
+    DevRec* dr = &tdm[ip.device_names[i]];
+    dr->task = ip.task_names[i];
+    dr->device = ip.device_names[i];
+    dr->original_rank = i;
+    dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
+    dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
+    dr->locality = &localities[i];
+  }
+  return gdm;
+}
+
+void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
+  CHECK_GT(tdm->size(), 0);  // Should never be called with 0 devices
+  int least_rank = -1;
+  string next_device;
+  std::set<string> selected;
+  // Starting device is one with the least initial rank.
+  for (const auto& it : *tdm) {
+    if (least_rank < 0 || it.second.original_rank < least_rank) {
+      least_rank = it.second.original_rank;
+      next_device = it.second.device;
+    }
+  }
+  CHECK_GE(least_rank, 0);
+  DeviceNameUtils::ParsedName parsed_name;
+  CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
+  // NOTE: InterconnectLink has only a device_id, nothing more, so for
+  // the time being if there's more than one device at a task we
+  // assume they're all GPUs.
+
+  int next_rank = 0;
+  while (true) {
+    selected.insert(next_device);
+    DevRec* dr = &(*tdm)[next_device];
+    dr->local_rank = next_rank;
+    ++next_rank;
+    if (selected.size() == tdm->size()) {
+      break;
+    }
+    // For the present time we assume Locality links only cover GPUs.
+    // For multiple CPUs, just take them in order.
+    const InterconnectLink* best_link = nullptr;
+    if (parsed_name.type == "GPU") {
+      for (const InterconnectLink& il : dr->locality->links().link()) {
+        parsed_name.id = il.device_id();
+        string endpoint_device =
+            DeviceNameUtils::ParsedNameToString(parsed_name);
+        if (selected.find(endpoint_device) != selected.end()) {
+          continue;
+        }
+        if (best_link == nullptr || il.strength() > best_link->strength()) {
+          best_link = &il;
+        }
+      }
+    }
+    if (best_link != nullptr) {
+      // Follow the best edge
+      parsed_name.id = best_link->device_id();
+      next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
+    } else {
+      // No good edges, alas. Pick the lowest initial rank among remaining
+      // devices.
+      least_rank = -1;
+      for (const auto& it : *tdm) {
+        if (selected.find(it.second.device) != selected.end()) {
+          continue;
+        }
+        if (least_rank < 0 || it.second.original_rank < least_rank) {
+          least_rank = it.second.original_rank;
+          next_device = it.second.device;
+        }
+      }
+      CHECK_GE(least_rank, 0);
+    }
+  }
+}
+
+// The first time a shared CollectiveParams is established for a
+// shared set of instances we compute a good rank order for all the
+// devices in the group, that is appropriate for a ring algorithm.
+// This order need not be the same across different instance groups
+// sharing the same device group where there is more than one good
+// order.
+GlobalDeviceMap EstablishGlobalRank(
+    CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
+  VLOG(1) << "EstablishGlobalRank";
+  GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
+  for (auto& iter : gdm) {
+    TaskDeviceMap& tdm = iter.second;
+    OrderTaskDeviceMap(&tdm);
+  }
+  // Connect the global rank order by the order in which tasks first appear.
+  std::set<string> ordered_tasks;
+  int next_rank = 0;
+  for (int i = 0; i < cp->instance.task_names.size(); ++i) {
+    const string& task_name = cp->instance.task_names[i];
+    if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
+      continue;
+    }
+    ordered_tasks.insert(task_name);
+    TaskDeviceMap* tdm = &gdm[task_name];
+    for (auto& it : *tdm) {
+      it.second.global_rank = it.second.local_rank + next_rank;
+    }
+    next_rank += tdm->size();
+  }
+  return gdm;
+}
+
+// Sort cp->instance.device_names lexicographically, but do by first
+// computing a reordering permutation so we can keep cp->instance.task_names
+// in corresponding order.
+void SortDevicesAndTasks(CollectiveParams* cp) {
+  VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
+  CHECK(cp);
+  CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
+  CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
+  std::vector<int> perm(cp->group.group_size);
+  // TODO(tucker): substitute std::iota when the windows build supports it.
+  // std::iota(perm.begin(), perm.end(), 0);
+  for (int i = 0; i < perm.size(); ++i) {
+    perm[i] = i;
+  }
+  std::sort(perm.begin(), perm.end(), [cp](const int& a, const int& b) {
+    return cp->instance.device_names[a] < cp->instance.device_names[b];
+  });
+  std::vector<string> new_devs;
+  std::vector<string> new_tasks;
+  new_devs.reserve(cp->group.group_size);
+  new_tasks.reserve(cp->group.group_size);
+  for (int pi : perm) {
+    new_devs.push_back(cp->instance.device_names[pi]);
+    new_tasks.push_back(cp->instance.task_names[pi]);
+  }
+  cp->instance.device_names = std::move(new_devs);
+  cp->instance.task_names = std::move(new_tasks);
+  VLOG(1) << "Modified device_names on " << cp;
+}
+
+// Establish the requested number of subdivision permutations based on the
+// ring order implicit in the device order.
+void GenerateSubdivPerms(const string& device, int source_rank,
+                         CollectiveParams* cp) {
+  CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+  cp->instance.impl_details.subdiv_permutations.resize(
+      cp->instance.impl_details.subdiv_offsets.size());
+  // Each subdiv permutation is a ring formed by rotating each
+  // single-task subsequence of devices by an offset.  This makes most
+  // sense when each task has the same number of devices but we can't
+  // depend on that being the case so we'll compute something that
+  // works in any case.
+
+  // Start by counting the devices in each task.
+  // Precondition: device_names must be sorted so that all devices in
+  // the same task are adjacent.
+  VLOG(2) << "Sorted task names: "
+          << str_util::Join(cp->instance.task_names, ", ");
+  std::vector<int> dev_per_task;
+  const string* prior_task_name = &cp->instance.task_names[0];
+  int dev_count = 1;
+  for (int di = 1; di < cp->group.group_size; ++di) {
+    if (cp->instance.task_names[di] != *prior_task_name) {
+      dev_per_task.push_back(dev_count);
+      dev_count = 1;
+      prior_task_name = &cp->instance.task_names[di];
+    } else {
+      ++dev_count;
+    }
+  }
+  dev_per_task.push_back(dev_count);
+  CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
+
+  // Generate a ring permutation for each requested offset.
+  CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+  VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
+          << &cp->instance.impl_details.subdiv_permutations;
+  cp->instance.impl_details.subdiv_permutations.resize(
+      cp->instance.impl_details.subdiv_offsets.size());
+  cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
+  for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
+       ++sdi) {
+    std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
+    CHECK_EQ(perm.size(), 0);
+    int offset = cp->instance.impl_details.subdiv_offsets[sdi];
+    int prior_dev_count = 0;
+    for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
+      for (int di = 0; di < dev_per_task[ti]; ++di) {
+        int offset_di = (di + offset) % dev_per_task[ti];
+        int permuted_di = prior_dev_count + offset_di;
+        perm.push_back(permuted_di);
+        if (cp->instance.device_names[prior_dev_count + di] == device) {
+          CHECK_EQ(prior_dev_count + di, cp->default_rank);
+          cp->subdiv_rank[sdi] = permuted_di;
+        }
+      }
+      prior_dev_count += dev_per_task[ti];
+    }
+    CHECK_EQ(cp->group.group_size, perm.size());
+  }
+
+  if (cp->instance.type == BROADCAST_COLLECTIVE) {
+    CHECK_GE(source_rank, 0);
+    cp->subdiv_source_rank.resize(
+        cp->instance.impl_details.subdiv_offsets.size(), -1);
+    for (int sdi = 0; sdi < cp->subdiv_source_rank.size(); ++sdi) {
+      for (int j = 0; j < cp->group.group_size; ++j) {
+        if (cp->instance.impl_details.subdiv_permutations[sdi][j] ==
+            source_rank) {
+          cp->subdiv_source_rank[sdi] = j;
+          break;
+        }
+      }
+      CHECK_GE(cp->subdiv_source_rank[sdi], 0);
+    }
+  }
+
+  if (VLOG_IS_ON(1)) {
+    // Log the computed ring order for each subdiv.
+    string buf;
+    for (int sdi = 0;
+         sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
+      buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
+      for (int di = 0;
+           di < cp->instance.impl_details.subdiv_permutations[sdi].size();
+           ++di) {
+        int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
+        strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
+      }
+      strings::StrAppend(&buf, " subdiv_offsets: ");
+      for (auto o : cp->instance.impl_details.subdiv_offsets)
+        strings::StrAppend(&buf, o, " ");
+      strings::StrAppend(&buf, " SubdivRank: ");
+      for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
+      VLOG(1) << buf;
+    }
+  }
+}
+
+}  // namespace
+
+void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
+                                                       CollectiveParams* cp) {
+  cp->task.is_local.resize(cp->group.group_size, false);
+  for (int i = 0; i < cp->group.group_size; ++i) {
+    cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
+  }
+}
+
+void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
+                                                  CollectiveParams* cp) {
+  CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
+  for (int i = 0; i < cp->group.group_size; ++i) {
+    if (cp->instance.device_names[i] == device) {
+      cp->default_rank = i;
+      break;
+    }
+  }
+}
+
+Status CollectiveParamResolverLocal::InitInstanceSharedParams(
+    GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
+  VLOG(1) << "InitInstanceSharedParams " << ir;
+  ir->shared.instance = cp->instance;
+  {
+    mutex_lock gl(gr->mu);
+    ir->shared.group = gr->group;
+    ir->shared.instance.device_names.assign(gr->device_list.begin(),
+                                            gr->device_list.end());
+    ir->shared.instance.task_names.assign(gr->task_list.begin(),
+                                          gr->task_list.end());
+    VLOG(2) << "Initialized names for instance: "
+            << ir->shared.instance.ToString();
+  }
+  ir->shared.default_rank = -1;
+
+  // Sort devce_names lexicographcally, keeping task_names in
+  // corresponding order.
+  SortDevicesAndTasks(&ir->shared);
+
+  // Get Locality data for all devices.
+
+  // Set is_local and task_names in *shared prior to invoking
+  // GetDeviceLocalitiesAsync.  In a distributed context this function can be
+  // called by a derived class, some of the devices may be non-local and
+  // GetDeviceLocalitiesAsync will use those fields to launch RPCs.
+  CompleteTaskIsLocal(task_name_, &ir->shared);
+  std::vector<DeviceLocality> localities;
+  Notification note;
+  Status status;
+  dev_resolver_->GetDeviceLocalitiesAsync(ir->shared.instance, &localities,
+                                          [&note, &status](const Status& s) {
+                                            status = s;
+                                            note.Notify();
+                                          });
+  note.WaitForNotification();
+  if (status.ok()) {
+    CompleteDefaultRanking(gr, cp, ir, localities);
+  }
+  return status;
+}
+
+void CollectiveParamResolverLocal::CompleteDefaultRanking(
+    GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
+    const std::vector<DeviceLocality>& localities) {
+  // Establish an instance-specific default rank order for devices
+  // based on localities.  This rank order should be a good ring
+  // order, if possible.
+  GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
+  // Reflect the new global ranking on shared
+  size_t num_devices = ir->shared.group.group_size;
+  std::vector<string> new_device_names(num_devices, "");
+  std::vector<string> new_task_names(num_devices, "");
+  for (const auto& git : gdm) {
+    const TaskDeviceMap& tdm = git.second;
+    for (const auto& tit : tdm) {
+      const DevRec& dr = tit.second;
+      new_device_names[dr.global_rank] =
+          ir->shared.instance.device_names[dr.original_rank];
+      new_task_names[dr.global_rank] =
+          ir->shared.instance.task_names[dr.original_rank];
+    }
+  }
+
+  ir->shared.instance.device_names = new_device_names;
+  ir->shared.instance.task_names = new_task_names;
+  if (VLOG_IS_ON(2)) {
+    string buf;
+    for (const auto& d : cp->instance.device_names)
+      strings::StrAppend(&buf, "\n", d);
+    VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
+  }
+}
+
+void CollectiveParamResolverLocal::CallbackWithStatus(
+    const InstanceRecCallback& done, InstanceRec* irec) {
+  Status s;
+  {
+    mutex_lock l(irec->out_mu);
+    s = irec->status;
+  }
+  done(s, irec);
+}
+
+void CollectiveParamResolverLocal::FindInstanceRec(
+    GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
+  InstanceRec* irec = nullptr;
+  bool exit_outside_locks = false;
+  {
+    mutex_lock l(instance_mu_);
+    auto it = instance_table_.find(cp->instance.instance_key);
+    if (it != instance_table_.end()) {
+      irec = it->second.get();
+      {
+        mutex_lock l(irec->in_mu);
+        if (irec->is_init) {
+          exit_outside_locks = true;
+        } else {
+          irec->init_waiters.push_back([this, gr, cp, done](InstanceRec* irec) {
+            CallbackWithStatus(done, irec);
+          });
+          return;
+        }
+      }
+    } else {
+      // Create new InstanceRec.
+      irec = new InstanceRec;
+      instance_table_[cp->instance.instance_key].reset(irec);
+    }
+  }
+  if (exit_outside_locks) {
+    CallbackWithStatus(done, irec);
+    return;
+  }
+  // Initialize the new InstanceRec while holding out_mu.
+  {
+    mutex_lock il(irec->out_mu);
+    irec->known.resize(cp->group.group_size, false);
+    irec->status = InitInstanceSharedParams(gr, cp, irec);
+  }
+  // Prepare to invoke any waiters that accumlated during initialization.
+  std::vector<IRConsumer> init_waiters;
+  {
+    mutex_lock tl(instance_mu_);
+    {
+      mutex_lock l(irec->in_mu);
+      irec->is_init = true;
+      if (!irec->init_waiters.empty()) {
+        std::swap(init_waiters, irec->init_waiters);
+      }
+    }
+  }
+  CallbackWithStatus(done, irec);
+  for (auto& f : init_waiters) {
+    f(irec);
+  }
+}
+
+void CollectiveParamResolverLocal::CompleteParamsAsync(
+    const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+    const StatusCallback& done) {
+  VLOG(1) << "CompleteParams " << device << " for " << cp << ": "
+          << cp->ToString();
+  CompleteGroupLocal(
+      device, cp, [this, device, cp, done](const Status& s, GroupRec* gr) {
+        if (s.ok()) {
+          CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+        } else {
+          done(s);
+        }
+      });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceAsync(
+    const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
+    CancellationManager* cancel_mgr, const StatusCallback& done) {
+  done(
+      errors::Internal("CompleteInstance is not implemented by "
+                       "CollectiveParamResolverLocal which is "
+                       "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceLocal(
+    const string& device, GroupRec* gr, CollectiveParams* cp, bool is_source,
+    const StatusCallback& done) {
+  VLOG(1) << "CompleteInstanceLocal " << device
+          << " instance_key: " << cp->instance.instance_key << " gr " << gr;
+
+  // Populate the group portion of *cp from *gr.  Most of it should already
+  // match.
+  DCHECK_EQ(cp->group.group_key, gr->group.group_key);
+  DCHECK_EQ(cp->group.group_size, gr->group.group_size);
+  DCHECK_EQ(cp->group.device_type, gr->group.device_type);
+  cp->group = gr->group;
+
+  // Get the shared InstanceRec for this instance.
+  FindInstanceRec(gr, cp,
+                  [this, device, gr, cp, is_source, done](const Status& s,
+                                                          InstanceRec* ir) {
+                    if (s.ok()) {
+                      CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
+                                                          is_source, done);
+                    } else {
+                      done(s);
+                    }
+                  });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
+    const string& device, GroupRec* gr, CollectiveParams* cp, InstanceRec* ir,
+    bool is_source, const StatusCallback& done) {
+  // Populate the fields common across instance.
+  {
+    mutex_lock l(ir->out_mu);
+    // custom operator= does a deep copy.
+    cp->instance = ir->shared.instance;
+  }
+  // Populate the fields common across task, also default_rank.
+  SetDefaultRank(device, cp);
+  CompleteTaskIsLocal(task_name_, cp);
+  // If broadcast, may need to wait for source discovery.
+  if (cp->instance.type == BROADCAST_COLLECTIVE) {
+    CompleteInstanceSource(ir, cp, is_source,
+                           [this, ir, device, cp, done](InstanceRec* irec) {
+                             CHECK_EQ(ir, irec);
+                             Status s;
+                             int source_rank;
+                             {
+                               mutex_lock l(irec->out_mu);
+                               s = irec->status;
+                               source_rank = ir->source_rank;
+                             }
+                             if (s.ok()) {
+                               GenerateSubdivPerms(device, source_rank, cp);
+                             }
+                             done(s);
+                           });
+    return;
+  } else {
+    GenerateSubdivPerms(device, 0, cp);
+  }
+  done(Status::OK());
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
+                                                          CollectiveParams* cp,
+                                                          bool is_source,
+                                                          const IRConsumer& f) {
+  std::vector<IRConsumer> ready_waiters;
+  {
+    mutex_lock l(ir->out_mu);
+    CHECK_EQ(cp->group.group_size, ir->known.size());
+    CHECK_GE(cp->default_rank, 0);
+    if (!ir->known[cp->default_rank]) {
+      ir->known[cp->default_rank] = true;
+      ++ir->known_count;
+      if (is_source) {
+        if (ir->source_rank >= 0) {
+          ir->status = errors::Internal("Instance ", cp->instance.instance_key,
+                                        " already has source ", ir->source_rank,
+                                        ", recevied second claim from ",
+                                        cp->default_rank);
+        } else {
+          ir->source_rank = cp->default_rank;
+        }
+      }
+    }
+    if (ir->known_count < ir->shared.group.group_size) {
+      ir->known_waiters.push_back(f);
+      return;
+    }
+    CHECK_EQ(ir->known_count, ir->shared.group.group_size);
+    CHECK_GE(ir->source_rank, 0);
+    if (!ir->known_waiters.empty()) {
+      ready_waiters = std::move(ir->known_waiters);
+    }
+  }
+  f(ir);
+  for (auto& f : ready_waiters) {
+    f(ir);
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
new file mode 100644 (file)
index 0000000..ff3415b
--- /dev/null
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceMgr;
+
+// Implements ParamResolverInterface for a single-task context.
+// It also implements the functionality necessary to serve as the
+// group leader for param resolution in a multi-task context.
+class CollectiveParamResolverLocal : public ParamResolverInterface {
+ public:
+  CollectiveParamResolverLocal(const DeviceMgr* dev_mgr,
+                               DeviceResolverInterface* dev_resolver,
+                               const string& task_name);
+
+  ~CollectiveParamResolverLocal() override {}
+
+  void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+                           CancellationManager* cancel_mgr,
+                           const StatusCallback& done) override;
+
+  void CompleteGroupAsync(const CompleteGroupRequest* request,
+                          CompleteGroupResponse* response,
+                          CancellationManager* cancel_mgr,
+                          const StatusCallback& done) override;
+
+  void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+                             CompleteInstanceResponse* response,
+                             CancellationManager* cancel_mgr,
+                             const StatusCallback& done) override;
+
+ protected:
+  // Used to complete/verify CollGroup.
+  struct GroupRec {
+    CollGroupParams group;
+    mutex mu;
+    Status status GUARDED_BY(mu);
+    std::set<string> device_set GUARDED_BY(mu);
+    std::vector<string> device_list GUARDED_BY(mu);
+    std::set<string> task_set GUARDED_BY(mu);
+    std::vector<string> task_list GUARDED_BY(mu);
+    std::vector<StatusCallback> waiting GUARDED_BY(mu);
+  };
+
+  // Finds the GroupRec that corresponds to cp->group_key.
+  // Also populates cp->group from that group_rec.
+  // Will wait until GroupRec is fully populated or an error arises before
+  // calling done.  Callback GroupRec* arg is only valid if status is ok.
+  // Ownership of GroupRec stays with this object and does not pass to the
+  // callback.
+  typedef std::function<void(const Status& s, GroupRec* gr)> GroupRecCallback;
+  void CompleteGroupLocal(const string& device, CollectiveParams* cp,
+                          const GroupRecCallback& done)
+      LOCKS_EXCLUDED(group_mu_);
+
+  // Used to complete/verify CollInstance.
+  struct InstanceRec;
+  typedef std::function<void(InstanceRec*)> IRConsumer;
+  struct InstanceRec {
+    // This structure has two mutexes so that a possibly long
+    // initialization can be done without holding the instance_mu_
+    // table lock the whole time (which can cause an excessive number
+    // of threads to block on it), and because the compiler may not
+    // permit mutex locks to be taken in more than one order.
+    //
+    // out_mu guards access to most of the fields.
+    // in_mu guards access to a queue of comsumer callbacks wanting to
+    // read the fields guarded by out_mu.
+    //
+    // The in_mu should be locked only while holding instance_mu_; the
+    // out_mu should be locked only while not holding
+    // instance_mu_.
+    //
+    // When is_init is false (the initial value) any potential user
+    // other than the creator should queue a callback on init_waiters.
+    // As soon as the shared member of this structure is fully
+    // initialized is_init will be set true and those callbacks will
+    // be invoked.
+    //
+    // Once inserted in the table this structure will never be replaced
+    // so users can capture the pointer while holding instance_mu_,
+    // drop that lock, then take a lock on out_mu before
+    // reading/modifying its values.
+    mutex in_mu;
+    bool is_init GUARDED_BY(in_mu);
+    std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
+
+    // Values to be shared by all instances, constant after initialization.
+    mutex out_mu;
+    CollectiveParams shared GUARDED_BY(out_mu);
+    // If an error occurs during initialization this structure stays in
+    // the table with a non-OK status.  Purging the table and restarting
+    // needs to be done at a higher level.
+    Status status GUARDED_BY(out_mu);
+
+    // These fields are used to count the instances that have called
+    // in and become known while resolving broadcast source identity.
+    int source_rank GUARDED_BY(out_mu);
+    int known_count GUARDED_BY(out_mu);
+    std::vector<bool> known GUARDED_BY(out_mu);
+    std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
+
+    InstanceRec() : is_init(false), source_rank(-1), known_count(0) {}
+  };
+
+  // Find the InstanceRec with the same instance_key as cp.  If it doesn't
+  // already exist, create and initialize from gr and cp.
+  //
+  // Precondition: *gr must be a complete GroupRec, i.e. the value set
+  // by CompleteGroupLocal. *cp must be populated with all the fields
+  // required by InitInstanceSharedParams.  Ownership of InstanceRec stays
+  // with this object and does not pass to the callback.
+  typedef std::function<void(const Status& s, InstanceRec* ir)>
+      InstanceRecCallback;
+  void FindInstanceRec(GroupRec* gr, CollectiveParams* cp,
+                       const InstanceRecCallback& done)
+      LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+  // Populate *ir with device membership from gr, then initialize to be specific
+  // to cp->instance_key, i.e. order the devices and tasks.
+  //
+  // Preconditions:
+  //  cp is populated with all DeviceLocalities
+  Status InitInstanceSharedParams(GroupRec* gr, const CollectiveParams* cp,
+                                  InstanceRec* ir)
+      EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+
+  // Establishes the final order of ir->shared.instance.device_names and
+  // ir->shared.instance.task_names by considering localities of all devices.
+  void CompleteDefaultRanking(GroupRec* gr, const CollectiveParams* cp,
+                              InstanceRec* ir,
+                              const std::vector<DeviceLocality>& localities)
+      EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu);
+
+  // Finish populating *cp.
+  // Precondition: *gr has been fully populated by CompleteGroupLocal.
+  void CompleteInstanceLocal(const string& device, GroupRec* gr,
+                             CollectiveParams* cp, bool is_source,
+                             const StatusCallback& done)
+      LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+  // Finish populating *cp from fully initialized *ir.
+  // Precondition: *gr and *ir are fully populated.
+  void CompleteInstanceFromInitializedIRec(const string& device, GroupRec* gr,
+                                           CollectiveParams* cp,
+                                           InstanceRec* ir, bool is_source,
+                                           const StatusCallback& done)
+      LOCKS_EXCLUDED(ir->out_mu);
+
+  // Complete source data for a broadcast instance.
+  // Precondition: *cp has complete group data and default_rank.
+  void CompleteInstanceSource(InstanceRec* ir, CollectiveParams* cp,
+                              bool is_source, const IRConsumer& f)
+      LOCKS_EXCLUDED(ir->out_mu);
+
+  // If cp.device_names contains only devices local to this process
+  // populates *localities, else returns an error.
+  Status GetLocalDeviceLocalities(const CollectiveParams& cp,
+                                  std::vector<DeviceLocality>* localities);
+
+  // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
+  // Precondition: cp->device_names is fully populated and in final order.
+  void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
+
+  // Sets cp->instance_default_rank according to location of device in
+  // current ordering of cp->instance.device_names.
+  void SetDefaultRank(const string& device, CollectiveParams* cp);
+
+  // Helper to grab status under lock, invoke callback out of lock.
+  void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
+      LOCKS_EXCLUDED(irec->out_mu);
+
+  const DeviceMgr* dev_mgr_;
+  DeviceResolverInterface* dev_resolver_;
+  string task_name_;
+  mutex group_mu_;
+  gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
+      GUARDED_BY(group_mu_);
+  mutex instance_mu_;
+  gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
+      GUARDED_BY(instance_mu_);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
new file mode 100644 (file)
index 0000000..4e3c712
--- /dev/null
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveParamResolverLocalTest : public ::testing::Test {
+ protected:
+  CollectiveParamResolverLocalTest() {
+    ConfigProto cp;
+    SessionOptions options;
+    string task_name = "/job:localhost/replica:0/task:0";
+    auto* device_count = options.config.mutable_device_count();
+    device_count->insert({"CPU", NUM_DEVS});
+    TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+    device_mgr_.reset(new DeviceMgr(devices_));
+    drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+    prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+                                                task_name));
+  }
+
+  std::vector<Device*> devices_;
+  std::unique_ptr<DeviceMgr> device_mgr_;
+  std::unique_ptr<DeviceResolverLocal> drl_;
+  std::unique_ptr<CollectiveParamResolverLocal> prl_;
+};
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
+  CollectiveParams cps[NUM_DEVS];
+  Status statuses[NUM_DEVS];
+  Notification note[NUM_DEVS];
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    CollectiveParams* cp = &cps[i];
+    cp->group.group_key = 1;
+    cp->group.group_size = 3;
+    cp->group.device_type = DeviceType("CPU");
+    cp->group.num_tasks = 1;
+    cp->instance.instance_key = 7;
+    cp->instance.type = REDUCTION_COLLECTIVE;
+    cp->instance.data_type = DataType(DT_FLOAT);
+    cp->instance.shape = TensorShape({5});
+    cp->instance.device_names.push_back(
+        strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+    cp->instance.impl_details.subdiv_offsets.push_back(0);
+    cp->is_source = false;
+    Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
+      prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+                                nullptr /*CancellationManager*/,
+                                [this, &statuses, &note, i](const Status& s) {
+                                  statuses[i] = s;
+                                  note[i].Notify();
+                                });
+    });
+  }
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    note[i].WaitForNotification();
+  }
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    TF_ASSERT_OK(statuses[i]);
+    ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+    for (int j = 0; j < NUM_DEVS; ++j) {
+      EXPECT_EQ(
+          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+          cps[i].instance.device_names[j]);
+      EXPECT_TRUE(cps[i].task.is_local[j]);
+    }
+    EXPECT_EQ(cps[i].subdiv_rank[0], i);
+    EXPECT_EQ(cps[i].subdiv_source_rank.size(), 0);
+    EXPECT_FALSE(cps[i].is_source);
+    EXPECT_EQ(cps[i].default_rank, i);
+  }
+}
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
+  CollectiveParams cps[NUM_DEVS];
+  Status statuses[NUM_DEVS];
+  Notification note[NUM_DEVS];
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    CollectiveParams* cp = &cps[i];
+    cp->group.group_key = 1;
+    cp->group.group_size = 3;
+    cp->group.device_type = DeviceType("CPU");
+    cp->group.num_tasks = 1;
+    cp->instance.instance_key = 3;
+    cp->instance.type = BROADCAST_COLLECTIVE;
+    cp->instance.data_type = DataType(DT_FLOAT);
+    cp->instance.shape = TensorShape({5});
+    cp->instance.device_names.push_back(
+        strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+    cp->instance.impl_details.subdiv_offsets.push_back(0);
+    cp->is_source = (i == 1);
+    Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
+      prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+                                nullptr /*CancellationManager*/,
+                                [this, &statuses, &note, i](const Status& s) {
+                                  statuses[i] = s;
+                                  note[i].Notify();
+                                });
+    });
+  }
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    note[i].WaitForNotification();
+  }
+  for (int i = 0; i < NUM_DEVS; ++i) {
+    TF_ASSERT_OK(statuses[i]);
+    ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+    for (int j = 0; j < NUM_DEVS; ++j) {
+      EXPECT_EQ(
+          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+          cps[i].instance.device_names[j]);
+      EXPECT_TRUE(cps[i].task.is_local[j]);
+    }
+    ASSERT_GT(cps[i].subdiv_rank.size(), 0);
+    EXPECT_EQ(cps[i].subdiv_rank[0], i);
+    ASSERT_GT(cps[i].subdiv_source_rank.size(), 0);
+    EXPECT_EQ(cps[i].subdiv_source_rank[0], 1);
+    EXPECT_EQ(cps[i].is_source, (i == 1));
+    EXPECT_EQ(cps[i].default_rank, i);
+  }
+}
+
+// TEST_F(CollectiveParamResolverLocalTest,
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
new file mode 100644 (file)
index 0000000..ad9b32c
--- /dev/null
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+
+namespace tensorflow {
+
+void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
+  buf_rendezvous_.StartAbort(s);
+}
+
+void CollectiveRemoteAccessLocal::RecvFromPeer(
+    const string& peer_device, const string& peer_task, bool peer_is_local,
+    const string& key, Device* to_device, DeviceContext* to_device_ctx,
+    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+    const DeviceLocality& client_locality, const StatusCallback& done) {
+  VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
+          << key;
+  if (!peer_is_local) {
+    done(
+        errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer "
+                         "called with peer_is_local=false"));
+    return;
+  }
+  buf_rendezvous_.ConsumeBuf(
+      key, [this, to_tensor, to_device_ctx, to_device, to_alloc_attr, done](
+               const Status& s, BufRendezvous::Hook* hook) {
+        if (!s.ok()) {
+          done(s);
+          delete hook;
+        } else {
+          int64 recv_bytes = to_tensor->TotalBytes();
+          CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes());
+          MemCpyAsync(hook->prod_ctx,    // src DeviceContext
+                      to_device_ctx,     // dst DeviceContext
+                      hook->prod_dev,    // src Device
+                      to_device,         // dst Device
+                      hook->prod_attr,   // src AllocatorAttributes
+                      to_alloc_attr,     // dst AllocatorAttributes
+                      hook->prod_value,  // src Tensor*
+                      to_tensor,         // dst Tensor*
+                      [hook, done](const Status& s) {
+                        done(s);
+                        hook->prod_cb(s);
+                        delete hook;
+                      });
+        }
+      });
+}
+
+void CollectiveRemoteAccessLocal::PostToPeer(
+    const string& peer_device, const string& peer_task, const string& key,
+    Device* from_device, DeviceContext* from_device_ctx,
+    const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
+    const DeviceLocality& client_locality, const StatusCallback& done) {
+  VLOG(1) << "PostToPeer " << this << " key " << key
+          << " step_id_=" << step_id_;
+  buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
+                             from_alloc_attr, done);
+}
+
+/*static*/
+void CollectiveRemoteAccessLocal::MemCpyAsync(
+    DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
+    Device* dst_dev, const AllocatorAttributes& src_attr,
+    const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
+    const StatusCallback& done) {
+  // We want a real copy to happen, i.e. the bytes inside of src should be
+  // transferred to the buffer backing dst.  If src and dst are on different
+  // devices then CopyTensor::ViaDMA will do just that.  But if they're both
+  // the same CPU, then it will actually just reset dst to point to src.
+  // Since this routine is used for copying between devices and within a
+  // device, we need to detect and bypass the wrong-semantics case.
+  const DeviceType src_device_type(
+      src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type());
+  const DeviceType dst_device_type(
+      dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
+  const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
+  const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
+  if (non_cpu_src) CHECK(src_dev_ctx);
+  if (non_cpu_dst) CHECK(dst_dev_ctx);
+  if (non_cpu_src || non_cpu_dst) {
+    CopyTensor::ViaDMA("",  // edge name (non-existent)
+                       src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
+                       dst_attr, src, dst, done);
+  } else {
+    int64 bytes = src->TotalBytes();
+    DCHECK_EQ(dst->TotalBytes(), bytes);
+    memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes);
+    done(Status::OK());
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
new file mode 100644 (file)
index 0000000..d25dd5f
--- /dev/null
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/rendezvous.h"
+
+namespace tensorflow {
+
+// Basic implementation of PerStepCollectiveRemoteAccess.
+class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
+ public:
+  CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
+                              DeviceResolverInterface* dev_resolver,
+                              int64 step_id)
+      : dev_mgr_(dev_mgr),
+        dev_resolver_(dev_resolver),
+        buf_rendezvous_(step_id),
+        step_id_(step_id) {}
+
+  virtual ~CollectiveRemoteAccessLocal() {}
+
+  void StartAbort(const Status& s);
+
+  void RecvFromPeer(const string& peer_device, const string& peer_task,
+                    bool peer_is_local, const string& key, Device* to_device,
+                    DeviceContext* to_device_ctx,
+                    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+                    const DeviceLocality& client_locality,
+                    const StatusCallback& done) override;
+
+  void PostToPeer(const string& peer_device, const string& peer_task,
+                  const string& key, Device* from_device,
+                  DeviceContext* from_device_ctx,
+                  const AllocatorAttributes& from_alloc_attr,
+                  const Tensor* from_tensor,
+                  const DeviceLocality& client_locality,
+                  const StatusCallback& done) override;
+
+  void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+                                std::vector<DeviceLocality>* localities,
+                                const StatusCallback& done) override {
+    dev_resolver_->GetDeviceLocalitiesAsync(ci_params, localities, done);
+  }
+
+  void GetLocalityAsync(const string& device, const string& task,
+                        DeviceLocality* locality,
+                        const StatusCallback& done) override {
+    dev_resolver_->GetLocalityAsync(device, task, locality, done);
+  }
+
+  void ClearTask(const string& task) override {
+    dev_resolver_->ClearTask(task);
+  }
+
+  // Copy utility that always copies bytes from src to dst even if
+  // they are on the same device, unlike CopyTensor::ViaDMA which will
+  // just change the dst buffer pointer in that case.
+  static void MemCpyAsync(DeviceContext* src_dev_ctx,
+                          DeviceContext* dst_dev_ctx, Device* src_dev,
+                          Device* dst_dev, const AllocatorAttributes& src_attr,
+                          const AllocatorAttributes& dst_attr,
+                          const Tensor* src, Tensor* dst,
+                          const StatusCallback& done);
+
+ protected:
+  const DeviceMgr* dev_mgr_;               // not owned
+  DeviceResolverInterface* dev_resolver_;  // not owned
+  BufRendezvous buf_rendezvous_;
+  int64 step_id_;
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
new file mode 100644 (file)
index 0000000..dcd4272
--- /dev/null
@@ -0,0 +1,148 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+static const int kStepId = 123;
+
+class CollectiveRemoteAccessLocalTest : public ::testing::Test {
+ protected:
+  const string kTaskName = "/job:localhost/replica:0/task:0";
+
+  CollectiveRemoteAccessLocalTest() {
+    ConfigProto cp;
+    SessionOptions options;
+    auto* device_count = options.config.mutable_device_count();
+    device_count->insert({"CPU", NUM_DEVS});
+    TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
+    device_mgr_.reset(new DeviceMgr(devices_));
+    drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+    prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+                                                kTaskName));
+    rma_.reset(new CollectiveRemoteAccessLocal(device_mgr_.get(), drl_.get(),
+                                               kStepId));
+  }
+
+  std::vector<Device*> devices_;
+  std::unique_ptr<DeviceMgr> device_mgr_;
+  std::unique_ptr<DeviceResolverLocal> drl_;
+  std::unique_ptr<CollectiveParamResolverLocal> prl_;
+  std::unique_ptr<CollectiveRemoteAccessLocal> rma_;
+};
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
+  Device* cpu0 = nullptr;
+  AllocatorAttributes attr;
+  DeviceLocality dev_locality;
+  TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
+  Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+  Notification recv_note;
+  Status recv_status;
+  rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
+                     "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
+                     attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+                     [this, &recv_note, &recv_status](const Status& s) {
+                       recv_status = s;
+                       recv_note.Notify();
+                     });
+  Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+  for (int i = 0; i < 8; ++i) {
+    source_tensor.flat<float>()(i) = i / 2;
+  }
+  // Tensors have distinct storage.
+  EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+  Notification send_note;
+  Status send_status;
+  rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
+                   cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
+                   attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+                   [this, &send_note, &send_status](const Status& s) {
+                     send_status = s;
+                     send_note.Notify();
+                   });
+  recv_note.WaitForNotification();
+  send_note.WaitForNotification();
+  TF_EXPECT_OK(recv_status);
+  TF_EXPECT_OK(send_status);
+  // Sink tensor gets the source tensor values.
+  for (int i = 0; i < 8; ++i) {
+    EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+  }
+  // And still has distinct storage.
+  EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
+  Device* cpu2 = nullptr;
+  AllocatorAttributes attr;
+  DeviceLocality dev_locality;
+  TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:2", &cpu2));
+  Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+  Notification recv_note;
+  Status recv_status;
+  rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/,
+                     "key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
+                     attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+                     [this, &recv_note, &recv_status](const Status& s) {
+                       recv_status = s;
+                       recv_note.Notify();
+                     });
+  Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+  for (int i = 0; i < 8; ++i) {
+    source_tensor.flat<float>()(i) = i / 2;
+  }
+  // Tensors have distinct storage.
+  EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+  Device* cpu1 = nullptr;
+  TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:1", &cpu1));
+  Notification send_note;
+  Status send_status;
+  rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0",
+                   cpu1 /*from_device*/, nullptr /*from_device_ctx*/,
+                   attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+                   [this, &send_note, &send_status](const Status& s) {
+                     send_status = s;
+                     send_note.Notify();
+                   });
+  recv_note.WaitForNotification();
+  send_note.WaitForNotification();
+  TF_EXPECT_OK(recv_status);
+  TF_EXPECT_OK(send_status);
+  // Sink tensor gets the source tensor values.
+  for (int i = 0; i < 8; ++i) {
+    EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+  }
+  // And still has distinct storage.
+  EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc
new file mode 100644 (file)
index 0000000..17ef4a2
--- /dev/null
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+void DeviceResolverLocal::GetDeviceLocalitiesAsync(
+    const CollInstanceParams& ci_params,
+    std::vector<DeviceLocality>* localities, const StatusCallback& done) {
+  localities->clear();
+  for (const string& device_name : ci_params.device_names) {
+    Device* dev;
+    Status s = dev_mgr_->LookupDevice(device_name, &dev);
+    if (!s.ok()) {
+      done(s);
+      return;
+    }
+    localities->push_back(dev->attributes().locality());
+  }
+  done(Status::OK());
+}
+
+void DeviceResolverLocal::GetLocalityAsync(const string& device,
+                                           const string& task,
+                                           DeviceLocality* locality,
+                                           const StatusCallback& done) {
+  Device* dev;
+  Status s = dev_mgr_->LookupDevice(device, &dev);
+  if (s.ok()) {
+    *locality = dev->attributes().locality();
+  }
+  done(s);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
new file mode 100644 (file)
index 0000000..098eccd
--- /dev/null
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+class DeviceMgr;
+
+// Implements DeviceResolverInterface in a single-task context.
+class DeviceResolverLocal : public DeviceResolverInterface {
+ public:
+  DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
+
+  virtual ~DeviceResolverLocal() {}
+
+  void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+                                std::vector<DeviceLocality>* localities,
+                                const StatusCallback& done) override;
+
+  void GetLocalityAsync(const string& device, const string& task,
+                        DeviceLocality* locality,
+                        const StatusCallback& done) override;
+
+  void ClearTask(const string& task) override {}
+
+ protected:
+  const DeviceMgr* dev_mgr_;
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc
new file mode 100644 (file)
index 0000000..f5a6471
--- /dev/null
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class DeviceResolverLocalTest : public ::testing::Test {
+ protected:
+  DeviceResolverLocalTest() {
+    ConfigProto cp;
+    SessionOptions options;
+    string task_name = "/job:localhost/replica:0/task:0";
+    auto* device_count = options.config.mutable_device_count();
+    device_count->insert({"CPU", NUM_DEVS});
+    TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+    device_mgr_.reset(new DeviceMgr(devices_));
+    drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+  }
+
+  std::vector<Device*> devices_;
+  std::unique_ptr<DeviceMgr> device_mgr_;
+  std::unique_ptr<DeviceResolverLocal> drl_;
+};
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesKnown) {
+  CollectiveParams cp;
+  std::vector<DeviceLocality> localities;
+  cp.instance.device_names.push_back(
+      "/job:localhost/replica:0/task:0/device:CPU:1");
+  cp.instance.device_names.push_back(
+      "/job:localhost/replica:0/task:0/device:CPU:2");
+  Notification note;
+  Status status;
+  drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+                                 [this, &note, &status](const Status& s) {
+                                   status = s;
+                                   note.Notify();
+                                 });
+  note.WaitForNotification();
+  TF_EXPECT_OK(status);
+  EXPECT_EQ(2, localities.size());
+}
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesUnknown) {
+  CollectiveParams cp;
+  std::vector<DeviceLocality> localities;
+  // In some builds there may be 1 GPU, but there should never be 9.
+  cp.instance.device_names.push_back(
+      "/job:localhost/replica:0/task:0/device:GPU:9");
+  Notification note;
+  Status status;
+  drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+                                 [this, &note, &status](const Status& s) {
+                                   status = s;
+                                   note.Notify();
+                                 });
+  note.WaitForNotification();
+  EXPECT_FALSE(status.ok());
+  EXPECT_EQ(0, localities.size());
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
new file mode 100644 (file)
index 0000000..a26f2c2
--- /dev/null
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/collective.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string CollGroupParams::ToString() const {
+  return strings::StrCat("CollGroupParams {group_key=", group_key,
+                         " group_size=", group_size,
+                         " device_type=", device_type.type_string(),
+                         " num_tasks=", num_tasks, "}");
+}
+
+CollInstanceParams& CollInstanceParams::operator=(
+    const CollInstanceParams& other) {
+  if (this != &other) {
+    instance_key = other.instance_key;
+    type = other.type;
+    data_type = other.data_type;
+    shape = other.shape;
+    device_names.clear();
+    device_names.assign(other.device_names.begin(), other.device_names.end());
+    task_names.assign(other.task_names.begin(), other.task_names.end());
+    impl_details.subdiv_offsets.assign(
+        other.impl_details.subdiv_offsets.begin(),
+        other.impl_details.subdiv_offsets.end());
+    impl_details.subdiv_permutations.clear();
+    for (auto p : other.impl_details.subdiv_permutations) {
+      impl_details.subdiv_permutations.push_back(
+          std::vector<int>(p.begin(), p.end()));
+    }
+    impl_details.subdiv_source_rank.assign(
+        other.impl_details.subdiv_source_rank.begin(),
+        other.impl_details.subdiv_source_rank.end());
+  }
+  return *this;
+}
+
+string CollInstanceParams::ToString() const {
+  string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key,
+                             " type=", type, " data_type=", data_type,
+                             " shape=", shape.DebugString(), " devices {");
+  for (const auto& d : device_names) {
+    strings::StrAppend(&v, d, ",");
+  }
+  strings::StrAppend(&v, "} task_names={");
+  for (const auto& n : task_names) {
+    strings::StrAppend(&v, n, ", ");
+  }
+  strings::StrAppend(&v, "}, subdiv_offsets={");
+  for (const auto& d : impl_details.subdiv_offsets) {
+    strings::StrAppend(&v, d, ",");
+  }
+  strings::StrAppend(&v, "}, subdiv_perms={");
+  for (const auto& p : impl_details.subdiv_permutations) {
+    strings::StrAppend(&v, "{");
+    for (const auto& i : p) {
+      strings::StrAppend(&v, i, ",");
+    }
+    strings::StrAppend(&v, "}");  // one subdiv
+  }
+  strings::StrAppend(&v, "}");  // all subdivs
+  return v;
+}
+
+string CollTaskParams::ToString() const {
+  string v = strings::StrCat("CollTaskParams {is_local={");
+  for (const auto& b : is_local) {
+    strings::StrAppend(&v, static_cast<int>(b), ",");
+  }
+  strings::StrAppend(&v, "}}");
+  return v;
+}
+
+string CollectiveParams::ToString() const {
+  string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
+  strings::StrAppend(&v, " ", instance.ToString());
+  strings::StrAppend(&v, " ", task.ToString());
+  strings::StrAppend(&v, " default_rank=", default_rank,
+                     " is_source=", is_source, " subdiv_rank={");
+  for (const auto& r : subdiv_rank) {
+    strings::StrAppend(&v, r, ",");
+  }
+  if (!subdiv_source_rank.empty()) {
+    strings::StrAppend(&v, " subdiv_rank={");
+    for (const auto& r : subdiv_source_rank) {
+      strings::StrAppend(&v, r, ",");
+    }
+    strings::StrAppend(&v, "}");
+  }
+  strings::StrAppend(&v, "}}");
+  return v;
+}
+
+/*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
+    OpKernelContext* ctx) {
+  return ctx->params_;
+}
+
+/*static*/
+int64 CollectiveExecutor::kInvalidId = -1;
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
new file mode 100644 (file)
index 0000000..362d345
--- /dev/null
@@ -0,0 +1,308 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#define TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class BufRendezvous;
+class CancellationManager;
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceLocality;
+class GetStepSequenceRequest;
+class GetStepSequenceResponse;
+class Op;
+class Tensor;
+
+// Types of supported collective operations.
+enum CollectiveType {
+  REDUCTION_COLLECTIVE = 0,
+  BROADCAST_COLLECTIVE,
+  UNDEFINED_COLLECTIVE,
+};
+
+// Data common to all members of a device group.
+// All members share the same device set but its order is
+// particular to an instance so it is stored there.
+struct CollGroupParams {
+  int32 group_key;
+  int32 group_size;
+  DeviceType device_type;
+  int32 num_tasks;  // number of distinct tasks in group
+  string ToString() const;
+  CollGroupParams() : device_type(DEVICE_CPU) {}
+};
+
+// The best implementation of a collective op depends on many factors
+// including the number of devices involved, the topology of
+// interconnects between them and the sizes of inputs.  This structure
+// is used in generating and representing data movement choreography
+// for each specific algorithm, hence it does not have a single, fixed
+// interpretation.  On first execution the runtime will update this
+// structure with decisions that will guide all subsequent executions.
+struct CollImplDetails {
+  std::vector<std::vector<int>> subdiv_permutations;
+  std::vector<int> subdiv_offsets;
+  // broadcast only: rank of source in each subdiv
+  std::vector<int> subdiv_source_rank;
+};
+
+// Data common to all members of a collective instance.
+struct CollInstanceParams {
+  int32 instance_key;  // Identifies all participating graph nodes.
+  CollectiveType type;
+  DataType data_type;
+  TensorShape shape;
+  // Fully qualified name of device for each member, in default rank order.
+  std::vector<string> device_names;
+  // Task name prefix of corresponding device name.
+  std::vector<string> task_names;
+  CollImplDetails impl_details;
+  string ToString() const;
+  CollInstanceParams& operator=(const struct CollInstanceParams& other);
+};
+
+// Data common to all instance members in the same task.
+struct CollTaskParams {
+  // True for devices that are local to the process, i.e. no RPC needed.
+  std::vector<bool> is_local;
+  string ToString() const;
+};
+
+// Unique to a single CollectiveOp node.
+struct CollectiveParams {
+  CollGroupParams group;
+  CollInstanceParams instance;
+  CollTaskParams task;
+
+  string name;       // node name used only for log or error messages
+  int default_rank;  // index of this op within device_names
+  bool is_source;    // broadcast only
+  // Rank of this device in each subdivision permutation.
+  std::vector<int> subdiv_rank;
+  std::vector<int> subdiv_source_rank;
+  const Tensor* in_tensor;             // kernel input
+  Tensor* out_tensor;                  // kernel output
+  std::unique_ptr<OpKernel> merge_op;  // reduction only
+  std::unique_ptr<OpKernel> final_op;  // reduction only
+  OpKernelContext* op_context;
+  string ToString() const;
+};
+
+class CollectiveExecutor;
+
+// Interface that provides resolution of device localities.
+class DeviceResolverInterface {
+ public:
+  virtual ~DeviceResolverInterface() {}
+
+  // Collects DeviceLocality protobufs from all of the devices identified
+  // in 'col_params'.
+  virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
+                                        std::vector<DeviceLocality>* localities,
+                                        const StatusCallback& done) = 0;
+
+  // Populate *locality with the DeviceLocality of the specified
+  // device.
+  virtual void GetLocalityAsync(const string& device, const string& task,
+                                DeviceLocality* locality,
+                                const StatusCallback& done) = 0;
+
+  // Clear the cache of device data belonging
+  // to the specified task.
+  virtual void ClearTask(const string& task) = 0;
+};
+
+// Interface that provides resolution of shared CollectiveParams fields.
+class ParamResolverInterface {
+ public:
+  virtual ~ParamResolverInterface() {}
+
+  // Called by each collective op at first execution in order to fill out
+  // the CollectiveParams structure with data gathered from the full
+  // (maybe distributed) collection of peer nodes.
+  virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+                                   CancellationManager* cancel_mgr,
+                                   const StatusCallback& done) = 0;
+
+  // Used within a distributed implementation to discover/verify
+  // data shared across a device group.
+  virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
+                                  CompleteGroupResponse* response,
+                                  CancellationManager* cancel_mgr,
+                                  const StatusCallback& done) = 0;
+
+  // Used within a distributed implementation to discover/verify data
+  // shared across an instance group.
+  virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+                                     CompleteInstanceResponse* response,
+                                     CancellationManager* cancel_mgr,
+                                     const StatusCallback& done) = 0;
+};
+
+// Graphs which utilize Collective Ops in a common instance must
+// execute with identical step_ids even if they are disjoint graphs
+// run by otherwise independent tasks.  This interface supplies
+// coordinated step_ids to use in such cases.
+class StepSequenceInterface {
+ public:
+  virtual ~StepSequenceInterface() {}
+
+  // Used with a distributed implementation to coordinate step_id
+  // sequences across tasks.
+  virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+                                    GetStepSequenceResponse* response,
+                                    const StatusCallback& done) = 0;
+
+  // Refresh the local per-graph_key step_id sequence from collective
+  // group leader, if applicable.
+  virtual void RefreshStepIdSequenceAsync(int64 graph_key,
+                                          const StatusCallback& done) = 0;
+
+  // Returns the the step_id that should be used for initiating a new execution
+  // on the specified graph. May return the same step_id multiple times if
+  // RetireStepId or RefreshStepIdReservation is not called.
+  virtual int64 NextStepId(int64 graph_key) = 0;
+
+  // Reports that execution of the given step has completed successfully.
+  // Should be called immediately after a step completes with OK status,
+  // prior to calling NextStepId().  If the step fails, don't call.
+  virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
+};
+
+// Interface that provides access to per-step CollectiveExecutor
+// instances and various distributed resolution capabilities.
+class CollectiveExecutorMgrInterface : public StepSequenceInterface {
+ public:
+  virtual ~CollectiveExecutorMgrInterface() {}
+
+  // Returns the step-specific CollectiveExecutor, creating if one does not
+  // already exist.  The caller assumes ownership of one Ref on the object.
+  virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
+
+  // If there is a CollectiveExecutor for step_id, remove it from the
+  // table.
+  virtual void Cleanup(int64 step_id) = 0;
+
+  virtual ParamResolverInterface* GetParamResolver() const = 0;
+
+  virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
+};
+
+// Interface that a Collective Op implementation uses to exchange data
+// with peers.  Note that data exchange is currently limited to types
+// for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
+// types.
+class PeerAccessInterface {
+ public:
+  virtual ~PeerAccessInterface() {}
+
+  virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
+                            bool peer_is_local, const string& key,
+                            Device* to_device, DeviceContext* to_device_ctx,
+                            const AllocatorAttributes& to_alloc_attr,
+                            Tensor* to_tensor,
+                            const DeviceLocality& client_locality,
+                            const StatusCallback& done) = 0;
+
+  virtual void PostToPeer(const string& peer_device, const string& peer_task,
+                          const string& key, Device* from_device,
+                          DeviceContext* from_device_ctx,
+                          const AllocatorAttributes& from_alloc_attr,
+                          const Tensor* from_tensor,
+                          const DeviceLocality& client_locality,
+                          const StatusCallback& done) = 0;
+};
+
+class PerStepCollectiveRemoteAccess;
+
+// A step-specific object that can execute a collective operation completely
+// described by a CollectiveParams object.
+class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
+ public:
+  virtual void StartAbort(const Status& s) {}
+
+  virtual void ExecuteAsync(OpKernelContext* ctx,
+                            const CollectiveParams& col_params,
+                            const string& exec_key, StatusCallback done) {
+    done(errors::Internal(
+        "A collective Op has been called in a context in which "
+        "a CollectiveExecutor has not been provided."));
+  }
+
+  virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+                                   CancellationManager* cancel_mgr,
+                                   StatusCallback done) {
+    cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
+  }
+
+  virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; }
+
+  // Used to designate an invalid group or instance key.
+  static int64 kInvalidId;
+
+  // Lexically scoped handle for Ref.
+  class Handle {
+   public:
+    explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
+      if (!inherit_ref) ce->Ref();
+    }
+    ~Handle() { ce_->Unref(); }
+    CollectiveExecutor* get() const { return ce_; }
+
+   private:
+    CollectiveExecutor* ce_;
+  };
+
+ protected:
+  explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
+      : cem_(cem) {}
+
+  // For use only by derived classes
+  static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
+  CollectiveExecutorMgrInterface* cem_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
+};
+
+// Interface of a helper object that provices a CollectiveExecutor with
+// all of the remote access it needs.
+class CollectiveRemoteAccess : public PeerAccessInterface,
+                               public DeviceResolverInterface {
+ public:
+  virtual ~CollectiveRemoteAccess() {}
+};
+
+// A per-step version of CollectiveRemoteAccess that cleans up outstanding
+// communications in case step execution is abandoned.
+class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
+ public:
+  virtual ~PerStepCollectiveRemoteAccess() {}
+  virtual void StartAbort(const Status& s) = 0;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
index 5ccd45e..2d97160 100644 (file)
@@ -1101,6 +1101,7 @@ class OpKernelContext {
   void NotifyUseOfPersistentTensor(const Tensor& tensor);
 
   Status status_;
+  friend class CollectiveExecutor;  // for access to params_
   Params* params_;    // not owned
   mutable mutex mu_;  // mutable so const accessors can acquire the lock
   gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);