Collective Ops Part 4
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Apr 2018 20:09:07 +0000 (13:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 19 Apr 2018 20:11:50 +0000 (13:11 -0700)
Add Broadcaster.
A few minor adjustments to CollectiveParams and RMA.

This change is part of a series of changes introducing infrastructure
for collective ops and initial implementations of reduction and broadcast.

PiperOrigin-RevId: 193562391

tensorflow/core/BUILD
tensorflow/core/common_runtime/base_collective_executor.cc
tensorflow/core/common_runtime/base_collective_executor.h
tensorflow/core/common_runtime/broadcaster.cc [new file with mode: 0644]
tensorflow/core/common_runtime/broadcaster.h [new file with mode: 0644]
tensorflow/core/common_runtime/broadcaster_test.cc [new file with mode: 0644]
tensorflow/core/common_runtime/collective_param_resolver_local.cc
tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
tensorflow/core/common_runtime/collective_rma_local.h
tensorflow/core/framework/collective.cc
tensorflow/core/framework/collective.h

index 54e7ab3..c15e7de 100644 (file)
@@ -2256,6 +2256,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/allocator_retry.h",
     "common_runtime/base_collective_executor.h",
     "common_runtime/bfc_allocator.h",
+    "common_runtime/broadcaster.h",
     "common_runtime/buf_rendezvous.h",
     "common_runtime/build_graph_options.h",
     "common_runtime/collective_executor_mgr.h",
@@ -2303,6 +2304,7 @@ tf_cuda_library(
         "common_runtime/allocator_retry.cc",
         "common_runtime/base_collective_executor.cc",
         "common_runtime/bfc_allocator.cc",
+        "common_runtime/broadcaster.cc",
         "common_runtime/buf_rendezvous.cc",
         "common_runtime/build_graph_options.cc",
         "common_runtime/collective_executor_mgr.cc",
@@ -3140,6 +3142,34 @@ tf_cc_tests_gpu(
     ],
 )
 
+tf_cc_tests_gpu(
+    name = "broadcaster_test",
+    size = "small",
+    srcs = [
+        "common_runtime/broadcaster_test.cc",
+    ],
+    linkstatic = tf_kernel_tests_linkstatic(),
+    tags = tf_cuda_tests_tags(),
+    deps = [
+        ":all_kernels",
+        ":core",
+        ":core_cpu",
+        ":core_cpu_internal",
+        ":direct_session_internal",
+        ":framework",
+        ":framework_internal",
+        ":gpu_runtime",
+        ":lib",
+        ":lib_internal",
+        ":ops",
+        ":protos_all_cc",
+        ":protos_test_cc",
+        ":test",
+        ":test_main",
+        ":testlib",
+    ],
+)
+
 tf_cc_test_mkl(
     name = "mkl_runtime_tests",
     size = "small",
index f6332fa..637b43c 100644 (file)
@@ -14,14 +14,13 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/common_runtime/base_collective_executor.h"
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/broadcaster.h"
 #include "tensorflow/core/common_runtime/copy_tensor.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/common_runtime/ring_reducer.h"
 #include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/strings/str_util.h"
 
 #define VALUE_IN_DEBUG_STRING false
 
@@ -194,37 +193,68 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
                                           const CollectiveParams& col_params,
                                           const string& exec_key,
                                           StatusCallback done) {
-  const Tensor* input = &ctx->input(0);
+  // On any individual collective Op failure we need to abort the
+  // BufRendezvous so that other Ops in the instance don't hang
+  // waiting for transmissions that will never happen.  Do so after a
+  // delay so that the original error status is more likely to
+  // propagate up, and peers are unlikely to re-create the purged
+  // BufRendezvous by late-arriving requests.
+  StatusCallback done_safe = [this, done](const Status& s) {
+    if (!s.ok()) {
+      Ref();  // Ensure this lasts until the closure executes.
+      SchedNonBlockingClosureAfter(1000000, [this, s] {
+        remote_access_->buf_rendezvous()->StartAbort(s);
+        Unref();
+      });
+    }
+    done(s);
+  };
+
   Tensor* output = ctx->mutable_output(0);
   string error;
   switch (col_params.instance.type) {
     case REDUCTION_COLLECTIVE: {
       // TODO(tucker): support other reduction algorithms,
       // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc.
+      const Tensor* input = &ctx->input(0);
       RingReducer* reducer =
           CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_,
                         input, output, &error);
       if (!reducer) {
-        done(errors::Internal(error));
+        done_safe(errors::Internal(error));
         return;
       }
       // Run in an I/O thread, so as not to starve the executor threads.
       // TODO(tucker): Instead of forking every per-device Collective
       // Op off into its own thread, consider queuing them on a
       // fixed-size thread-pool dedicated to running CollectiveOps.
-      SchedClosure([reducer, done]() {
-        reducer->Run([reducer, done](const Status& s) {
-          done(s);
+      SchedClosure([reducer, done_safe]() {
+        reducer->Run([reducer, done_safe](const Status& s) {
+          done_safe(s);
           delete reducer;
         });
       });
     } break;
-    case BROADCAST_COLLECTIVE:
-      done(errors::Internal("Collective Broadcast unimplemented"));
-      break;
+
+    case BROADCAST_COLLECTIVE: {
+      Broadcaster* broadcaster = CreateBroadcaster(
+          ctx, CtxParams(ctx), col_params, exec_key, step_id_, output, &error);
+      if (!broadcaster) {
+        done_safe(errors::Internal(error));
+        return;
+      }
+      // Run in an I/O thread, so as not to starve the executor threads.
+      SchedClosure([broadcaster, done_safe]() {
+        broadcaster->Run([broadcaster, done_safe](const Status& s) {
+          done_safe(s);
+          delete broadcaster;
+        });
+      });
+    } break;
+
     default:
-      done(errors::Internal("Unimplemented CollectiveType ",
-                            col_params.instance.type));
+      done_safe(errors::Internal("Unimplemented CollectiveType ",
+                                 col_params.instance.type));
   }
 }
 
@@ -254,4 +284,31 @@ RingReducer* BaseCollectiveExecutor::CreateReducer(
   }
 }
 
+Broadcaster* BaseCollectiveExecutor::CreateBroadcaster(
+    OpKernelContext* ctx, OpKernelContext::Params* params,
+    const CollectiveParams& col_params, const string& exec_key, int64 step_id,
+    Tensor* output, string* error) {
+  switch (col_params.instance.data_type) {
+    case DT_INT32:
+      if (col_params.group.device_type == DEVICE_GPU) {
+        *error =
+            "Collective Broadcast does not support datatype DT_INT32 on "
+            "DEVICE_GPU";
+        return nullptr;
+      }
+      TF_FALLTHROUGH_INTENDED;
+    case DT_FLOAT:
+    case DT_DOUBLE:
+    case DT_INT64: {
+      return new Broadcaster(this, dev_mgr_, ctx, params, col_params, exec_key,
+                             step_id, output);
+    } break;
+    default:
+      *error =
+          strings::StrCat("Collective Broadcast does not support datatype ",
+                          DataTypeString(col_params.instance.data_type));
+      return nullptr;
+  }
+}
+
 }  // namespace tensorflow
index 58eaf31..462d6b7 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/core/framework/device_attributes.pb.h"
 
 namespace tensorflow {
+class Broadcaster;
 class DeviceMgr;
 class RingReducer;
 
@@ -138,6 +139,12 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
                              const string& exec_key, int64 step_id,
                              const Tensor* input, Tensor* output,
                              string* error);
+
+  Broadcaster* CreateBroadcaster(OpKernelContext* ctx,
+                                 OpKernelContext::Params* params,
+                                 const CollectiveParams& col_params,
+                                 const string& exec_key, int64 step_id,
+                                 Tensor* output, string* error);
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
new file mode 100644 (file)
index 0000000..5e8af86
--- /dev/null
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/broadcaster.h"
+
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/env.h"
+
+// Set true for greater intelligibility of debug mode log messages.
+#define READABLE_KEYS false
+
+namespace tensorflow {
+
+namespace {
+// Key to be used for BufRendezvous by Broadcaster.
+string BroadcastBufKey(const string& exec_key, int src_rank, int dst_rank) {
+  if (READABLE_KEYS) {
+    return strings::StrCat("broadcast(", exec_key, "):src(", src_rank, "):dst(",
+                           dst_rank, ")");
+  } else {
+    // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash.
+    return strings::StrCat(exec_key, ":", src_rank, ":", dst_rank);
+  }
+}
+}  // namespace
+
+Broadcaster::Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+                         OpKernelContext* ctx, OpKernelContext::Params* params,
+                         const CollectiveParams& col_params,
+                         const string& exec_key, int64 step_id, Tensor* output)
+    : col_exec_(col_exec),
+      dev_mgr_(dev_mgr),
+      ctx_(ctx),
+      col_params_(col_params),
+      exec_key_(exec_key),
+      rank_(col_params.subdiv_rank[0]),
+      is_source_(col_params.is_source),
+      output_(output),
+      done_(nullptr),
+      device_(nullptr) {}
+
+void Broadcaster::Run(StatusCallback done) {
+  // The optimal data transfer choreography is going to very platform dependent.
+  // That will be addressed by later improvements here or by platform-specific
+  // overrides of collective broadcast. The initial version is simply
+  // a binary tree that completely ignores DeviceLocality.
+  done_ = std::move(done);
+
+  // Get the device for which we're executing and look up its locality.
+  status_ = dev_mgr_->LookupDevice(
+      col_params_.instance.device_names[col_params_.default_rank], &device_);
+  if (!status_.ok()) {
+    done_(status_);
+    return;
+  }
+  CHECK(device_);
+  device_locality_ = device_->attributes().locality();
+
+  RunTree();
+}
+
+// Binary tree parent/child relations are trivial to calculate, i.e.
+// device at rank r is the parent of 2r+1 and 2r+2.  The one exception
+// is if the source is not rank 0.  We treat that case as though the
+// source is appended to the front of the rank ordering as well as
+// continuing to occupy its current position.  Hence we calculate as
+// though each device's rank is actually r+1, then subtract 1 again to
+// get the descendent ranks.  If the source is not rank 0 then its
+// decendents include both {0,1} and the descendents of its current
+// position.  Where a non-0-rank source is a descendent of another
+// device, no send to it is necessary.
+
+/* static*/
+int Broadcaster::TreeRecvFrom(const CollectiveParams& cp) {
+  DCHECK_EQ(1, cp.subdiv_rank.size());
+  if (cp.is_source) return -1;
+  int source_rank = cp.instance.impl_details.subdiv_source_rank[0];
+  int my_rank = cp.subdiv_rank[0];
+  if (source_rank == 0) {
+    return (my_rank - 1) / 2;
+  } else {
+    int predecessor_rank = (my_rank / 2) - 1;
+    return (predecessor_rank < 0) ? source_rank : predecessor_rank;
+  }
+}
+
+/* static */
+void Broadcaster::TreeSendTo(const CollectiveParams& cp,
+                             std::vector<int>* targets) {
+  DCHECK_EQ(1, cp.subdiv_rank.size());
+  targets->clear();
+  int my_rank = cp.subdiv_rank[0];
+  DCHECK_EQ(1, cp.instance.impl_details.subdiv_source_rank.size());
+  int source_rank = cp.instance.impl_details.subdiv_source_rank[0];
+  int successor_rank = 0;
+  if (source_rank == 0) {
+    successor_rank = (2 * my_rank) + 1;
+  } else {
+    successor_rank = (2 * (my_rank + 1));
+  }
+  DCHECK_NE(successor_rank, my_rank);
+  if (cp.is_source && source_rank != 0) {
+    // The source sends to rank 0,1 in addition to its positional
+    // decendents.
+    if (cp.group.group_size > 1) {
+      targets->push_back(0);
+    }
+    if (cp.group.group_size > 2 && source_rank != 1) {
+      targets->push_back(1);
+    }
+  }
+  for (int i = 0; i < 2; ++i) {
+    if (successor_rank < cp.group.group_size && successor_rank != source_rank) {
+      targets->push_back(successor_rank);
+    }
+    ++successor_rank;
+  }
+}
+
+// Execute a tree broadcast, i.e. each non-source device receives from
+// one other and sends to up-to two others.
+void Broadcaster::RunTree() {
+  mutex mu;
+  int pending_count = 0;  // GUARDED_BY(mu)
+  condition_variable all_done;
+  std::vector<int> send_to_ranks;
+  TreeSendTo(col_params_, &send_to_ranks);
+
+  if (!is_source_) {
+    // Begin by receiving the value.
+    int recv_from_rank = TreeRecvFrom(col_params_);
+    Notification note;
+    DispatchRecv(recv_from_rank, output_,
+                 [this, recv_from_rank, &mu, &note](const Status& s) {
+                   mutex_lock l(mu);
+                   status_.Update(s);
+                   note.Notify();
+                 });
+    note.WaitForNotification();
+  }
+
+  // Then forward value to all descendent devices.
+  if (status_.ok()) {
+    for (int i = 0; i < send_to_ranks.size(); ++i) {
+      int target_rank = send_to_ranks[i];
+      {
+        mutex_lock l(mu);
+        ++pending_count;
+      }
+      DispatchSend(
+          target_rank, output_,
+          [this, target_rank, &mu, &pending_count, &all_done](const Status& s) {
+            status_.Update(s);
+            {
+              mutex_lock l(mu);
+              --pending_count;
+              if (pending_count == 0) {
+                all_done.notify_all();
+              }
+            }
+          });
+    }
+  }
+
+  if (status_.ok() && is_source_) {
+    // Meanwhile, copy input to output if we weren't lucky enough to
+    // be able to reuse input as output.
+    const Tensor* input = &ctx_->input(0);
+    if (input != output_ &&
+        (DMAHelper::base(input) != DMAHelper::base(output_))) {
+      {
+        mutex_lock l(mu);
+        ++pending_count;
+      }
+      DeviceContext* op_dev_ctx = ctx_->op_device_context();
+      CollectiveRemoteAccessLocal::MemCpyAsync(
+          op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
+          ctx_->output_alloc_attr(0), input, output_,
+          [this, &mu, &pending_count, &all_done](const Status& s) {
+            status_.Update(s);
+            {
+              mutex_lock l(mu);
+              --pending_count;
+              if (0 == pending_count) {
+                all_done.notify_all();
+              }
+            }
+          });
+    }
+  }
+
+  // Then wait for all pending actions to complete.
+  {
+    mutex_lock l(mu);
+    if (pending_count > 0) {
+      all_done.wait(l);
+    }
+  }
+
+  VLOG(2) << "return status " << status_;
+  done_(status_);
+}
+
+void Broadcaster::DispatchSend(int dst_rank, const Tensor* src_tensor,
+                               const StatusCallback& done) {
+  string send_buf_key = BroadcastBufKey(exec_key_, rank_, dst_rank);
+  VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
+          << device_->name();
+  int dst_idx =
+      col_params_.instance.impl_details.subdiv_permutations[0][dst_rank];
+  col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx],
+                        col_params_.instance.task_names[dst_idx], send_buf_key,
+                        device_, ctx_->op_device_context(),
+                        ctx_->output_alloc_attr(0), src_tensor,
+                        device_locality_, done);
+}
+
+void Broadcaster::DispatchRecv(int src_rank, Tensor* dst_tensor,
+                               const StatusCallback& done) {
+  string recv_buf_key = BroadcastBufKey(exec_key_, src_rank, rank_);
+  int src_idx =
+      col_params_.instance.impl_details.subdiv_permutations[0][src_rank];
+  VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device "
+          << col_params_.instance.device_names[src_idx];
+  int dst_idx = col_params_.instance.impl_details.subdiv_permutations[0][rank_];
+  CHECK_EQ(col_params_.instance.device_names[dst_idx], device_->name());
+  col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx],
+                          col_params_.instance.task_names[src_idx],
+                          col_params_.task.is_local[src_idx], recv_buf_key,
+                          device_, ctx_->op_device_context(),
+                          ctx_->output_alloc_attr(0), dst_tensor,
+                          device_locality_, done);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/broadcaster.h
new file mode 100644 (file)
index 0000000..bdf68f1
--- /dev/null
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+
+#include <vector>
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+
+// Tree-algorithm implementation of collective broadcast.
+class Broadcaster {
+ public:
+  Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+              OpKernelContext* ctx, OpKernelContext::Params* params,
+              const CollectiveParams& col_params, const string& exec_key,
+              int64 step_id, Tensor* output);
+
+  void Run(StatusCallback done);
+
+  // Returns the rank of the device from which this device should receive
+  // its value, -1 if no value should be received.
+  static int TreeRecvFrom(const CollectiveParams& cp);
+
+  // Populates targets with the ranks of the devices to which this device
+  // should forward the value.
+  static void TreeSendTo(const CollectiveParams& cp, std::vector<int>* targets);
+
+ private:
+  void DispatchSend(int dst_rank, const Tensor* src_tensor,
+                    const StatusCallback& done);
+  void DispatchRecv(int src_rank, Tensor* dst_tensor,
+                    const StatusCallback& done);
+  void RunTree();
+
+  Status status_;
+  CollectiveExecutor* col_exec_;  // Not owned
+  const DeviceMgr* dev_mgr_;      // Not owned
+  OpKernelContext* ctx_;          // Not owned
+  const CollectiveParams& col_params_;
+  const string exec_key_;
+  const int rank_;
+  const bool is_source_;
+  Tensor* output_;  // Not owned
+  std::unique_ptr<CollectiveAdapter> ca_;
+  StatusCallback done_;
+  Device* device_;  // The device for which this instance labors
+  DeviceLocality device_locality_;
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/broadcaster_test.cc
new file mode 100644 (file)
index 0000000..89d3914
--- /dev/null
@@ -0,0 +1,741 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/broadcaster.h"
+
+#include <algorithm>
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+
+static int64 kStepId = 123;
+static int32 kNumSubdivs = 1;  // Subdiv not yet meaningful for broadcast
+
+// The test harness won't allow a mixture of fixture and non-fixture
+// tests in one file, so this is a trival fixture for tests that don't
+// need the heavy-weight BroadcasterTest fixture.
+class TrivialTest : public ::testing::Test {
+ protected:
+  TrivialTest() {}
+};
+
+// Tests of static TreeSendTo() and TreeRecvFrom() functions.
+// D = number of devices
+// S = source rank
+// R = tested rank
+// RF = receive-from rank
+// ST = send_to rank vector
+#define DEF_TL_TEST(D, S, R, RF, ST)                               \
+  TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
+    CollectiveParams cp;                                           \
+    cp.group.group_size = D;                                       \
+    cp.instance.impl_details.subdiv_source_rank = {S};             \
+    cp.subdiv_rank = {R};                                          \
+    cp.is_source = (S == R);                                       \
+    EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp));                  \
+    std::vector<int> expected = ST;                                \
+    std::vector<int> send_to;                                      \
+    Broadcaster::TreeSendTo(cp, &send_to);                         \
+    ASSERT_EQ(expected.size(), send_to.size());                    \
+    for (int i = 0; i < expected.size(); ++i) {                    \
+      EXPECT_EQ(expected[i], send_to[i]);                          \
+    }                                                              \
+  }
+
+#define V(...) std::vector<int>({__VA_ARGS__})
+
+//          D  S  R  RF  ST
+// 2 device cases
+DEF_TL_TEST(2, 0, 0, -1, V(1))
+DEF_TL_TEST(2, 1, 0, 1, V())
+DEF_TL_TEST(2, 0, 1, 0, V())
+DEF_TL_TEST(2, 1, 1, -1, V(0))
+// 3 device cases
+DEF_TL_TEST(3, 0, 0, -1, V(1, 2))
+DEF_TL_TEST(3, 0, 1, 0, V())
+DEF_TL_TEST(3, 0, 2, 0, V())
+DEF_TL_TEST(3, 1, 0, 1, V(2))
+DEF_TL_TEST(3, 1, 1, -1, V(0))
+DEF_TL_TEST(3, 1, 2, 0, V())
+DEF_TL_TEST(3, 2, 0, 2, V())
+DEF_TL_TEST(3, 2, 1, 2, V())
+DEF_TL_TEST(3, 2, 2, -1, V(0, 1))
+// 4 device cases
+DEF_TL_TEST(4, 0, 0, -1, V(1, 2))
+DEF_TL_TEST(4, 0, 1, 0, V(3))
+DEF_TL_TEST(4, 0, 2, 0, V())
+DEF_TL_TEST(4, 0, 3, 1, V())
+DEF_TL_TEST(4, 1, 0, 1, V(2, 3))
+DEF_TL_TEST(4, 1, 1, -1, V(0))
+DEF_TL_TEST(4, 1, 2, 0, V())
+DEF_TL_TEST(4, 1, 3, 0, V())
+DEF_TL_TEST(4, 2, 0, 2, V(3))
+DEF_TL_TEST(4, 2, 1, 2, V())
+DEF_TL_TEST(4, 2, 2, -1, V(0, 1))
+DEF_TL_TEST(4, 2, 3, 0, V())
+DEF_TL_TEST(4, 3, 0, 3, V(2))
+DEF_TL_TEST(4, 3, 1, 3, V())
+DEF_TL_TEST(4, 3, 2, 0, V())
+DEF_TL_TEST(4, 3, 3, -1, V(0, 1))
+// 8 device cases
+//          D  S  R  RF  ST
+DEF_TL_TEST(8, 0, 0, -1, V(1, 2))
+DEF_TL_TEST(8, 0, 1, 0, V(3, 4))
+DEF_TL_TEST(8, 0, 2, 0, V(5, 6))
+DEF_TL_TEST(8, 0, 3, 1, V(7))
+DEF_TL_TEST(8, 0, 4, 1, V())
+DEF_TL_TEST(8, 0, 5, 2, V())
+DEF_TL_TEST(8, 0, 6, 2, V())
+DEF_TL_TEST(8, 0, 7, 3, V())
+DEF_TL_TEST(8, 7, 0, 7, V(2, 3))
+DEF_TL_TEST(8, 7, 1, 7, V(4, 5))
+DEF_TL_TEST(8, 7, 2, 0, V(6))
+DEF_TL_TEST(8, 7, 3, 0, V())
+DEF_TL_TEST(8, 7, 4, 1, V())
+DEF_TL_TEST(8, 7, 5, 1, V())
+DEF_TL_TEST(8, 7, 6, 2, V())
+DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
+#undef DEF_TL_TEST
+#undef V
+
+// Wraps CollectiveRemoteAccessLocal with the ability to return an
+// error status to the N'th action.
+// TODO(tucker): factor out of this file and ring_reducer_test.cc
+// into a single common source.
+class FailTestRMA : public CollectiveRemoteAccessLocal {
+ public:
+  FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+              int64 step_id, int fail_after)
+      : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+        fail_after_(fail_after) {}
+
+  bool MaybeFail(const StatusCallback& done) {
+    bool fail_now = false;
+    {
+      mutex_lock l(mu_);
+      if (fail_after_ > 0) {
+        fail_now = (--fail_after_ == 0);
+      }
+    }
+    if (fail_now) {
+      auto error = errors::Internal("Deliberate failure");
+      LOG(INFO) << "triggering failure " << error;
+      SchedNonBlockingClosureAfter(
+          1000, [this, error] { buf_rendezvous()->StartAbort(error); });
+      done(error);
+      return true;
+    }
+    return false;
+  }
+
+  void RecvFromPeer(const string& peer_device, const string& peer_task,
+                    bool peer_is_local, const string& key, Device* to_device,
+                    DeviceContext* to_device_ctx,
+                    const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+                    const DeviceLocality& client_locality,
+                    const StatusCallback& done) override {
+    if (MaybeFail(done)) return;
+    CollectiveRemoteAccessLocal::RecvFromPeer(
+        peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
+        to_alloc_attr, to_tensor, client_locality, done);
+  }
+
+  void PostToPeer(const string& peer_device, const string& peer_task,
+                  const string& key, Device* from_device,
+                  DeviceContext* from_device_ctx,
+                  const AllocatorAttributes& from_alloc_attr,
+                  const Tensor* from_tensor,
+                  const DeviceLocality& client_locality,
+                  const StatusCallback& done) override {
+    if (MaybeFail(done)) return;
+    CollectiveRemoteAccessLocal::PostToPeer(
+        peer_device, peer_task, key, from_device, from_device_ctx,
+        from_alloc_attr, from_tensor, client_locality, done);
+  }
+
+  mutex mu_;
+  int fail_after_ GUARDED_BY(mu_);
+};
+
+class BroadcasterTest : public ::testing::Test {
+ protected:
+  BroadcasterTest() : device_type_(DEVICE_CPU) {}
+
+  ~BroadcasterTest() override {
+    stop_ = true;
+    for (auto i : instances_) {
+      delete i;
+    }
+    if (col_exec_) col_exec_->Unref();
+  }
+
+  void SetUp() override {
+#if GOOGLE_CUDA
+    auto device_factory = DeviceFactory::GetFactory("GPU");
+    CHECK(device_factory);
+    SessionOptions options;
+    Status s = device_factory->CreateDevices(
+        options, "/job:worker/replica:0/task:0", &gpu_devices_);
+    CHECK(s.ok());
+#endif
+  }
+
+  void Init(int num_workers, int num_devices, DataType dtype,
+            const DeviceType& device_type, int fail_after) {
+    device_type_ = device_type;
+    std::vector<Device*> local_devices;
+    SessionOptions sess_opts;
+    sess_opts.env = Env::Default();
+    Bytes mem_limit(4 << 20);
+    DeviceLocality dev_locality;
+    for (int wi = 0; wi < num_workers; ++wi) {
+      for (int di = 0; di < num_devices; ++di) {
+        if (device_type == DEVICE_CPU) {
+          string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
+                                            "/device:CPU:", di);
+          local_devices.push_back(new ThreadPoolDevice(
+              sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
+        } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
+          int dev_idx = (wi * num_devices) + di;
+          if (dev_idx >= static_cast<int>(gpu_devices_.size())) {
+            LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
+                         "than one ring node.";
+          } else {
+            local_devices.push_back(gpu_devices_[dev_idx]);
+          }
+        } else {
+          LOG(FATAL) << "Unsupported device_type " << device_type;
+        }
+      }
+    }
+    if (!dev_mgr_ || device_type == DEVICE_CPU) {
+      dev_mgr_.reset(new DeviceMgr(local_devices));
+    }
+    dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
+                           fail_after);
+    col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
+                                           dev_mgr_.get());
+    col_params_.name = "test_collective";
+    col_params_.instance.data_type = dtype;
+    static const int kGroupKey = 5;
+    col_params_.group.group_key = kGroupKey;
+    static const int kInstanceKey = 17;
+    col_params_.instance.instance_key = kInstanceKey;
+    col_params_.group.device_type = device_type;
+    col_params_.group.group_size = num_workers * num_devices;
+    col_params_.instance.impl_details.subdiv_offsets.clear();
+    col_params_.instance.type = BROADCAST_COLLECTIVE;
+    col_params_.instance.impl_details.subdiv_permutations.resize(kNumSubdivs);
+    col_params_.subdiv_rank.resize(kNumSubdivs);
+    int subdiv_stride = num_devices / kNumSubdivs;
+    for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
+      col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
+                                                                 subdiv_stride);
+      col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
+    }
+
+    // Set up a local device ring order that's not just 0,1,2...
+    std::vector<int> local_ring_order;
+    for (int di = 0; di < num_devices; ++di) {
+      local_ring_order.push_back(di);
+    }
+    for (int di = 0; di < num_devices; ++di) {
+      bool is_odd = ((di % 2) == 1);
+      int other = (di + (is_odd ? 7 : 3)) % num_devices;
+      if (di == other) continue;
+      iter_swap(local_ring_order.begin() + di,
+                local_ring_order.begin() + other);
+    }
+    broadcast_dev_id_ = local_ring_order[0];
+    string lro_buf;
+    for (auto d : local_ring_order) strings::StrAppend(&lro_buf, d, ", ");
+    VLOG(1) << "local_ring_order " << lro_buf;
+
+    // Set up all of the fake device contexts.
+    for (int wi = 0; wi < num_workers; ++wi) {
+      for (int di = 0; di < num_devices; ++di) {
+        string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
+        string dev_name = strings::StrCat(task_name, "/device:CPU:", di);
+        if (device_type == DEVICE_GPU) {
+          dev_name = strings::StrCat(task_name, "/device:GPU:0");
+        }
+        col_params_.instance.device_names.push_back(dev_name);
+        col_params_.instance.task_names.push_back(task_name);
+        // Normally each device would set is_local to its own perspective but
+        // this test runs in a single process so is_local is always true.
+        col_params_.task.is_local.push_back(true);
+        for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
+          int rotated_di =
+              (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
+              num_devices;
+          col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
+              wi * num_devices + local_ring_order[rotated_di]);
+        }
+      }
+    }
+    for (int wi = 0; wi < num_workers; ++wi) {
+      for (int di = 0; di < num_devices; ++di) {
+        int rank = wi * num_devices + di;
+        instances_.push_back(new DeviceInstance(
+            rank, col_params_.instance.device_names[rank], device_type_, this));
+      }
+    }
+  }
+
+  typedef std::function<void(Tensor*)> InitFunc;
+
+  void Broadcast() {
+    std::atomic<int> done(0);
+    for (auto di : instances_) {
+      SchedClosure([di, &done] {
+        di->DoBroadcast();
+        ++done;
+      });
+    }
+    while (done < instances_.size()) {
+      if (stop_) break;
+      Env::Default()->SleepForMicroseconds(1000);
+    }
+  }
+
+  std::unique_ptr<OpKernel> GetKernel(const NodeDef& node,
+                                      const DeviceType& device_type,
+                                      DeviceBase* device) {
+    Status status;
+    std::unique_ptr<OpKernel> k = CreateOpKernel(
+        device_type, device, device->GetAllocator(AllocatorAttributes()), node,
+        TF_GRAPH_DEF_VERSION, &status);
+    if (!status.ok()) {
+      LOG(FATAL) << status;
+    }
+    return k;
+  }
+
+  std::unique_ptr<OpKernel> GetCollectiveBcastSend(
+      const CollectiveParams& params, Tensor* input,
+      const DeviceType& device_type, DeviceBase* device) {
+    mutex_lock l(mu_);
+    NodeDef node_def;
+    NodeDefBuilder builder(
+        strings::StrCat("collective_bcast_send_", bcast_send_counter_++),
+        "CollectiveBcastSend");
+    TF_CHECK_OK(builder.Attr("T", input->dtype())
+                    .Attr("group_size", params.group.group_size)
+                    .Attr("group_key", params.group.group_key)
+                    .Attr("instance_key", params.instance.instance_key)
+                    .Attr("shape", input->shape())
+                    .Input(FakeInput(params.instance.data_type))
+                    .Finalize(&node_def));
+    return GetKernel(node_def, device_type, device);
+  }
+
+  std::unique_ptr<OpKernel> GetCollectiveBcastRecv(
+      const CollectiveParams& params, const TensorShape& shape,
+      const DeviceType& device_type, DeviceBase* device) {
+    mutex_lock l(mu_);
+    NodeDef node_def;
+    NodeDefBuilder builder(
+        strings::StrCat("collective_bcast_recv_", bcast_recv_counter_++),
+        "CollectiveBcastRecv");
+    TF_CHECK_OK(builder.Attr("T", params.instance.data_type)
+                    .Attr("group_size", params.group.group_size)
+                    .Attr("group_key", params.group.group_key)
+                    .Attr("instance_key", params.instance.instance_key)
+                    .Attr("shape", shape)
+                    .Finalize(&node_def));
+    return GetKernel(node_def, device_type, device);
+  }
+
+  void BuildColParams() {}
+
+  template <typename T>
+  void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
+               int num_devices, int tensor_len, int fail_after) {
+    Init(num_workers, num_devices, dtype, device_type, fail_after);
+
+    // Initialize each instance tensor with distinct values.
+    for (int di = 0; di < instances_.size(); ++di) {
+      DeviceInstance* instance = instances_[di];
+      instance->InitTensor(
+          dtype, TensorShape({tensor_len}), [di, dtype](Tensor* t) {
+            for (size_t i = 0; i < t->NumElements(); ++i) {
+              // The cast is necessary to prevent clang-tidy from insisting
+              // that a faster non-open source function be substituted.
+              float value = pow(10, static_cast<double>(di)) * i;
+              t->flat<T>()(i) = value;
+            }
+          });
+    }
+
+    // Copy the expected value from the broadcast source tensor
+    std::vector<T> expected(tensor_len, 0.0);
+    const CollectiveParams& cp = instances_[0]->col_params_;
+    int broadcast_dev_id =
+        cp.instance.impl_details.subdiv_permutations
+            [0][cp.instance.impl_details.subdiv_source_rank[0]];
+    const Tensor* t = &instances_[broadcast_dev_id]->tensor_;
+    Tensor cpu_copy(dtype, TensorShape({tensor_len}));
+    if (device_type == DEVICE_GPU) {
+      Notification notification;
+      Device* dev = instances_[broadcast_dev_id]->device_;
+      auto* dev_info = dev->tensorflow_gpu_device_info();
+      CHECK(dev_info);
+      dev_info->default_context->CopyDeviceTensorToCPU(
+          t, "" /*tensor_name*/, dev, &cpu_copy,
+          [this, &notification](Status s) {
+            TF_CHECK_OK(s);
+            notification.Notify();
+          });
+      notification.WaitForNotification();
+      t = &cpu_copy;
+    }
+    for (size_t i = 0; i < t->NumElements(); ++i) {
+      expected[i] = t->flat<T>()(i);
+    }
+
+    Broadcast();
+
+    // At this point all of the ops have terminated.
+    for (int di = 0; di < instances_.size(); ++di) {
+      if (!instances_[di]->status_.ok()) {
+        ASSERT_GT(fail_after, 0);
+        ASSERT_EQ(instances_[di]->status_.error_message(),
+                  "Deliberate failure");
+        mutex_lock l(mu_);
+        ++failure_count_;
+        continue;
+      }
+      Tensor* inst = &instances_[di]->tensor_;
+      Tensor actual(dtype, TensorShape({tensor_len}));
+      if (device_type_ == DEVICE_CPU) {
+        CHECK(actual.CopyFrom(*inst, inst->shape()));
+      } else if (device_type_ == DEVICE_GPU) {
+        Notification notification;
+        Device* dev = instances_[di]->device_;
+        auto* dev_info = dev->tensorflow_gpu_device_info();
+        CHECK(dev_info);
+        dev_info->default_context->CopyDeviceTensorToCPU(
+            inst, "" /*tensor_name*/, dev, &actual,
+            [this, &notification](Status s) {
+              TF_CHECK_OK(s);
+              notification.Notify();
+            });
+        notification.WaitForNotification();
+      }
+      for (int i = 0; i < tensor_len; ++i) {
+        switch (dtype) {
+          case DT_FLOAT:
+            EXPECT_FLOAT_EQ(expected[i], actual.template flat<T>()(i))
+                << "Mismatch at device " << di << " index " << i;
+            break;
+          case DT_DOUBLE:
+            EXPECT_DOUBLE_EQ(expected[i], actual.template flat<T>()(i))
+                << "Mismatch at device " << di << " index " << i;
+            break;
+          case DT_INT32:
+          case DT_INT64:
+            EXPECT_EQ(expected[i], actual.template flat<T>()(i))
+                << "Mismatch at device " << di << " index " << i;
+            break;
+          default:
+            LOG(FATAL) << "unimplemented";
+        }
+      }
+    }
+
+    // Note that the order of operations during broadcast is
+    // non-deterministic and unlike the reduce case some Ops in the
+    // instance may succeed while others fail, even if a transmission
+    // failure occurs early in the operation chain.  So, when an abort
+    // is specified we need to verify that at least one Op fails with
+    // the expected status and any Op that succeeds yeilds the correct
+    // value.
+    if (fail_after > 0) {
+      mutex_lock l(mu_);
+      EXPECT_GT(failure_count_, 0);
+    }
+  }
+
+  class DeviceInstance {
+   public:
+    DeviceInstance(int rank, const string& dev_name,
+                   const DeviceType& device_type, BroadcasterTest* parent)
+        : parent_(parent),
+          dev_name_(dev_name),
+          device_type_(device_type),
+          rank_(rank) {
+      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
+      col_params_.name = parent_->col_params_.name;
+      col_params_.instance.data_type = parent_->col_params_.instance.data_type;
+      col_params_.group.group_key = parent_->col_params_.group.group_key;
+      col_params_.instance.instance_key =
+          parent_->col_params_.instance.instance_key;
+      col_params_.group.device_type = parent_->col_params_.group.device_type;
+      col_params_.group.group_size = parent_->col_params_.group.group_size;
+      col_params_.instance.device_names =
+          parent_->col_params_.instance.device_names;
+      col_params_.instance.task_names =
+          parent_->col_params_.instance.task_names;
+      col_params_.task.is_local = parent_->col_params_.task.is_local;
+      col_params_.instance.impl_details.subdiv_permutations =
+          parent_->col_params_.instance.impl_details.subdiv_permutations;
+      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
+
+      int group_size = col_params_.group.group_size;
+      CHECK_EQ(group_size, col_params_.instance.device_names.size());
+      // Default rank is order in device_names.
+      col_params_.default_rank = rank;
+      // perm_rank is order in subdiv[0]:
+      int perm_rank = -1;
+      for (int i = 0;
+           i < col_params_.instance.impl_details.subdiv_permutations[0].size();
+           ++i) {
+        if (rank ==
+            col_params_.instance.impl_details.subdiv_permutations[0][i]) {
+          perm_rank = i;
+          break;
+        }
+      }
+      CHECK_GE(perm_rank, 0);
+      col_params_.instance.impl_details.subdiv_source_rank.resize(1, 0);
+      col_params_.is_source =
+          (perm_rank ==
+           col_params_.instance.impl_details.subdiv_source_rank[0]);
+      // Set rank in all subdivs by finding that default_rank.
+      for (int sdi = 0; sdi < kNumSubdivs; ++sdi) {
+        for (int r = 0;
+             r <
+             col_params_.instance.impl_details.subdiv_permutations[sdi].size();
+             ++r) {
+          if (col_params_.default_rank ==
+              col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
+            col_params_.subdiv_rank[sdi] = r;
+            CHECK_EQ(0, sdi);
+            CHECK_EQ(perm_rank, col_params_.subdiv_rank[sdi]);
+            break;
+          }
+        }
+      }
+      CHECK_EQ(group_size, col_params_.task.is_local.size());
+      CHECK_EQ(group_size, col_params_.instance.task_names.size());
+    }
+
+    void InitTensor(DataType dtype, const TensorShape& shape,
+                    const InitFunc& f) {
+      tensor_ =
+          Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
+      if (device_type_ == DEVICE_CPU) {
+        f(&tensor_);
+      } else if (device_type_ == DEVICE_GPU) {
+        Tensor cpu_tensor(dtype, shape);
+        f(&cpu_tensor);
+        Notification notification;
+        auto* dev_info = device_->tensorflow_gpu_device_info();
+        CHECK(dev_info);
+        dev_info->default_context->CopyCPUTensorToDevice(
+            &cpu_tensor, device_, &tensor_, [this, &notification](Status s) {
+              TF_CHECK_OK(s);
+              notification.Notify();
+            });
+        notification.WaitForNotification();
+      } else {
+        LOG(FATAL) << "Unsupported device_type " << device_type_;
+      }
+    }
+
+    void DoBroadcast() {
+      // Prepare an OpKernelContext.
+      OpKernelContext::Params op_params;
+      op_params.step_id = parent_->step_id_;
+      op_params.device = device_;
+      gtl::InlinedVector<TensorValue, 4> inputs;
+      inputs.push_back(TensorValue(&tensor_));
+      op_params.inputs = &inputs;
+      gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
+          {AllocatorAttributes()});
+      op_params.input_alloc_attrs = &input_aa;
+      gtl::InlinedVector<DeviceContext*, 4> input_dc;
+      DeviceContext* dev_ctx = nullptr;
+      auto* dev_info = device_->tensorflow_gpu_device_info();
+      if (dev_info) {
+        dev_ctx = dev_info->default_context;
+        dev_ctx->Ref();
+      } else {
+        dev_ctx = new DeviceContext;
+      }
+      input_dc.push_back(dev_ctx);
+      op_params.input_device_contexts = &input_dc;
+      op_params.op_device_context = dev_ctx;
+      int forward_from[] = {0};
+      if (col_params_.is_source) {
+        op_params.forward_from_array = &forward_from[0];
+      }
+      AllocatorAttributes generic_alloc_attr;
+      op_params.output_attr_array = &generic_alloc_attr;
+      std::unique_ptr<OpKernel> op =
+          col_params_.is_source
+              ? parent_->GetCollectiveBcastSend(col_params_, &tensor_,
+                                                DEVICE_CPU, device_)
+              : parent_->GetCollectiveBcastRecv(col_params_, tensor_.shape(),
+                                                DEVICE_CPU, device_);
+      op_params.op_kernel = op.get();
+      OpKernelContext ctx(&op_params, 1);
+
+      Tensor* output_tensor_ptr = nullptr;
+      if (col_params_.is_source) {
+        TF_CHECK_OK(ctx.forward_input_or_allocate_output(
+            {0}, 0, tensor_.shape(), &output_tensor_ptr));
+      } else {
+        TF_CHECK_OK(
+            ctx.allocate_output(0, tensor_.shape(), &output_tensor_ptr));
+      }
+      CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
+
+      // Prepare a Broadcaster instance.
+      string exec_key =
+          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+      Broadcaster broadcaster(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
+                              &op_params, col_params_, exec_key, kStepId,
+                              output_tensor_ptr);
+
+      // Start execution in a threadpool then wait for completion.
+      Notification notification;
+      broadcaster.Run([this, &notification](Status s) {
+        status_ = s;
+        notification.Notify();
+      });
+      notification.WaitForNotification();
+      if (status_.ok()) {
+        CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+      }
+
+      dev_ctx->Unref();
+    }
+
+    BroadcasterTest* parent_;
+    string dev_name_;
+    DeviceType device_type_ = DEVICE_CPU;
+    int rank_;
+    Tensor tensor_;
+    Device* device_;
+    CollectiveParams col_params_;
+    std::unique_ptr<CollectiveAdapter> ca_;
+    std::unique_ptr<OpKernelContext> ctx_;
+    Status status_;
+  };  // class DeviceInstance
+
+  bool stop_ = false;
+  int64 step_id_ = kStepId;
+  int broadcast_dev_id_ = 0;
+  DeviceType device_type_;
+  TestCollectiveExecutorMgr col_exec_mgr_;
+  CollectiveExecutor* col_exec_ = nullptr;
+  CollectiveRemoteAccessLocal* rma_;
+  std::unique_ptr<DeviceResolverLocal> dev_resolver_;
+  std::vector<DeviceInstance*> instances_;
+  CollectiveParams col_params_;
+  std::vector<tensorflow::Device*> gpu_devices_;
+  std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
+  mutex mu_;
+  int bcast_recv_counter_ GUARDED_BY(mu_) = 0;
+  int bcast_send_counter_ GUARDED_BY(mu_) = 0;
+  int failure_count_ GUARDED_BY(mu_) = 0;
+};
+
+// Tests of full broadcast algorithm, with different device and
+// data types.
+// B = data element type
+// T = device type
+// W = number of workers
+// D = number of devices per worker
+// L = tensor length
+// A = abort after count
+#define DEF_TEST(B, T, W, D, L, A)                                 \
+  TEST_F(BroadcasterTest,                                          \
+         DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A) { \
+    DataType dtype = DT_##B;                                       \
+    switch (dtype) {                                               \
+      case DT_FLOAT: {                                             \
+        RunTest<float>(dtype, DEVICE_##T, W, D, L, A);             \
+      } break;                                                     \
+      case DT_DOUBLE: {                                            \
+        RunTest<double>(dtype, DEVICE_##T, W, D, L, A);            \
+      } break;                                                     \
+      case DT_INT32: {                                             \
+        RunTest<int32>(dtype, DEVICE_##T, W, D, L, A);             \
+      } break;                                                     \
+      case DT_INT64: {                                             \
+        RunTest<int64>(dtype, DEVICE_##T, W, D, L, A);             \
+      } break;                                                     \
+      default:                                                     \
+        LOG(FATAL) << "Unimplemented";                             \
+    }                                                              \
+  }
+
+#ifndef GOOGLE_CUDA
+//       B      T    W  D  L  A
+DEF_TEST(FLOAT, CPU, 1, 2, 1, 0)
+DEF_TEST(FLOAT, CPU, 1, 2, 1001, 0)
+DEF_TEST(FLOAT, CPU, 2, 1, 128, 0)
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 0)
+DEF_TEST(FLOAT, CPU, 2, 8, 4095, 0)
+DEF_TEST(FLOAT, CPU, 4, 4, 1045991, 0)
+
+DEF_TEST(DOUBLE, CPU, 2, 4, 128, 0)
+DEF_TEST(INT32, CPU, 2, 4, 128, 0)
+DEF_TEST(INT64, CPU, 2, 4, 128, 0)
+
+// Failure cases
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 1)
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 5)
+#endif
+
+#ifdef GOOGLE_CUDA
+// Can only set W=1 for GPU tests.
+//       B      T    W  D  L  A
+DEF_TEST(FLOAT, GPU, 1, 2, 1, 0)
+DEF_TEST(FLOAT, GPU, 1, 2, 33, 0)
+DEF_TEST(FLOAT, GPU, 1, 3, 64, 0)
+DEF_TEST(FLOAT, GPU, 1, 8, 1001, 0)
+DEF_TEST(FLOAT, GPU, 1, 8, 4095, 0)
+DEF_TEST(FLOAT, GPU, 1, 8, 1045991, 0)
+
+DEF_TEST(DOUBLE, GPU, 1, 8, 1001, 0)
+DEF_TEST(INT64, GPU, 1, 8, 1001, 0)
+
+// Failure cases
+DEF_TEST(FLOAT, GPU, 1, 8, 128, 6)
+#endif
+
+}  // namespace
+}  // namespace tensorflow
index 393d3f8..bdddf92 100644 (file)
@@ -250,6 +250,38 @@ GlobalDeviceMap EstablishGlobalRank(
   return gdm;
 }
 
