From 8a247976484173059aedc17bfd8d770b8d1a70e1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Apr 2018 09:46:34 -0700 Subject: [PATCH] Collective Ops Part 3 BaseCollectiveExecutor and RingReducer. This change is part of a series of changes introducing infrastructure for collective ops and initial implementations of reduction and broadcast. PiperOrigin-RevId: 192624521 --- tensorflow/core/BUILD | 33 ++ .../common_runtime/base_collective_executor.cc | 257 +++++++++ .../core/common_runtime/base_collective_executor.h | 144 +++++ .../core/common_runtime/collective_executor_mgr.cc | 38 +- tensorflow/core/common_runtime/dma_helper.h | 3 + tensorflow/core/common_runtime/ring_reducer.cc | 542 ++++++++++++++++++ tensorflow/core/common_runtime/ring_reducer.h | 146 +++++ .../core/common_runtime/ring_reducer_test.cc | 606 +++++++++++++++++++++ .../common_runtime/test_collective_executor_mgr.h | 116 ++++ 9 files changed, 1851 insertions(+), 34 deletions(-) create mode 100644 tensorflow/core/common_runtime/base_collective_executor.cc create mode 100644 tensorflow/core/common_runtime/base_collective_executor.h create mode 100644 tensorflow/core/common_runtime/ring_reducer.cc create mode 100644 tensorflow/core/common_runtime/ring_reducer.h create mode 100644 tensorflow/core/common_runtime/ring_reducer_test.cc create mode 100644 tensorflow/core/common_runtime/test_collective_executor_mgr.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 55b0040..1189552 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1064,6 +1064,7 @@ cc_library( hdrs = [ "common_runtime/function_testlib.h", "common_runtime/kernel_benchmark_testlib.h", + "common_runtime/test_collective_executor_mgr.h", "framework/fake_input.h", "framework/function_testlib.h", "framework/shape_inference_testutil.h", @@ -2261,6 +2262,7 @@ tf_cuda_library( CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", + "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", "common_runtime/buf_rendezvous.h", "common_runtime/build_graph_options.h", @@ -2289,6 +2291,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", + "common_runtime/ring_reducer.h", "common_runtime/scoped_allocator.h", "common_runtime/scoped_allocator_mgr.h", "common_runtime/session_factory.h", @@ -2306,6 +2309,7 @@ tf_cuda_library( srcs = [ "common_runtime/accumulate_n_optimizer.cc", "common_runtime/allocator_retry.cc", + "common_runtime/base_collective_executor.cc", "common_runtime/bfc_allocator.cc", "common_runtime/buf_rendezvous.cc", "common_runtime/build_graph_options.cc", @@ -2336,6 +2340,7 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_reducer.cc", "common_runtime/scoped_allocator.cc", "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", @@ -3101,6 +3106,34 @@ tf_cc_test( ], ) +tf_cc_tests_gpu( + name = "ring_reducer_test", + size = "medium", + srcs = [ + "common_runtime/ring_reducer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cc_test_mkl( name = "mkl_runtime_tests", size = "small", diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc new file mode 100644 index 0000000..f6332fa --- /dev/null +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -0,0 +1,257 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/base_collective_executor.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/ring_reducer.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/str_util.h" + +#define VALUE_IN_DEBUG_STRING false + +namespace tensorflow { +/*static*/ +int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts, + int64 num_chunks) { + DCHECK_GT(num_chunks, 0); + int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks; + if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts; + if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) { + // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES + DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES); + return base_chunk_elts; + } + // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which + // must be a common multiple of the various atomic data types. + DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes) + << "total_elts=" << total_elts << " num_chunks=" << num_chunks + << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES + << " elt_bytes=" << elt_bytes; + // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES. + int64 chunk_bytes = base_chunk_elts * elt_bytes; + int64 diff = + (chunk_bytes < EIGEN_MAX_ALIGN_BYTES) + ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes) + : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES)); + CHECK_EQ(0, diff % elt_bytes); + base_chunk_elts += (diff / elt_bytes); + DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES)) + << "total_elts=" << total_elts << " num_chunks=" << num_chunks + << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES + << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes; + return base_chunk_elts; +} + +namespace { +template +class CollectiveAdapterImpl : public CollectiveAdapter { + public: + // Takes ownership of output and prepares to properly alias its chunks. + // Ownership is taken because the shape may temporarily change. + CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator) + : output_(std::move(*output)), + dt_(output_.dtype()), + old_shape_(output_.shape()), + num_chunks_(num_chunks), + allocator_(allocator), + total_elts_(output_.NumElements()), + chunk_elts_(AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)), + data_start_(reinterpret_cast(DMAHelper::base(&output_))), + data_end_(data_start_ + total_elts_) { + CHECK_GT(chunk_elts_, 0); + Flatten(); + } + + ~CollectiveAdapterImpl() override {} + + const Tensor& Value() const override { return output_; } + + // If necessary, flatten output. + void Flatten() { + if (old_shape_.dims() > 1) { + TensorShape new_shape = TensorShape({old_shape_.num_elements()}); + DMAHelper::UnsafeSetShape(&output_, new_shape); + } + } + + void ConsumeFinalValue(Tensor* output) override { + if (old_shape_ != output_.shape()) { + DMAHelper::UnsafeSetShape(&output_, old_shape_); + } + *output = std::move(output_); + } + + // Number of T elements in a particular chunk. + inline int64 ChunkElts(int i) const { + DCHECK_LT(i, num_chunks_); + const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_); + const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_); + return chunk_end - chunk_start; + } + + int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); } + + // Returns a new Tensor that aliases the required chunk. + Tensor ChunkAlias(int i) override { + int64 start = chunk_elts_ * i; + int64 num_elts = ChunkElts(i); + // If this chunk is empty the prior chunk might also be short + // so always take an empty slice from the front of the tensor + // to avoid an illegal offset check failure somewhere. + return (num_elts > 0) ? output_.Slice(start, start + num_elts) + : output_.Slice(0, 0); + } + + Tensor TempChunk(int i) const override { + AllocationAttributes empty; + return Tensor(allocator_, dt_, {ChunkElts(i)}, empty); + } + + string DebugString() const override { + return strings::StrCat( + "base addr ", reinterpret_cast(DMAHelper::base(&output_)), + " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts", + chunk_elts_, " value ", + VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : ""); + } + + string TBounds(const Tensor& t) const override { + int64 base_addr = reinterpret_cast(DMAHelper::base(&t)); + return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()), + ")"); + } + + Tensor Scalar(int v) const override { + Tensor t(dt_, TensorShape({})); + t.scalar()() = v; + return t; + } + + Tensor Scalar(Allocator* a) const override { + Tensor t(a, dt_, TensorShape({})); + return t; + } + + Tensor output_; + const DataType dt_; + const TensorShape old_shape_; + const int64 num_chunks_; + Allocator* allocator_; + const int64 total_elts_; + const int64 chunk_elts_; + const T* data_start_; + const T* data_end_; +}; + +} // namespace + +CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, + Allocator* allocator) { + switch (output->dtype()) { + case DT_FLOAT: + return new CollectiveAdapterImpl(output, num_chunks, allocator); + break; + case DT_DOUBLE: + return new CollectiveAdapterImpl(output, num_chunks, allocator); + break; + case DT_INT32: + return new CollectiveAdapterImpl(output, num_chunks, allocator); + break; + case DT_INT64: + return new CollectiveAdapterImpl(output, num_chunks, allocator); + break; + default: + LOG(FATAL) << "Unsupported type " << output->dtype() + << " to MakeCollectiveAdapter"; + return nullptr; + } +} + +BaseCollectiveExecutor::~BaseCollectiveExecutor() {} + +void BaseCollectiveExecutor::StartAbort(const Status& s) { + LOG(WARNING) << "BaseCollectiveExecutor::StartAbort " << s; + remote_access_->StartAbort(s); +} + +void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, + const CollectiveParams& col_params, + const string& exec_key, + StatusCallback done) { + const Tensor* input = &ctx->input(0); + Tensor* output = ctx->mutable_output(0); + string error; + switch (col_params.instance.type) { + case REDUCTION_COLLECTIVE: { + // TODO(tucker): support other reduction algorithms, + // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc. + RingReducer* reducer = + CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_, + input, output, &error); + if (!reducer) { + done(errors::Internal(error)); + return; + } + // Run in an I/O thread, so as not to starve the executor threads. + // TODO(tucker): Instead of forking every per-device Collective + // Op off into its own thread, consider queuing them on a + // fixed-size thread-pool dedicated to running CollectiveOps. + SchedClosure([reducer, done]() { + reducer->Run([reducer, done](const Status& s) { + done(s); + delete reducer; + }); + }); + } break; + case BROADCAST_COLLECTIVE: + done(errors::Internal("Collective Broadcast unimplemented")); + break; + default: + done(errors::Internal("Unimplemented CollectiveType ", + col_params.instance.type)); + } +} + +RingReducer* BaseCollectiveExecutor::CreateReducer( + OpKernelContext* ctx, OpKernelContext::Params* params, + const CollectiveParams& col_params, const string& exec_key, int64 step_id, + const Tensor* input, Tensor* output, string* error) { + switch (col_params.instance.data_type) { + case DT_INT32: + if (col_params.group.device_type == DEVICE_GPU) { + *error = + "Collective Reduce does not support datatype DT_INT32 on " + "DEVICE_GPU"; + return nullptr; + } + TF_FALLTHROUGH_INTENDED; + case DT_FLOAT: + case DT_DOUBLE: + case DT_INT64: + return new RingReducer(this, dev_mgr_, ctx, params, col_params, exec_key, + step_id, input, output); + break; + default: + *error = strings::StrCat("Collective Reduce does not support datatype ", + col_params.instance.data_type); + return nullptr; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h new file mode 100644 index 0000000..58eaf31 --- /dev/null +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ + +#include +#include "tensorflow/core/common_runtime/buf_rendezvous.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" + +namespace tensorflow { +class DeviceMgr; +class RingReducer; + +// Helper interface that aliases regular subfields of a Tensor as separate +// Tensors for in-place update. +class CollectiveAdapter { + public: + virtual ~CollectiveAdapter() {} + + // Move the backing tensor to 'output' with its original storage and + // shape. After this call this CollectiveAdapter object should be + // deleted immediately without calling any of its other methods. + virtual void ConsumeFinalValue(Tensor* output) = 0; + + // const access to entire intermediate value for debugging + virtual const Tensor& Value() const = 0; + + // Returns tensor for chunk i which aliases the backing buffer. + virtual Tensor ChunkAlias(int i) = 0; + + // Returns tensor allocated on the same device but with its own + // separate backing buffer. Will have same type and size as + // chunk i. + virtual Tensor TempChunk(int i) const = 0; + + // Bytes in chunk i + virtual int64 ChunkBytes(int i) const = 0; + + // Generate a CPU RAM scalar tensor of the same DataType as the + // backing tensor with the given integer value. + virtual Tensor Scalar(int v) const = 0; + + // Generate a scalar tensor of same DataType and on the same device + // as the backing tensor. + virtual Tensor Scalar(Allocator* a) const = 0; + + // Debugging string describing buffer location + virtual string TBounds(const Tensor& t) const = 0; + + virtual string DebugString() const = 0; + + // Computes the number of elements per alias chunk tensor. + // + // A CHECK in tensor.cc expects that the memory buffer backing a + // Tensor will be aligned according to EIGEN_MAX_ALIGN_BYTES. To + // ensure that all chunk aliasing Tensors maintain this alignment we + // need to pick a chunk size that preserves it. Note than in extreme + // cases (impractical, but possible with very small tensors) one or + // more tail chunks can end up emptby. + static int64 AlignedChunkElts(int64 elt_bytes, int64 total_elts, + int64 num_chunks); +}; + +// Create a CollectiveAdaptor wrapping 'output', specialized to its +// data-type and shape. +CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, + Allocator* allocator); + +// Default implementation of CollectiveExecutor. Delegates the actual +// work of moving data to a class specialized for the operation type, +// arguments and device+interconnect topology. +class BaseCollectiveExecutor : public CollectiveExecutor { + public: + BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem, + PerStepCollectiveRemoteAccess* remote_access, + int64 step_id, const DeviceMgr* dev_mgr) + : CollectiveExecutor(cem), + step_id_(step_id), + dev_mgr_(dev_mgr), + remote_access_(remote_access) {} + + ~BaseCollectiveExecutor() override; + + void StartAbort(const Status& s) override; + + void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, + const string& exec_key, StatusCallback done) override; + + PerStepCollectiveRemoteAccess* remote_access() override { + return remote_access_.get(); + } + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override { + remote_access_->RecvFromPeer(peer_device, peer_task, peer_is_local, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, done); + } + + void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override { + remote_access_->PostToPeer(peer_device, peer_task, key, from_device, + from_device_ctx, from_alloc_attr, from_tensor, + client_locality, done); + } + + protected: + const int64 step_id_; + const DeviceMgr* dev_mgr_; // Not owned. + std::unique_ptr remote_access_; + + private: + RingReducer* CreateReducer(OpKernelContext* ctx, + OpKernelContext::Params* params, + const CollectiveParams& col_params, + const string& exec_key, int64 step_id, + const Tensor* input, Tensor* output, + string* error); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index a5c4946..e07829b 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/common_runtime/collective_rma_local.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -21,39 +22,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { -namespace { -// TODO(tucker): Temporary class just until a real CollectiveExecutor -// implementation is submitted in a later CL. -class DummyCollectiveExecutor : public CollectiveExecutor { - public: - explicit DummyCollectiveExecutor(CollectiveExecutorMgr* ce_mgr) - : CollectiveExecutor(ce_mgr) {} - - ~DummyCollectiveExecutor() override {} - - void RecvFromPeer(const string& peer_device, const string& peer_task, - bool peer_is_local, const string& key, Device* to_device, - DeviceContext* to_device_ctx, - const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, - const DeviceLocality& client_locality, - const StatusCallback& done) override { - done(errors::Internal("Unimplemented")); - } - - void PostToPeer(const string& peer_device, const string& peer_task, - const string& key, Device* from_device, - DeviceContext* from_device_ctx, - const AllocatorAttributes& from_alloc_attr, - const Tensor* from_tensor, - const DeviceLocality& client_locality, - const StatusCallback& done) override { - done(errors::Internal("Unimplemented")); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DummyCollectiveExecutor); -}; -} // namespace CollectiveExecutorMgr::CollectiveExecutorMgr( const ConfigProto& config, const DeviceMgr* dev_mgr, @@ -77,7 +45,9 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { if (it != executor_table_.end()) { ce = it->second; } else { - ce = new DummyCollectiveExecutor(this); + CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal( + dev_mgr_, dev_resolver_.get(), step_id); + ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); executor_table_[step_id] = ce; } ce->Ref(); diff --git a/tensorflow/core/common_runtime/dma_helper.h b/tensorflow/core/common_runtime/dma_helper.h index 1cc8b9e..cdfce1f 100644 --- a/tensorflow/core/common_runtime/dma_helper.h +++ b/tensorflow/core/common_runtime/dma_helper.h @@ -28,6 +28,9 @@ class DMAHelper { static void* base(Tensor* t) { return t->base(); } static TensorBuffer* buffer(Tensor* t) { return t->buf_; } static const TensorBuffer* buffer(const Tensor* t) { return t->buf_; } + static void UnsafeSetShape(Tensor* t, const TensorShape& s) { + t->set_shape(s); + } }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc new file mode 100644 index 0000000..79d03a2 --- /dev/null +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -0,0 +1,542 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/ring_reducer.h" + +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" + +// Set true for greater intelligibility of debug mode log messages. +#define READABLE_KEYS false + +namespace tensorflow { +namespace { +// Each CollectiveOp implementation is free to define its own +// BufRendezvous key format. This function produces the key used by +// RingReducer. +string RingReduceBufKey(const string& exec_key, int pass, int section, + int source_rank) { + if (READABLE_KEYS) { + return strings::StrCat("rred(", exec_key, "):pass(", pass, "):section(", + section, "):srcrank(", source_rank, ")"); + } else { + // TODO(tucker): Try out some kind of denser encoding, e.g. 128 bit hash. + return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank); + } +} + +} // namespace + +void RingReducer::PCQueue::Enqueue(RingField* rf) { + mutex_lock l(pcq_mu_); + deque_.push_back(rf); + if (waiter_count_ > 0) { + cv_.notify_one(); + } +} + +RingReducer::RingField* RingReducer::PCQueue::Dequeue() { + mutex_lock l(pcq_mu_); + if (deque_.empty()) { + ++waiter_count_; + while (deque_.empty()) { + cv_.wait(l); + } + --waiter_count_; + } + RingField* rf = deque_.front(); + deque_.pop_front(); + return rf; +} + +RingReducer::RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, + OpKernelContext* ctx, + OpKernelContext::Params* op_params, + const CollectiveParams& col_params, + const string& exec_key, int64 step_id, + const Tensor* input, Tensor* output) + : col_exec_(col_exec), + dev_mgr_(dev_mgr), + ctx_(ctx), + op_params_(op_params), + col_params_(col_params), + exec_key_(exec_key), + input_(input), + output_(output), + rank_(col_params.subdiv_rank[0]), + step_id_(step_id), + group_size_(col_params.group.group_size), + num_subdivs_(static_cast( + col_params.instance.impl_details.subdiv_permutations.size())), + done_(nullptr), + device_(nullptr), + device_name_( + col_params_.instance.device_names[col_params_.default_rank]) { + CHECK_GT(group_size_, 0); + CHECK_GT(num_subdivs_, 0); +} + +string RingReducer::TensorDebugString(Tensor tensor) { + const DeviceBase::GpuDeviceInfo* gpu_device_info = + ctx_->device()->tensorflow_gpu_device_info(); + if (gpu_device_info) { + Tensor cpu_tensor(tensor.dtype(), tensor.shape()); + Notification note; + gpu_device_info->default_context->CopyDeviceTensorToCPU( + &tensor, "" /*tensor_name*/, device_, &cpu_tensor, + [¬e](const Status& s) { + CHECK(s.ok()); + note.Notify(); + }); + note.WaitForNotification(); + return cpu_tensor.SummarizeValue(64); + } else { + return tensor.SummarizeValue(64); + } +} + +void RingReducer::Run(StatusCallback done) { + done_ = std::move(done); + + // Get local execution device. + if (VLOG_IS_ON(1)) { + string buf; + for (int r = 0; r < col_params_.instance.device_names.size(); ++r) { + strings::StrAppend(&buf, "dev ", r, " : ", + col_params_.instance.device_names[r], "\n"); + } + for (int sd = 0; + sd < col_params_.instance.impl_details.subdiv_permutations.size(); + ++sd) { + strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); + for (auto x : col_params_.instance.impl_details.subdiv_permutations[sd]) { + strings::StrAppend(&buf, x, ", "); + } + } + VLOG(1) << "RingReducer::Run for device " << device_name_ + << " default_rank " << col_params_.default_rank << "\n" + << buf; + } + CHECK(dev_mgr_); + Status status = dev_mgr_->LookupDevice( + col_params_.instance.device_names[col_params_.default_rank], &device_); + if (!status.ok()) { + LOG(ERROR) << "Failed to find device " + << col_params_.instance.device_names[col_params_.default_rank]; + for (auto d : dev_mgr_->ListDevices()) { + LOG(ERROR) << "Available device " << d->name(); + } + done_(status); + return; + } + CHECK(device_); + device_locality_ = device_->attributes().locality(); + + VLOG(1) << this << " default_rank " << col_params_.default_rank << " cp " + << &col_params_ << ": " << col_params_.ToString(); + + // Start by copying input to output if they're not already the same, i.e. if + // we're not computing in-place on the input tensor. + if ((input_ != output_) && + (DMAHelper::base(input_) != DMAHelper::base(output_))) { + CollectiveRemoteAccessLocal::MemCpyAsync( + ctx_->input_device_context(0), ctx_->op_device_context(), device_, + device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_, + output_, [this](const Status& s) { + if (!s.ok()) { + done_(s); + } else { + ContinueAfterInputCopy(); + } + }); + } else { + ContinueAfterInputCopy(); + } +} + +void RingReducer::ContinueAfterInputCopy() { + AllocatorAttributes attr = ctx_->output_alloc_attr(0); + ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_, + device_->GetAllocator(attr))); + + if (col_params_.final_op) { + // Create an on-device scalar value from group_size_ that may be needed + // later. + // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar + // can be provided to the kernel in host memory? + Tensor group_size_val = ca_->Scalar(group_size_); + if (col_params_.group.device_type != "CPU") { + group_size_tensor_ = + ca_->Scalar(device_->GetAllocator(ctx_->input_alloc_attr(0))); + DeviceContext* op_dev_ctx = ctx_->op_device_context(); + op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, device_, + &group_size_tensor_, + [this](const Status& s) { + if (!s.ok()) { + StartAbort(s); + } + group_size_tensor_ready_.Notify(); + }); + } else { + group_size_tensor_ = group_size_val; + group_size_tensor_ready_.Notify(); + } + } + Finish(RunAsyncParts()); +} + +void RingReducer::StartAbort(const Status& s) { + // In abort mode we stop issuing additional ProvideBuf + // and ConsumeBuf calls, but we need to wait for all of the + // outstanding callbacks to be invoked before quitting. + bool abort_started = false; + { + mutex_lock l(status_mu_); + if (status_.ok()) { + LOG(ERROR) << "Aborting RingReduce with " << s; + abort_started = true; + status_.Update(s); + } + } + // If this is the initial entry to abort mode then invoke StartAbort + // on the CollectiveExecutor that invoked us. That should start + // cancellation on all of the outstanding CollectiveRemoteAccess + // actions. + if (abort_started) { + col_exec_->StartAbort(s); + } +} + +void RingReducer::Finish(bool ok) { + if (ok) { + // Recover the output from the adaptor. + ca_->ConsumeFinalValue(output_); + } + Status s; + { + mutex_lock l(status_mu_); + s = status_; + } + done_(s); +} + +RingReducer::SubContext::SubContext(OpKernelContext* ctx, + OpKernelContext::Params* params, + OpKernel* op, Tensor* output, Tensor* input) + : sub_params_(*params), + sub_inputs_({output, input}), + sub_input_attr_({ctx->input_alloc_attr(0), ctx->input_alloc_attr(0)}), + sub_input_dc_( + {ctx->input_device_context(0), ctx->input_device_context(0)}) { + sub_params_.op_kernel = op; + sub_params_.inputs = &sub_inputs_; + sub_params_.input_alloc_attrs = &sub_input_attr_; + sub_params_.input_device_contexts = &sub_input_dc_; + sub_params_.eigen_gpu_device = nullptr; + sub_params_.ensure_eigen_gpu_device(); + sub_ctx_ = new OpKernelContext(&sub_params_, 1); +} + +Status RingReducer::ComputeBinOp(Device* device, OpKernel* op, Tensor* output, + Tensor* input) { + // Prepare an OpKernelContext that is identical to that of the original Op + // (i.e. the collective), except for the input output sizes and identities and + // the Op itself. + // TODO(tucker): Is it possible to cache and reuse these objects? They're + // mostly identical inside one device execution. + std::unique_ptr sub_ctx( + new SubContext(ctx_, op_params_, op, output, input)); + device->Compute(op, sub_ctx->sub_ctx_); + return sub_ctx->sub_ctx_->status(); +} + +// At the beginning of the algorithm initialize a RingField struct for +// every independent field of the tensor. +void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, + int field_idx) { + // Note on field indexing: There are group_size_ devices in the + // instance, implying the same number of chunks per tensor, where a + // chunk is the unit of data transferred in a time step. However, if + // a device can simultaenously send data by 2 or more independent + // channels we can speed up the transfer by subdividing chunks and + // processing multiple subdivisions at once. So the actual number + // of RingFields is group_size_ * num_subdivs_. + DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx); + rf->chunk_idx = chunk_idx; + rf->subdiv_idx = subdiv_idx; + rf->sc_idx = field_idx; + rf->rank = col_params_.subdiv_rank[subdiv_idx]; + rf->second_pass = false; + rf->action = RF_INIT; + // Recv from the device with preceding rank within the subdivision. + int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_; + int send_to_rank = (rf->rank + 1) % group_size_; + rf->recv_dev_idx = col_params_.instance.impl_details + .subdiv_permutations[subdiv_idx][recv_from_rank]; + int send_dev_idx = col_params_.instance.impl_details + .subdiv_permutations[subdiv_idx][send_to_rank]; + rf->recv_is_remote = !col_params_.task.is_local[rf->recv_dev_idx]; + rf->send_is_remote = !col_params_.task.is_local[send_dev_idx]; + if (ca_->ChunkBytes(rf->sc_idx) > 0) { + // In pass 0 we skip Recv when rank = chunk_idx + rf->do_recv = (rf->chunk_idx != rf->rank); + // In pass 0 we skip Send when rank = chunk_idx-1 + rf->do_send = + (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); + } + rf->is_final = + (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); + if (rf->do_send || rf->do_recv) { + rf->chunk = ca_->ChunkAlias(rf->sc_idx); + CHECK(rf->chunk.IsAligned()) << rf->DebugString(); + } + if (rf->do_recv) { + rf->tmp_chunk = ca_->TempChunk(rf->sc_idx); + CHECK(rf->tmp_chunk.IsAligned()) << rf->DebugString(); + } + VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk " + << ca_->TBounds(rf->chunk); +} + +// When a RingField transitions from first to second recompute the +// do_send and do_recv values. +void RingReducer::AdvanceToSecondPass(RingField* rf) { + VLOG(3) << "IncrRingField old value " << rf->DebugString(); + CHECK(!rf->second_pass); + rf->second_pass = true; + rf->action = RF_INIT; + if (ca_->ChunkBytes(rf->sc_idx) > 0) { + // In pass 1 the send/no-send boundary moves down 1 place. + rf->do_recv = + (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); + rf->do_send = + (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); + } + rf->is_final = + (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); + VLOG(3) << "IncrRingField new value " << rf->DebugString(); +} + +string RingReducer::RingField::DebugString() const { + string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx, + " subdiv=", subdiv_idx, " sc_idx=", sc_idx, + " action=", action); + strings::StrAppend(&rv, " pass=", second_pass); + strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv, + " is_final=", is_final, " recv_is_remote=", recv_is_remote, + " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx); + return rv; +} + +void RingReducer::DispatchSend(RingField* rf, const StatusCallback& done) { + CHECK(rf->do_send); + string send_buf_key = + RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, rf->rank); + VLOG(3) << "DispatchSend rank=" << col_params_.default_rank << " send key " + << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx " + << rf->sc_idx; + int send_to_rank = (rf->rank + 1) % group_size_; + int send_to_dev_idx = col_params_.instance.impl_details + .subdiv_permutations[rf->subdiv_idx][send_to_rank]; + col_exec_->PostToPeer(col_params_.instance.device_names[send_to_dev_idx], + col_params_.instance.task_names[send_to_dev_idx], + send_buf_key, device_, ctx_->op_device_context(), + ctx_->output_alloc_attr(0), &rf->chunk, + device_locality_, done); +} + +void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) { + CHECK(rf->do_recv); + string recv_buf_key = + RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, + (rf->rank + (group_size_ - 1)) % group_size_); + VLOG(3) << "DispatchRecv rank=" << col_params_.default_rank << " recv key " + << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into " + << ((col_params_.merge_op != nullptr) ? "tmp_chunk" : "chunk"); + Tensor* dst_tensor = (!rf->second_pass && (col_params_.merge_op != nullptr)) + ? &rf->tmp_chunk + : &rf->chunk; + col_exec_->RecvFromPeer(col_params_.instance.device_names[rf->recv_dev_idx], + col_params_.instance.task_names[rf->recv_dev_idx], + col_params_.task.is_local[rf->recv_dev_idx], + recv_buf_key, device_, ctx_->op_device_context(), + ctx_->output_alloc_attr(0), dst_tensor, + device_locality_, done); +} + +string RingReducer::FieldState() { + string s = strings::StrCat("RingReducer ", + strings::Hex(reinterpret_cast(this)), + " exec ", exec_key_, " step_id=", step_id_, + " state of all ", rfv_.size(), " fields:"); + for (int i = 0; i < rfv_.size(); ++i) { + s.append("\n"); + s.append(rfv_[i].DebugString()); + } + return s; +} + +bool RingReducer::RunAsyncParts() { + // This function orchestrates RingReduce actions on behalf of a + // single device. It is entered by a blockable thread that + // loops within it until all actions assigned to that device + // complete. Hence function local variables are accessible only by that + // one thread and do not require an explicit mutex. + rfv_.clear(); + rfv_.resize(group_size_ * num_subdivs_); + PCQueue ready_queue; + int field_done_count = 0; + int send_pending_count = 0; + int recv_pending_count = 0; + std::atomic aborted(false); + field_done_count = 0; + send_pending_count = 0; + recv_pending_count = 0; + for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) { + for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) { + int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx; + InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index); + ready_queue.Enqueue(&rfv_[rf_index]); + } + } + + // Loop until all RingFields have advanced to completion. + while (field_done_count < rfv_.size()) { + VLOG(4) << FieldState(); + // Wait for a RingField to appear in the ready_queue. + RingField* rf = ready_queue.Dequeue(); + // Advance the RingField to its next action and execute, repeating + // until either an async action has been started or the RingField + // is done. + bool dispatched = false; // true if async action was initiated + do { + if (aborted) break; + switch (rf->action) { + case RF_INIT: + if (rf->do_recv) { + rf->action = RF_RECV; + auto requeue = [this, rf, &ready_queue, &aborted](Status s) { + if (!s.ok()) { + aborted = true; + StartAbort(s); + } + ready_queue.Enqueue(rf); + }; + DispatchRecv(rf, requeue); + dispatched = true; + ++recv_pending_count; + } else { + rf->action = RF_SEND_READY; + } + break; + case RF_RECV: + CHECK_GT(recv_pending_count, 0); + --recv_pending_count; + if (!rf->second_pass) { + rf->action = RF_REDUCE; + Status s = ComputeBinOp(device_, col_params_.merge_op.get(), + &rf->chunk, &rf->tmp_chunk); + if (!s.ok()) { + aborted = true; + StartAbort(s); + } + } else { + rf->action = RF_SEND_READY; + } + break; + case RF_REDUCE: + if (!rf->second_pass && col_params_.final_op.get() && rf->is_final) { + rf->action = RF_FINALIZE; + group_size_tensor_ready_.WaitForNotification(); + Status s = ComputeBinOp(device_, col_params_.final_op.get(), + &rf->chunk, &group_size_tensor_); + if (!s.ok()) { + aborted = true; + StartAbort(s); + } + } else { + rf->action = RF_SEND_READY; + } + break; + case RF_FINALIZE: + rf->action = RF_DONE; + break; + case RF_SEND_READY: + if (rf->do_send) { + rf->action = RF_SEND; + auto send_complete = [this, rf, &ready_queue, &aborted](Status s) { + if (!s.ok()) { + aborted = true; + StartAbort(s); + } + ready_queue.Enqueue(rf); + }; + DispatchSend(rf, send_complete); + dispatched = true; + ++send_pending_count; + } else { + rf->action = RF_DONE; + } + break; + case RF_SEND: + CHECK_GT(send_pending_count, 0); + --send_pending_count; + rf->action = RF_DONE; + break; + case RF_DONE: + break; + } + if (rf->action == RF_DONE) { + if (rf->second_pass) { + ++field_done_count; + break; // from do while(!dispatched) + } else { + AdvanceToSecondPass(rf); + } + } + } while (!dispatched); + if (aborted) break; + } // while (field_done_count < number of fields) + + if (aborted) { + // All of the pending data actions should be aborted; field the + // callbacks and clear the queue before quitting. + while ((send_pending_count > 0) || (recv_pending_count > 0)) { + RingField* rf = ready_queue.Dequeue(); + switch (rf->action) { + case RF_RECV: + --recv_pending_count; + break; + case RF_SEND: + --send_pending_count; + break; + default: {} // Ignore any other actions + } + } + } + + CHECK_EQ(send_pending_count, 0); + CHECK_EQ(recv_pending_count, 0); + + VLOG(2) << this << " rank=" << rank_ << " finish;" + << " final value " << TensorDebugString(ca_->Value()); + return !aborted; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h new file mode 100644 index 0000000..8fde18d --- /dev/null +++ b/tensorflow/core/common_runtime/ring_reducer.h @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ + +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" + +namespace tensorflow { +class DeviceMgr; + +// Ring-algorithm implementation of collective all-reduce. +class RingReducer { + public: + RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, + OpKernelContext* ctx, OpKernelContext::Params* op_params, + const CollectiveParams& col_params, const string& exec_key, + int64 step_id, const Tensor* input, Tensor* output); + + virtual ~RingReducer() {} + + void Run(StatusCallback done); + + private: + // Called when a bad status is received that implies we should terminate + // execution and return a bad status. + void StartAbort(const Status& s); + void ContinueAfterInputCopy(); + void Finish(bool ok); + Status ComputeBinOp(Device* device, OpKernel* op, Tensor* output, + Tensor* input); + bool RunAsyncParts(); + + // Used for executing a sub-operation, e.g. a merge_op instance, with + // an OpKernelContext based on the one passed into this Op. + class SubContext { + public: + OpKernelContext::Params sub_params_; + gtl::InlinedVector sub_inputs_; + gtl::InlinedVector sub_input_attr_; + gtl::InlinedVector sub_input_dc_; + // Used only for Binary and Unary Ops for which we require + // the calculation to be in-place on the first input. + int forward_from_ = 0; + OpKernelContext* sub_ctx_; + SubContext(OpKernelContext* ctx, OpKernelContext::Params* params, + OpKernel* op, Tensor* output, Tensor* input); + ~SubContext() { delete sub_ctx_; } + }; + + // Current status of a RingField + enum RingFieldAction { + RF_INIT = 0, // Just initialized for a pass + RF_RECV, // Recv pending + RF_REDUCE, // Reduce pending + RF_FINALIZE, // FinalOp pending + RF_SEND_READY, // Ready to send + RF_SEND, // Send pending + RF_DONE, // No more work + }; + + // Tracks progress of actions on a single subfield of the entire tensor. + struct RingField { + int16 chunk_idx; // major division index + int16 subdiv_idx; // minor division index + int16 sc_idx; // subchunk index + int16 rank; // rank within subdiv permutation + int16 recv_dev_idx; // dev from which value should be recv'd + RingFieldAction action; + bool second_pass; + bool recv_is_remote = false; + bool send_is_remote = false; + bool do_send = false; // is the value sent in this pass? + bool do_recv = false; // is the value recv'd in this pass? + bool is_final = false; // is the last field in the pass for this rank + Tensor chunk; // alias to field values + Tensor tmp_chunk; + Status status; + string DebugString() const; + }; + void AdvanceToSecondPass(RingField* rf); + void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, + int field_idx); + void DispatchSend(RingField* rf, const StatusCallback& done); + void DispatchRecv(RingField* rf, const StatusCallback& done); + + // For constructing log messages for debugging. + string FieldState(); + string TensorDebugString(Tensor tensor); + + // Producer/Consumer Queue of RingField structs. + class PCQueue { + public: + void Enqueue(RingField* rf); + RingField* Dequeue(); + + private: + mutex pcq_mu_; + condition_variable cv_; + int waiter_count_ GUARDED_BY(pcq_mu_) = 0; + std::deque deque_ GUARDED_BY(pcq_mu_); + }; + + CollectiveExecutor* col_exec_; // Not owned + const DeviceMgr* dev_mgr_; // Not owned + OpKernelContext* ctx_; // Not owned + OpKernelContext::Params* op_params_; // Not owned + const CollectiveParams& col_params_; + const string exec_key_; + const Tensor* input_; // Not owned + Tensor* output_; // Not owned + const int rank_; + const int64 step_id_; + const int group_size_; + const int num_subdivs_; + Tensor group_size_tensor_; + Notification group_size_tensor_ready_; + std::unique_ptr ca_; + StatusCallback done_; + Device* device_; // The device for which this instance labors + const string device_name_; + DeviceLocality device_locality_; + + mutex status_mu_; + Status status_ GUARDED_BY(status_mu_); + + std::vector rfv_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc new file mode 100644 index 0000000..e4387a0 --- /dev/null +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -0,0 +1,606 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/ring_reducer.h" + +#include +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_resolver_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { + +// Wraps CollectiveRemoteAccessLocal with the ability to return an +// error status to the N'th action. +class FailTestRMA : public CollectiveRemoteAccessLocal { + public: + FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver, + int64 step_id, int fail_after) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + fail_after_(fail_after) {} + + bool MaybeFail(const StatusCallback& done) { + bool fail_now = false; + { + mutex_lock l(mu_); + if (fail_after_ > 0) { + fail_now = (--fail_after_ == 0); + } + } + if (fail_now) { + done(errors::Internal("Deliberate failure")); + return true; + } + return false; + } + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override { + if (MaybeFail(done)) return; + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, done); + } + + void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override { + if (MaybeFail(done)) return; + CollectiveRemoteAccessLocal::PostToPeer( + peer_device, peer_task, key, from_device, from_device_ctx, + from_alloc_attr, from_tensor, client_locality, done); + } + + mutex mu_; + int fail_after_ GUARDED_BY(mu_); +}; + +std::unique_ptr GetKernel(const NodeDef& node, + const DeviceType& device_type, + DeviceBase* device) { + Status status; + std::unique_ptr k = CreateOpKernel( + device_type, device, device->GetAllocator(AllocatorAttributes()), node, + TF_GRAPH_DEF_VERSION, &status); + if (!status.ok()) { + LOG(FATAL) << status; + } + return k; +} + +std::unique_ptr GetAdd(DataType dtype, const DeviceType& device_type, + DeviceBase* device) { + NodeDef node_def; + NodeDefBuilder builder("add_node", "Add"); + TF_CHECK_OK(builder.Attr("T", dtype) + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Finalize(&node_def)); + return GetKernel(node_def, device_type, device); +} + +std::unique_ptr GetDiv(DataType dtype, const DeviceType& device_type, + DeviceBase* device) { + NodeDef node_def; + NodeDefBuilder builder("add_node", "Div"); + TF_CHECK_OK(builder.Attr("T", dtype) + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Finalize(&node_def)); + return GetKernel(node_def, device_type, device); +} + +static int64 kStepId = 123; + +class RingReducerTest : public ::testing::Test { + protected: + RingReducerTest() : device_type_(DEVICE_CPU) {} + + void SetUp() override { +#if GOOGLE_CUDA + auto device_factory = DeviceFactory::GetFactory("GPU"); + CHECK(device_factory); + SessionOptions options; + Status s = device_factory->CreateDevices( + options, "/job:worker/replica:0/task:0", &gpu_devices_); + CHECK(s.ok()); +#endif + } + + ~RingReducerTest() override { + stop_ = true; + for (auto i : instances_) { + delete i; + } + if (col_exec_) col_exec_->Unref(); + } + + void Init(int num_workers, int num_devices, DataType dtype, + const DeviceType& device_type, int num_subdivs, int fail_after) { + device_type_ = device_type; + std::vector local_devices; + SessionOptions sess_opts; + sess_opts.env = Env::Default(); + Bytes mem_limit(4 << 20); + DeviceLocality dev_locality; + for (int wi = 0; wi < num_workers; ++wi) { + for (int di = 0; di < num_devices; ++di) { + if (device_type == DEVICE_CPU) { + string dev_name = + strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di); + local_devices.push_back(new ThreadPoolDevice( + sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator())); + } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) { + int dev_idx = (wi * num_devices) + di; + if (dev_idx >= static_cast(gpu_devices_.size())) { + LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more " + "than one ring node."; + } else { + local_devices.push_back(gpu_devices_[dev_idx]); + } + } else { + LOG(FATAL) << "Unsupported device_type " << device_type; + } + } + } + if (!dev_mgr_ || device_type == DEVICE_CPU) { + LOG(ERROR) << "resetting dev_mgr for " << local_devices.size() + << " devices: "; + dev_mgr_.reset(new DeviceMgr(local_devices)); + } + dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get())); + rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId, + fail_after); + col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, + dev_mgr_.get()); + col_params_.name = "test_collective"; + static const int kGroupKey = 5; + col_params_.group.group_key = kGroupKey; + col_params_.group.device_type = device_type; + col_params_.group.group_size = num_workers * num_devices; + static const int kInstanceKey = 17; + col_params_.instance.instance_key = kInstanceKey; + col_params_.instance.impl_details.subdiv_offsets.clear(); + col_params_.instance.type = REDUCTION_COLLECTIVE; + col_params_.instance.data_type = dtype; + col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); + col_params_.subdiv_rank.resize(num_subdivs); + int subdiv_stride = num_devices / num_subdivs; + for (int sdi = 0; sdi < num_subdivs; ++sdi) { + col_params_.instance.impl_details.subdiv_offsets.push_back(sdi * + subdiv_stride); + col_params_.subdiv_rank[sdi] = sdi * subdiv_stride; + } + + // Set up a local device ring order that's not just 0,1,2... + std::vector local_ring_order; + for (int di = 0; di < num_devices; ++di) { + local_ring_order.push_back(di); + } + for (int di = 0; di < num_devices; ++di) { + bool is_odd = ((di % 2) == 1); + int other = (di + (is_odd ? 7 : 3)) % num_devices; + if (di == other) continue; + iter_swap(local_ring_order.begin() + di, + local_ring_order.begin() + other); + } + string lro_buf; + for (auto d : local_ring_order) strings::StrAppend(&lro_buf, d, ", "); + VLOG(1) << "local_ring_order " << lro_buf; + + // Set up all of the fake device contexts. + for (int wi = 0; wi < num_workers; ++wi) { + for (int di = 0; di < num_devices; ++di) { + string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); + string dev_name = strings::StrCat(task_name, "/cpu:", di); + if (device_type == DEVICE_GPU) { + dev_name = + strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size()); + } + col_params_.instance.device_names.push_back(dev_name); + col_params_.instance.task_names.push_back(task_name); + // Normally each device would set is_local to its own perspective but + // this test runs in a single process so is_local is always true. + col_params_.task.is_local.push_back(true); + for (int sdi = 0; sdi < num_subdivs; ++sdi) { + int rotated_di = + (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) % + num_devices; + col_params_.instance.impl_details.subdiv_permutations[sdi].push_back( + wi * num_devices + local_ring_order[rotated_di]); + } + } + } + for (int wi = 0; wi < num_workers; ++wi) { + for (int di = 0; di < num_devices; ++di) { + int rank = wi * num_devices + di; + instances_.push_back(new DeviceInstance( + rank, col_params_.instance.device_names[rank], device_type_, this)); + } + } + } + + void Reduce() { + std::atomic done(0); + for (auto di : instances_) { + SchedClosure([di, &done] { + di->DoReduce(); + ++done; + }); + } + while (done < static_cast(instances_.size())) { + if (stop_) break; + Env::Default()->SleepForMicroseconds(1000); + } + } + + template + void RunTest(DataType dtype, const DeviceType& device_type, int num_workers, + int num_devices, int num_subdivs, int tensor_len, + int fail_after) { + Init(num_workers, num_devices, dtype, device_type, num_subdivs, fail_after); + std::vector expected(tensor_len, 0.0); + for (int di = 0; di < static_cast(instances_.size()); ++di) { + DeviceInstance* instance = instances_[di]; + instance->InitTensor( + dtype, TensorShape({tensor_len}), [&expected, dtype, di](Tensor* t) { + for (size_t i = 0; i < t->NumElements(); ++i) { + // The cast is necessary to prevent clang-tidy from insisting + // that a faster non-open source function be substituted. + float value = pow(10, static_cast(di)) * i; + if (dtype == DT_INT32 || dtype == DT_INT64) { + value = di * 10 + i; + } + t->flat()(i) = static_cast(value); + expected[i] += value; + } + }); + } + Reduce(); + if (fail_after > 0) { + // Confirm that every device terminated with the expected error status. + for (int di = 0; di < static_cast(instances_.size()); ++di) { + EXPECT_EQ("Deliberate failure", + instances_[di]->status_.error_message()); + } + } else { + // Confirm that every device computed the same correct reduction value. + for (int i = 0; i < tensor_len; ++i) { + expected[i] /= (num_workers * num_devices); + } + for (int di = 0; di < static_cast(instances_.size()); ++di) { + TF_EXPECT_OK(instances_[di]->status_); + Tensor* inst = &instances_[di]->tensor_; + CHECK(inst); + Tensor actual(dtype, TensorShape({tensor_len})); + if (device_type_ == DEVICE_CPU) { + CHECK(actual.CopyFrom(*inst, inst->shape())); + VLOG(1) << "actual " << actual.SummarizeValue(100); + } else if (device_type_ == DEVICE_GPU) { + Notification note; + Device* dev = instances_[di]->device_; + auto* dev_info = dev->tensorflow_gpu_device_info(); + CHECK(dev_info); + dev_info->default_context->CopyDeviceTensorToCPU( + inst, "" /*tensor_name*/, dev, &actual, [¬e](const Status& s) { + CHECK(s.ok()); + note.Notify(); + }); + note.WaitForNotification(); + } + + for (int i = 0; i < tensor_len; ++i) { + switch (dtype) { + case DT_FLOAT: + EXPECT_FLOAT_EQ(expected[i], actual.template flat()(i)) + << "Mismatch at device " << di << " index " << i; + break; + case DT_DOUBLE: + EXPECT_DOUBLE_EQ(expected[i], actual.template flat()(i)) + << "Mismatch at device " << di << " index " << i; + break; + case DT_INT32: + case DT_INT64: + EXPECT_EQ(expected[i], actual.template flat()(i)) + << "Mismatch at device " << di << " index " << i; + break; + default: + LOG(FATAL) << "unimplemented"; + } + } + } + } + } + + std::unique_ptr GetCollectiveReduce(const CollectiveParams& params, + Tensor* input, + const DeviceType& device_type, + DeviceBase* device) { + mutex_lock l(mu_); + NodeDef node_def; + NodeDefBuilder builder( + strings::StrCat("collective_reduce_", reduce_counter_++), + "CollectiveReduce"); + TF_CHECK_OK( + builder.Attr("T", params.instance.data_type) + .Attr("merge_op", "Add") + .Attr("final_op", "Id") + .Attr("group_size", params.group.group_size) + .Attr("group_key", params.group.group_key) + .Attr("instance_key", params.instance.instance_key) + .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets) + .Input(FakeInput(params.instance.data_type)) + .Finalize(&node_def)); + return GetKernel(node_def, device_type, device); + } + + class DeviceInstance { + public: + DeviceInstance(int rank, const string& dev_name, + const DeviceType& device_type, RingReducerTest* parent) + : parent_(parent), + dev_name_(dev_name), + device_type_(device_type), + rank_(rank) { + TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)) + << "Couldn't find device " << dev_name + << " existing devices: " << parent_->dev_mgr_->DebugString(); + col_params_.name = parent_->col_params_.name; + col_params_.group.group_key = parent_->col_params_.group.group_key; + col_params_.group.device_type = parent_->col_params_.group.device_type; + col_params_.group.group_size = parent_->col_params_.group.group_size; + col_params_.instance = parent->col_params_.instance; + col_params_.task.is_local = parent_->col_params_.task.is_local; + col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; + + int num_subdivs = static_cast(col_params_.subdiv_rank.size()); + int group_size = col_params_.group.group_size; + CHECK_EQ(group_size, + static_cast(col_params_.instance.device_names.size())); + // Id of this device is at rank position in first subdiv perm. + int my_device_id = + col_params_.instance.impl_details.subdiv_permutations[0][rank]; + col_params_.default_rank = my_device_id; + // Set rank for all other subdivs by finding that device_id. + for (int sdi = 0; sdi < num_subdivs; ++sdi) { + for (int r = 0; r < static_cast(col_params_.instance.impl_details + .subdiv_permutations[sdi] + .size()); + ++r) { + if (my_device_id == + col_params_.instance.impl_details.subdiv_permutations[sdi][r]) { + col_params_.subdiv_rank[sdi] = r; + break; + } + } + } + } + + void InitTensor(DataType dtype, const TensorShape& shape, + const std::function& init_f) { + tensor_ = + Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape); + if (device_type_ == DEVICE_CPU) { + init_f(&tensor_); + } else if (device_type_ == DEVICE_GPU) { + Tensor cpu_tensor(dtype, shape); + init_f(&cpu_tensor); + auto* dev_info = device_->tensorflow_gpu_device_info(); + CHECK(dev_info); + Notification note; + dev_info->default_context->CopyCPUTensorToDevice( + &cpu_tensor, device_, &tensor_, [¬e](const Status& s) { + CHECK(s.ok()); + note.Notify(); + }); + note.WaitForNotification(); + } else { + LOG(FATAL) << "Unsupported device_type " << device_type_; + } + } + + void DoReduce() { + col_params_.merge_op = + GetAdd(col_params_.instance.data_type, device_type_, device_); + col_params_.final_op = + GetDiv(col_params_.instance.data_type, device_type_, device_); + + // Prepare an OpKernelContext. + OpKernelContext::Params op_params; + op_params.step_id = kStepId; + op_params.device = device_; + gtl::InlinedVector inputs; + inputs.push_back(TensorValue(&tensor_)); + op_params.inputs = &inputs; + gtl::InlinedVector input_aa( + {AllocatorAttributes()}); + op_params.input_alloc_attrs = &input_aa; + gtl::InlinedVector input_dc; + DeviceContext* dev_ctx = nullptr; + auto* dev_info = device_->tensorflow_gpu_device_info(); + if (dev_info) { + dev_ctx = dev_info->default_context; + dev_ctx->Ref(); + } else { + dev_ctx = new DeviceContext; + } + input_dc.push_back(dev_ctx); + op_params.input_device_contexts = &input_dc; + op_params.op_device_context = dev_ctx; + int forward_from = 0; + op_params.forward_from_array = &forward_from; + AllocatorAttributes generic_alloc_attr; + op_params.output_attr_array = &generic_alloc_attr; + std::unique_ptr op = parent_->GetCollectiveReduce( + col_params_, &tensor_, DEVICE_CPU, device_); + op_params.op_kernel = op.get(); + OpKernelContext ctx(&op_params, 1); + + // We never actually execute the kernel, so we need to do the + // output allocation that it would do, ourselves. + Tensor* output_tensor_ptr = nullptr; + TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(), + &output_tensor_ptr)); + CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0)); + + // Prepare a RingReducer instance. + string exec_key = + strings::StrCat(col_params_.instance.instance_key, ":0:0"); + RingReducer rr(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, + &op_params, col_params_, exec_key, kStepId, &tensor_, + &tensor_); + + // Start execution in a threadpool then wait for completion. + Notification notification; + SchedClosure([this, ¬ification, &rr]() { + rr.Run([this, ¬ification](Status s) { + status_ = s; + notification.Notify(); + }); + }); + notification.WaitForNotification(); + CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape())); + + dev_ctx->Unref(); + } + + const Tensor& tensor() { return tensor_; } + + RingReducerTest* parent_; + string dev_name_; + DeviceType device_type_; + int rank_; + Tensor tensor_; + Device* device_; + CollectiveParams col_params_; + std::unique_ptr ca_; + std::unique_ptr ctx_; + Status status_; + }; + + bool stop_ = false; + DeviceType device_type_; + TestCollectiveExecutorMgr col_exec_mgr_; + CollectiveExecutor* col_exec_; + CollectiveRemoteAccessLocal* rma_; + std::unique_ptr dev_resolver_; + std::vector instances_; + CollectiveParams col_params_; + std::vector gpu_devices_; + std::unique_ptr dev_mgr_; + mutex mu_; + int32 reduce_counter_ GUARDED_BY(mu_) = 0; +}; + +#define DEF_TEST(B, T, W, D, S, L, A) \ + TEST_F(RingReducerTest, \ + DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \ + DataType dtype = DT_##B; \ + switch (dtype) { \ + case DT_FLOAT: { \ + RunTest(dtype, DEVICE_##T, W, D, S, L, A); \ + } break; \ + case DT_DOUBLE: { \ + RunTest(dtype, DEVICE_##T, W, D, S, L, A); \ + } break; \ + case DT_INT32: { \ + RunTest(dtype, DEVICE_##T, W, D, S, L, A); \ + } break; \ + case DT_INT64: { \ + RunTest(dtype, DEVICE_##T, W, D, S, L, A); \ + } break; \ + default: \ + LOG(FATAL) << "Unimplemented"; \ + } \ + } + +#ifndef GOOGLE_CUDA +// Success tests +DEF_TEST(FLOAT, CPU, 1, 2, 1, 1, 0) +DEF_TEST(FLOAT, CPU, 1, 2, 1, 2, 0) +DEF_TEST(FLOAT, CPU, 1, 2, 1, 8, 0) +DEF_TEST(FLOAT, CPU, 1, 2, 1, 16, 0) +DEF_TEST(FLOAT, CPU, 1, 2, 1, 1001, 0) +DEF_TEST(FLOAT, CPU, 2, 4, 1, 128, 0) +DEF_TEST(FLOAT, CPU, 2, 8, 1, 1001, 0) +DEF_TEST(FLOAT, CPU, 2, 8, 1, 4096, 0) +DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 0) +DEF_TEST(FLOAT, CPU, 2, 8, 3, 4095, 0) +DEF_TEST(FLOAT, CPU, 2, 8, 3, 1045991, 0) +DEF_TEST(FLOAT, CPU, 4, 4, 4, 1045991, 0) +DEF_TEST(DOUBLE, CPU, 1, 2, 1, 1001, 0) +DEF_TEST(DOUBLE, CPU, 2, 8, 3, 4095, 0) +DEF_TEST(INT32, CPU, 1, 2, 1, 1001, 0) +DEF_TEST(INT32, CPU, 2, 8, 3, 4095, 0) +DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0) +DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0) + +// Failure tests +DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7) +DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11) +#endif + +#ifdef GOOGLE_CUDA +// GPU tests. So long as the device names are all in a single tasks we +// bypass inter-worker routing code and can fake multiple GPUs with a single +// GPU, from the perspective of the RingReducer logic. So these tests +// are all single-worker. +DEF_TEST(FLOAT, GPU, 1, 2, 1, 1, 0) +DEF_TEST(FLOAT, GPU, 1, 2, 1, 2, 0) +DEF_TEST(FLOAT, GPU, 1, 2, 1, 8, 0) +DEF_TEST(FLOAT, GPU, 1, 2, 1, 16, 0) +DEF_TEST(FLOAT, GPU, 1, 2, 1, 1001, 0) +DEF_TEST(FLOAT, GPU, 1, 8, 1, 1001, 0) +DEF_TEST(FLOAT, GPU, 1, 8, 1, 4096, 0) +DEF_TEST(FLOAT, GPU, 1, 8, 3, 4095, 0) +DEF_TEST(FLOAT, GPU, 1, 8, 3, 1045991, 0) +DEF_TEST(FLOAT, GPU, 1, 4, 4, 1045991, 0) +DEF_TEST(DOUBLE, GPU, 1, 2, 1, 1001, 0) +// INT32 values are never on the GPU. +// DEF_TEST(INT32, GPU, 1, 2, 1, 1001, 0) +DEF_TEST(INT64, GPU, 1, 2, 1, 1001, 0) + +// Failure tests +DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2) +DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5) +#endif + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h new file mode 100644 index 0000000..d0d4f24 --- /dev/null +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// Mock objects that can't actually execute a Collective, but satisfy +// general infrastructure expectations within tests that don't require +// full functionality. + +class TestCollectiveExecutor : public CollectiveExecutor { + public: + explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem) + : CollectiveExecutor(cem) {} + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, //??? + const StatusCallback& done) override { + done(errors::Internal("Unimplemented")); + } + + void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + const StatusCallback& done) override { + done(errors::Internal("Unimplemented")); + } +}; + +class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { + public: + TestCollectiveExecutorMgr() {} + + ~TestCollectiveExecutorMgr() override { + for (auto& iter : table_) { + iter.second->Unref(); + } + } + + CollectiveExecutor* FindOrCreate(int64 step_id) override { + mutex_lock l(mu_); + CollectiveExecutor* ce = nullptr; + auto iter = table_.find(step_id); + if (iter != table_.end()) { + ce = iter->second; + } else { + ce = new TestCollectiveExecutor(this); + table_[step_id] = ce; + } + ce->Ref(); + return ce; + } + + void Cleanup(int64 step_id) override { + mutex_lock l(mu_); + auto iter = table_.find(step_id); + if (iter != table_.end()) { + iter->second->Unref(); + table_.erase(iter); + } + } + + ParamResolverInterface* GetParamResolver() const override { + LOG(FATAL); + return nullptr; + } + + DeviceResolverInterface* GetDeviceResolver() const override { + LOG(FATAL); + return nullptr; + } + + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) override { + done(errors::Internal("unimplemented")); + } + + void RefreshStepIdSequenceAsync(int64 graph_key, + const StatusCallback& done) override { + done(errors::Internal("unimplemented")); + } + + int64 NextStepId(int64 graph_key) override { + return CollectiveExecutor::kInvalidId; + } + + void RetireStepId(int64 graph_key, int64 step_id) override {} + + mutex mu_; + gtl::FlatMap table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ -- 2.7.4