From 9d2c6ff2a542b9bd89b42e3b88e6299eae9bdcc4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 22 May 2018 13:49:08 -0700 Subject: [PATCH] Collective Ops Part 7 Complete just enough of the core implementation to run multi-device collectives locally within a single process. Interfaces are still private and not availble for general use. PiperOrigin-RevId: 197617132 --- tensorflow/core/common_runtime/direct_session.cc | 27 +++++++- tensorflow/core/common_runtime/direct_session.h | 3 + tensorflow/core/common_runtime/executor.cc | 18 ++++- tensorflow/core/common_runtime/executor.h | 1 + tensorflow/core/common_runtime/function.cc | 3 + tensorflow/core/common_runtime/graph_runner.cc | 3 + tensorflow/core/framework/function.h | 2 + tensorflow/core/graph/graph.cc | 3 + tensorflow/core/graph/graph.h | 2 + tensorflow/core/kernels/function_ops.cc | 1 + tensorflow/core/protobuf/config.proto | 25 ++++++- tensorflow/core/protobuf/worker.proto | 8 +++ .../tensorflow.-config-proto.-experimental.pbtxt | 80 ++++++++++++++++++++++ .../api/golden/tensorflow.-config-proto.pbtxt | 8 +++ .../tensorflow.-run-options.-experimental.pbtxt | 80 ++++++++++++++++++++++ .../tools/api/golden/tensorflow.-run-options.pbtxt | 8 +++ 16 files changed, 268 insertions(+), 4 deletions(-) create mode 100644 tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0afbd02..07c1eaf 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -19,15 +19,19 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -443,6 +447,18 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, // Create a run state and start execution. RunState run_state(step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); + // Set up for collectives if the RunOption declares a key. + if (run_options.experimental().collective_graph_key() > 0) { + if (!collective_executor_mgr_) { + DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); + collective_executor_mgr_.reset(new CollectiveExecutorMgr( + options_.config, device_mgr_.get(), drl, + new CollectiveParamResolverLocal(device_mgr_.get(), drl, + "/job:localhost/replica:0/task:0"))); + } + run_state.collective_executor.reset(new CollectiveExecutor::Handle( + collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); + } // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); @@ -459,6 +475,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, args.step_id = step_id; args.call_frame = call_frame; args.rendezvous = run_state.rendez; + args.collective_executor = + (run_state.collective_executor ? run_state.collective_executor->get() + : nullptr); CancellationManager step_cancellation_manager; args.cancellation_manager = &step_cancellation_manager; args.session_state = &session_state_; @@ -768,6 +787,10 @@ Status DirectSession::PRunSetup(const std::vector& input_names, args.rendezvous = run_state->rendez; args.cancellation_manager = cancellation_manager_; + // Note that Collectives are not supported in partial runs + // because RunOptions is not passed in so we can't know whether + // their use is intended. + args.collective_executor = nullptr; args.runner = [this, pool](Executor::Args::Closure c) { SchedClosure(pool, std::move(c)); }; @@ -1518,11 +1541,13 @@ DirectSession::RunState::RunState( const std::vector& pending_input_names, const std::vector& pending_output_names, int64 step_id, const std::vector* devices) - : step_container(step_id, [devices](const string& name) { + : step_container(step_id, [devices, step_id](const string& name) { for (auto d : *devices) { if (!d->resource_manager()->Cleanup(name).ok()) { // Do nothing... } + ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); + if (sam) sam->Cleanup(step_id); } }) { // Initially all the feeds and fetches are pending. diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 6f9c1b9..72a2be4 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/session_state.h" #include "tensorflow/core/framework/tensor.h" @@ -175,6 +176,7 @@ class DirectSession : public Session { mutex mu_; Status status GUARDED_BY(mu_); IntraProcessRendezvous* rendez = nullptr; + std::unique_ptr collective_executor; std::unique_ptr collector; Notification executors_done; std::unordered_map pending_inputs; // true if fed @@ -352,6 +354,7 @@ class DirectSession : public Session { DirectSessionFactory* const factory_; // not owned CancellationManager* cancellation_manager_; + std::unique_ptr collective_executor_mgr_; // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 802bfee..585d777 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -592,7 +593,8 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { } } } - if (fwd_status.ok() && forward_from[i] == -1) { + if (fwd_status.ok() && + forward_from[i] == OpKernelContext::Params::kNoReservation) { DCHECK_EQ(forward_input.size() % 2, 0); for (int j = 0; j < forward_input.size(); j += 2) { if (forward_input[j + 1] == i) { @@ -770,7 +772,8 @@ void GraphView::SetScopedAllocatorAttrs( << use_node->name(); continue; } - // There should be exactly one output using ScopedAllocation. + // There can be more than one output using ScopedAllocation, but this + // analysis assumes they use the same ScopedAllocator. for (const auto& e : use_node->out_edges()) { if (!e->IsControlEdge()) { AllocatorAttributes attr; @@ -887,6 +890,11 @@ Status InferAllocAttr(const Node* n, const Node* dst, << " remote type " << parsed_dst_name.type; } } + if (n->IsCollective()) { + // We'll make the sweeping assumption that any collective op is going + // to be involved in network i/o. + attr->set_nic_compatible(true); + } return s; } @@ -1289,6 +1297,7 @@ class ExecutorState { int64 step_id_; // Not owned. Rendezvous* rendezvous_; + CollectiveExecutor* collective_executor_ = nullptr; SessionState* session_state_; TensorStore* tensor_store_; // Step-local container. @@ -1411,6 +1420,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) log_memory_(LogMemory::IsEnabled()), step_id_(args.step_id), rendezvous_(args.rendezvous), + collective_executor_(args.collective_executor), session_state_(args.session_state), tensor_store_(args.tensor_store), step_container_(args.step_container), @@ -1621,6 +1631,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.log_memory = log_memory_; params.record_tensor_accesses = impl_->device_record_tensor_accesses_; params.rendezvous = rendezvous_; + params.collective_executor = collective_executor_; params.session_state = session_state_; params.tensor_store = tensor_store_; params.cancellation_manager = cancellation_manager_; @@ -2180,6 +2191,9 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node, if (rendezvous_) { rendezvous_->StartAbort(s); } + if (collective_executor_) { + collective_executor_->StartAbort(s); + } if (cancellation_manager_) { cancellation_manager_->StartCancel(); } diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index adf80a2..e5d7b7c 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -89,6 +89,7 @@ class Executor { SessionState* session_state = nullptr; TensorStore* tensor_store = nullptr; ScopedStepContainer* step_container = nullptr; + CollectiveExecutor* collective_executor = nullptr; // If true, calls Sync() on the device. bool sync_on_finish = false; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index d05564e..5d9be70 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -809,6 +810,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, exec_args->cancellation_manager = run_opts.cancellation_manager; exec_args->step_container = run_opts.step_container; exec_args->runner = *run_opts.runner; + exec_args->collective_executor = run_opts.collective_executor; Item* item = nullptr; Status s = GetOrCreateItem(handle, &item); @@ -896,6 +898,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, exec_args->rendezvous = run_opts.rendezvous; exec_args->stats_collector = run_opts.stats_collector; exec_args->cancellation_manager = run_opts.cancellation_manager; + exec_args->collective_executor = run_opts.collective_executor; exec_args->step_container = run_opts.step_container; exec_args->runner = *run_opts.runner; exec_args->call_frame = frame; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index adf2ef6..0a1797f 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -176,6 +176,9 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID; args.runner = runner; args.rendezvous = rendez; + // NOTE: Use of graph runner is limited to single-device executions + // so a CollectiveExecutor should never be required. + args.collective_executor = nullptr; // Run the graph. TF_RETURN_IF_ERROR(executor->Run(args)); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e00399f..8729067 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -33,6 +33,7 @@ limitations under the License. namespace tensorflow { class CancellationManager; +class CollectiveExecutor; class GraphDef; class OpKernel; class ProcessFunctionLibraryRuntime; @@ -484,6 +485,7 @@ class FunctionLibraryRuntime { int64 step_id = 0; Rendezvous* rendezvous = nullptr; CancellationManager* cancellation_manager = nullptr; + CollectiveExecutor* collective_executor = nullptr; ScopedStepContainer* step_container = nullptr; StepStatsCollector* stats_collector = nullptr; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 71d0637..0f74851 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -80,6 +80,9 @@ const std::unordered_map& Node::kNodeClassTable = {"Shape", NC_METADATA}, {"Rank", NC_METADATA}, {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, + {"CollectiveReduce", NC_COLLECTIVE}, + {"CollectiveBcastSend", NC_COLLECTIVE}, + {"CollectiveBcastRecv", NC_COLLECTIVE}, }); #undef REF_CLASS diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 83a69e6..33fb7cb 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -163,6 +163,7 @@ class Node { bool IsHostSend() const { return class_ == NC_HOST_SEND; } bool IsHostRecv() const { return class_ == NC_HOST_RECV; } bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; } + bool IsCollective() const { return class_ == NC_COLLECTIVE; } bool IsMetadata() const { return class_ == NC_METADATA; } @@ -235,6 +236,7 @@ class Node { NC_DELETE_SESSION_TENSOR, NC_METADATA, NC_SCOPED_ALLOCATOR, + NC_COLLECTIVE, NC_OTHER // Not a special kind of node }; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 8f66f0a..f272473 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -254,6 +254,7 @@ class SymbolicGradientOp : public AsyncOpKernel { opts.runner = ctx->runner(); opts.stats_collector = ctx->stats_collector(); opts.step_container = ctx->step_container(); + opts.collective_executor = ctx->collective_executor(); std::vector args; args.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 6cd067a..410ad22 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -379,7 +379,17 @@ message ConfigProto { // shared with other sessions. bool isolate_session_state = 15; - // Next: 16 + // Everything inside Experimental is subject to change and is not subject + // to API stability guarantees in + // https://www.tensorflow.org/programmers_guide/version_compat. + message Experimental { + // Task name for group resolution. + string collective_group_leader = 1; + }; + + Experimental experimental = 16; + + // Next: 17 }; // Options for a single Run() call. @@ -414,6 +424,19 @@ message RunOptions { // Enabling this option can slow down the Run() call. bool report_tensor_allocations_upon_oom = 7; + // Everything inside Experimental is subject to change and is not subject + // to API stability guarantees in + // https://www.tensorflow.org/programmers_guide/version_compat. + message Experimental { + // If non-zero, declares that this graph is going to use collective + // ops and must synchronize step_ids with any other graph with this + // same group_key value (in a distributed computation where tasks + // run disjoint graphs). + int64 collective_graph_key = 1; + }; + + Experimental experimental = 8; + reserved 4; } diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 1cb84ca..b400638 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -122,6 +122,14 @@ message RegisterGraphRequest { // Field(s) used by TensorFlow Debugger (tfdbg). DebugOptions debug_options = 5; + + // If graph_def contains any collective ops this must be a positive + // integer used to coordinate execution with other graphs. All + // graphs in a distributed execution with the same + // collective_graph_key will coordinate to use the same step_id + // concurrently so that BufRendezvous entries will make the correct + // values accessible. + int64 collective_graph_key = 7; } message RegisterGraphResponse { diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt new file mode 100644 index 0000000..0a0669e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.ConfigProto.Experimental" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COLLECTIVE_GROUP_LEADER_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 009d64a..0d53d1c 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -27,6 +27,14 @@ tf_class { mtype: "" } member { + name: "EXPERIMENTAL_FIELD_NUMBER" + mtype: "" + } + member { + name: "Experimental" + mtype: "" + } + member { name: "Extensions" mtype: "" } diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt new file mode 100644 index 0000000..6a5e46a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.RunOptions.Experimental" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COLLECTIVE_GRAPH_KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt index 2f3e7f1..65e5588 100644 --- a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt @@ -11,6 +11,14 @@ tf_class { mtype: "" } member { + name: "EXPERIMENTAL_FIELD_NUMBER" + mtype: "" + } + member { + name: "Experimental" + mtype: "" + } + member { name: "Extensions" mtype: "" } -- 2.7.4