+// Count the devices associated with each task and set
+// cp->same_num_devices_per_task.  Requires cp->instance.task_names
+// be sorted.
+void SetDevPerTask(CollectiveParams* cp) {
+  cp->instance.same_num_devices_per_task = false;
+  if (cp->instance.task_names.empty()) return;
+  int dev_per_task = -1;
+  int count = 0;
+  const string* last_task_name = &cp->instance.task_names[0];
+  for (const string& task_name : cp->instance.task_names) {
+    if (task_name != *last_task_name) {
+      CHECK_GT(count, 0);
+      if (dev_per_task < 0) {
+        dev_per_task = count;
+      } else {
+        CHECK_GT(dev_per_task, 0);
+        if (count != dev_per_task) return;
+      }
+      count = 1;
+      last_task_name = &task_name;
+    } else {
+      ++count;
+    }
+  }
+  CHECK_GT(count, 0);
+  if ((dev_per_task > 0) && (count != dev_per_task)) {
+    return;
+  }
+  cp->instance.same_num_devices_per_task = true;
+  CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
+}
+
 // Sort cp->instance.device_names lexicographically, but do by first
 // computing a reordering permutation so we can keep cp->instance.task_names
 // in corresponding order.
@@ -278,6 +310,7 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
   cp->instance.device_names = std::move(new_devs);
   cp->instance.task_names = std::move(new_tasks);
   VLOG(1) << "Modified device_names on " << cp;
+  SetDevPerTask(cp);
 }
 
 // Establish the requested number of subdivision permutations based on the
@@ -343,17 +376,18 @@ void GenerateSubdivPerms(const string& device, int source_rank,
 
   if (cp->instance.type == BROADCAST_COLLECTIVE) {
     CHECK_GE(source_rank, 0);
-    cp->subdiv_source_rank.resize(
+    cp->instance.impl_details.subdiv_source_rank.resize(
         cp->instance.impl_details.subdiv_offsets.size(), -1);
-    for (int sdi = 0; sdi < cp->subdiv_source_rank.size(); ++sdi) {
+    for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_source_rank.size();
+         ++sdi) {
       for (int j = 0; j < cp->group.group_size; ++j) {
         if (cp->instance.impl_details.subdiv_permutations[sdi][j] ==
             source_rank) {
-          cp->subdiv_source_rank[sdi] = j;
+          cp->instance.impl_details.subdiv_source_rank[sdi] = j;
           break;
         }
       }
-      CHECK_GE(cp->subdiv_source_rank[sdi], 0);
+      CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sdi], 0);
     }
   }
 
index 4e3c712..4e33c47 100644 (file)
@@ -91,9 +91,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
       EXPECT_TRUE(cps[i].task.is_local[j]);
     }
     EXPECT_EQ(cps[i].subdiv_rank[0], i);
-    EXPECT_EQ(cps[i].subdiv_source_rank.size(), 0);
+    EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
     EXPECT_FALSE(cps[i].is_source);
     EXPECT_EQ(cps[i].default_rank, i);
+    EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
   }
 }
 
@@ -138,10 +139,11 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
     }
     ASSERT_GT(cps[i].subdiv_rank.size(), 0);
     EXPECT_EQ(cps[i].subdiv_rank[0], i);
-    ASSERT_GT(cps[i].subdiv_source_rank.size(), 0);
-    EXPECT_EQ(cps[i].subdiv_source_rank[0], 1);
+    ASSERT_GT(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
+    EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank[0], 1);
     EXPECT_EQ(cps[i].is_source, (i == 1));
     EXPECT_EQ(cps[i].default_rank, i);
+    EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
   }
 }
 
index d25dd5f..716e23b 100644 (file)
@@ -67,6 +67,8 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
     dev_resolver_->ClearTask(task);
   }
 
+  BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
+
   // Copy utility that always copies bytes from src to dst even if
   // they are on the same device, unlike CopyTensor::ViaDMA which will
   // just change the dst buffer pointer in that case.
index a26f2c2..d4ac50c 100644 (file)
@@ -38,6 +38,7 @@ CollInstanceParams& CollInstanceParams::operator=(
     device_names.clear();
     device_names.assign(other.device_names.begin(), other.device_names.end());
     task_names.assign(other.task_names.begin(), other.task_names.end());
+    same_num_devices_per_task = other.same_num_devices_per_task;
     impl_details.subdiv_offsets.assign(
         other.impl_details.subdiv_offsets.begin(),
         other.impl_details.subdiv_offsets.end());
@@ -76,6 +77,13 @@ string CollInstanceParams::ToString() const {
     }
     strings::StrAppend(&v, "}");  // one subdiv
   }
+  if (!impl_details.subdiv_source_rank.empty()) {
+    strings::StrAppend(&v, " subdiv_source_rank={");
+    for (const auto& r : impl_details.subdiv_source_rank) {
+      strings::StrAppend(&v, r, ",");
+    }
+    strings::StrAppend(&v, "}");
+  }
   strings::StrAppend(&v, "}");  // all subdivs
   return v;
 }
@@ -98,13 +106,6 @@ string CollectiveParams::ToString() const {
   for (const auto& r : subdiv_rank) {
     strings::StrAppend(&v, r, ",");
   }
-  if (!subdiv_source_rank.empty()) {
-    strings::StrAppend(&v, " subdiv_rank={");
-    for (const auto& r : subdiv_source_rank) {
-      strings::StrAppend(&v, r, ",");
-    }
-    strings::StrAppend(&v, "}");
-  }
   strings::StrAppend(&v, "}}");
   return v;
 }
index 5810c7f..40d82ab 100644 (file)
@@ -79,6 +79,8 @@ struct CollInstanceParams {
   std::vector<string> device_names;
   // Task name prefix of corresponding device name.
   std::vector<string> task_names;
+  // True if every task has the same number of devices.
+  bool same_num_devices_per_task;
   CollImplDetails impl_details;
   string ToString() const;
   CollInstanceParams& operator=(const struct CollInstanceParams& other);
@@ -102,7 +104,6 @@ struct CollectiveParams {
   bool is_source;    // broadcast only
   // Rank of this device in each subdivision permutation.
   std::vector<int> subdiv_rank;
-  std::vector<int> subdiv_source_rank;
   std::unique_ptr<OpKernel> merge_op;  // reduction only
   std::unique_ptr<OpKernel> final_op;  // reduction only
   string ToString() const;
@@ -284,12 +285,14 @@ class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
 };
 
-// Interface of a helper object that provices a CollectiveExecutor with
+// Interface of a helper object that provides a CollectiveExecutor with
 // all of the remote access it needs.
 class CollectiveRemoteAccess : public PeerAccessInterface,
                                public DeviceResolverInterface {
  public:
   virtual ~CollectiveRemoteAccess() {}
+
+  virtual BufRendezvous* buf_rendezvous() = 0;
 };
 
 // A per-step version of CollectiveRemoteAccess that cleans up outstanding