Collective Ops Part 7
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 May 2018 20:49:08 +0000 (13:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 20:51:22 +0000 (13:51 -0700)
Complete just enough of the core implementation to run
multi-device collectives locally within a single process.
Interfaces are still private and not availble for general use.

PiperOrigin-RevId: 197617132

16 files changed:
tensorflow/core/common_runtime/direct_session.cc
tensorflow/core/common_runtime/direct_session.h
tensorflow/core/common_runtime/executor.cc
tensorflow/core/common_runtime/executor.h
tensorflow/core/common_runtime/function.cc
tensorflow/core/common_runtime/graph_runner.cc
tensorflow/core/framework/function.h
tensorflow/core/graph/graph.cc
tensorflow/core/graph/graph.h
tensorflow/core/kernels/function_ops.cc
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/worker.proto
tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt

index 0afbd02..07c1eaf 100644 (file)
@@ -19,15 +19,19 @@ limitations under the License.
 #include <string>
 #include <vector>
 
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
 #include "tensorflow/core/common_runtime/constant_folding.h"
 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
 #include "tensorflow/core/common_runtime/executor.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/graph_optimizer.h"
 #include "tensorflow/core/common_runtime/memory_types.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
 #include "tensorflow/core/common_runtime/step_stats_collector.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb_text.h"
@@ -443,6 +447,18 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
   // Create a run state and start execution.
   RunState run_state(step_id, &devices_);
   run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+  // Set up for collectives if the RunOption declares a key.
+  if (run_options.experimental().collective_graph_key() > 0) {
+    if (!collective_executor_mgr_) {
+      DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+      collective_executor_mgr_.reset(new CollectiveExecutorMgr(
+          options_.config, device_mgr_.get(), drl,
+          new CollectiveParamResolverLocal(device_mgr_.get(), drl,
+                                           "/job:localhost/replica:0/task:0")));
+    }
+    run_state.collective_executor.reset(new CollectiveExecutor::Handle(
+        collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
+  }
 
   // Start parallel Executors.
   const size_t num_executors = executors_and_keys->items.size();
@@ -459,6 +475,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
   args.step_id = step_id;
   args.call_frame = call_frame;
   args.rendezvous = run_state.rendez;
+  args.collective_executor =
+      (run_state.collective_executor ? run_state.collective_executor->get()
+                                     : nullptr);
   CancellationManager step_cancellation_manager;
   args.cancellation_manager = &step_cancellation_manager;
   args.session_state = &session_state_;
@@ -768,6 +787,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
 
   args.rendezvous = run_state->rendez;
   args.cancellation_manager = cancellation_manager_;
+  // Note that Collectives are not supported in partial runs
+  // because RunOptions is not passed in so we can't know whether
+  // their use is intended.
+  args.collective_executor = nullptr;
   args.runner = [this, pool](Executor::Args::Closure c) {
     SchedClosure(pool, std::move(c));
   };
@@ -1518,11 +1541,13 @@ DirectSession::RunState::RunState(
     const std::vector<string>& pending_input_names,
     const std::vector<string>& pending_output_names, int64 step_id,
     const std::vector<Device*>* devices)
-    : step_container(step_id, [devices](const string& name) {
+    : step_container(step_id, [devices, step_id](const string& name) {
         for (auto d : *devices) {
           if (!d->resource_manager()->Cleanup(name).ok()) {
             // Do nothing...
           }
+          ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+          if (sam) sam->Cleanup(step_id);
         }
       }) {
   // Initially all the feeds and fetches are pending.
index 6f9c1b9..72a2be4 100644 (file)
@@ -33,6 +33,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/common_runtime/session_factory.h"
 #include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/session_state.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -175,6 +176,7 @@ class DirectSession : public Session {
     mutex mu_;
     Status status GUARDED_BY(mu_);
     IntraProcessRendezvous* rendez = nullptr;
+    std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
     std::unique_ptr<StepStatsCollector> collector;
     Notification executors_done;
     std::unordered_map<string, bool> pending_inputs;   // true if fed
@@ -352,6 +354,7 @@ class DirectSession : public Session {
 
   DirectSessionFactory* const factory_;  // not owned
   CancellationManager* cancellation_manager_;
+  std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
 
   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
   // is true, such as "params" and "queue" nodes.  Once placed these
index 802bfee..585d777 100644 (file)
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/control_flow.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
@@ -592,7 +593,8 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
           }
         }
       }
-      if (fwd_status.ok() && forward_from[i] == -1) {
+      if (fwd_status.ok() &&
+          forward_from[i] == OpKernelContext::Params::kNoReservation) {
         DCHECK_EQ(forward_input.size() % 2, 0);
         for (int j = 0; j < forward_input.size(); j += 2) {
           if (forward_input[j + 1] == i) {
@@ -770,7 +772,8 @@ void GraphView::SetScopedAllocatorAttrs(
                 << use_node->name();
         continue;
       }
-      // There should be exactly one output using ScopedAllocation.
+      // There can be more than one output using ScopedAllocation, but this
+      // analysis assumes they use the same ScopedAllocator.
       for (const auto& e : use_node->out_edges()) {
         if (!e->IsControlEdge()) {
           AllocatorAttributes attr;
@@ -887,6 +890,11 @@ Status InferAllocAttr(const Node* n, const Node* dst,
               << " remote type " << parsed_dst_name.type;
     }
   }
+  if (n->IsCollective()) {
+    // We'll make the sweeping assumption that any collective op is going
+    // to be involved in network i/o.
+    attr->set_nic_compatible(true);
+  }
   return s;
 }
 
@@ -1289,6 +1297,7 @@ class ExecutorState {
   int64 step_id_;
   // Not owned.
   Rendezvous* rendezvous_;
+  CollectiveExecutor* collective_executor_ = nullptr;
   SessionState* session_state_;
   TensorStore* tensor_store_;
   // Step-local container.
@@ -1411,6 +1420,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
       log_memory_(LogMemory::IsEnabled()),
       step_id_(args.step_id),
       rendezvous_(args.rendezvous),
+      collective_executor_(args.collective_executor),
       session_state_(args.session_state),
       tensor_store_(args.tensor_store),
       step_container_(args.step_container),
@@ -1621,6 +1631,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
   params.log_memory = log_memory_;
   params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
   params.rendezvous = rendezvous_;
+  params.collective_executor = collective_executor_;
   params.session_state = session_state_;
   params.tensor_store = tensor_store_;
   params.cancellation_manager = cancellation_manager_;
@@ -2180,6 +2191,9 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
     if (rendezvous_) {
       rendezvous_->StartAbort(s);
     }
+    if (collective_executor_) {
+      collective_executor_->StartAbort(s);
+    }
     if (cancellation_manager_) {
       cancellation_manager_->StartCancel();
     }
index adf80a2..e5d7b7c 100644 (file)
@@ -89,6 +89,7 @@ class Executor {
     SessionState* session_state = nullptr;
     TensorStore* tensor_store = nullptr;
     ScopedStepContainer* step_container = nullptr;
+    CollectiveExecutor* collective_executor = nullptr;
 
     // If true, calls Sync() on the device.
     bool sync_on_finish = false;
index d05564e..5d9be70 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/graph_optimizer.h"
 #include "tensorflow/core/common_runtime/memory_types.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/collective.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
@@ -809,6 +810,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   exec_args->cancellation_manager = run_opts.cancellation_manager;
   exec_args->step_container = run_opts.step_container;
   exec_args->runner = *run_opts.runner;
+  exec_args->collective_executor = run_opts.collective_executor;
 
   Item* item = nullptr;
   Status s = GetOrCreateItem(handle, &item);
@@ -896,6 +898,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   exec_args->rendezvous = run_opts.rendezvous;
   exec_args->stats_collector = run_opts.stats_collector;
   exec_args->cancellation_manager = run_opts.cancellation_manager;
+  exec_args->collective_executor = run_opts.collective_executor;
   exec_args->step_container = run_opts.step_container;
   exec_args->runner = *run_opts.runner;
   exec_args->call_frame = frame;
index adf2ef6..0a1797f 100644 (file)
@@ -176,6 +176,9 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
   args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
   args.runner = runner;
   args.rendezvous = rendez;
+  // NOTE: Use of graph runner is limited to single-device executions
+  // so a CollectiveExecutor should never be required.
+  args.collective_executor = nullptr;
 
   // Run the graph.
   TF_RETURN_IF_ERROR(executor->Run(args));
index e00399f..8729067 100644 (file)
@@ -33,6 +33,7 @@ limitations under the License.
 namespace tensorflow {
 
 class CancellationManager;
+class CollectiveExecutor;
 class GraphDef;
 class OpKernel;
 class ProcessFunctionLibraryRuntime;
@@ -484,6 +485,7 @@ class FunctionLibraryRuntime {
     int64 step_id = 0;
     Rendezvous* rendezvous = nullptr;
     CancellationManager* cancellation_manager = nullptr;
+    CollectiveExecutor* collective_executor = nullptr;
     ScopedStepContainer* step_container = nullptr;
     StepStatsCollector* stats_collector = nullptr;
 
index 71d0637..0f74851 100644 (file)
@@ -80,6 +80,9 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
         {"Shape", NC_METADATA},
         {"Rank", NC_METADATA},
         {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
+        {"CollectiveReduce", NC_COLLECTIVE},
+        {"CollectiveBcastSend", NC_COLLECTIVE},
+        {"CollectiveBcastRecv", NC_COLLECTIVE},
     });
 
 #undef REF_CLASS
index 83a69e6..33fb7cb 100644 (file)
@@ -163,6 +163,7 @@ class Node {
   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
+  bool IsCollective() const { return class_ == NC_COLLECTIVE; }
 
   bool IsMetadata() const { return class_ == NC_METADATA; }
 
@@ -235,6 +236,7 @@ class Node {
     NC_DELETE_SESSION_TENSOR,
     NC_METADATA,
     NC_SCOPED_ALLOCATOR,
+    NC_COLLECTIVE,
     NC_OTHER  // Not a special kind of node
   };
 
index 8f66f0a..f272473 100644 (file)
@@ -254,6 +254,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
     opts.runner = ctx->runner();
     opts.stats_collector = ctx->stats_collector();
     opts.step_container = ctx->step_container();
+    opts.collective_executor = ctx->collective_executor();
     std::vector<Tensor> args;
     args.reserve(ctx->num_inputs());
     for (int i = 0; i < ctx->num_inputs(); ++i) {
index 6cd067a..410ad22 100644 (file)
@@ -379,7 +379,17 @@ message ConfigProto {
   // shared with other sessions.
   bool isolate_session_state = 15;
 
-  // Next: 16
+  // Everything inside Experimental is subject to change and is not subject
+  // to API stability guarantees in
+  // https://www.tensorflow.org/programmers_guide/version_compat.
+  message Experimental {
+    // Task name for group resolution.
+    string collective_group_leader = 1;
+  };
+
+  Experimental experimental = 16;
+
+  // Next: 17
 };
 
 // Options for a single Run() call.
@@ -414,6 +424,19 @@ message RunOptions {
   // Enabling this option can slow down the Run() call.
   bool report_tensor_allocations_upon_oom = 7;
 
+  // Everything inside Experimental is subject to change and is not subject
+  // to API stability guarantees in
+  // https://www.tensorflow.org/programmers_guide/version_compat.
+  message Experimental {
+    // If non-zero, declares that this graph is going to use collective
+    // ops and must synchronize step_ids with any other graph with this
+    // same group_key value (in a distributed computation where tasks
+    // run disjoint graphs).
+    int64 collective_graph_key = 1;
+  };
+
+  Experimental experimental = 8;
+
   reserved 4;
 }
 
index 1cb84ca..b400638 100644 (file)
@@ -122,6 +122,14 @@ message RegisterGraphRequest {
 
   // Field(s) used by TensorFlow Debugger (tfdbg).
   DebugOptions debug_options = 5;
+
+  // If graph_def contains any collective ops this must be a positive
+  // integer used to coordinate execution with other graphs.  All
+  // graphs in a distributed execution with the same
+  // collective_graph_key will coordinate to use the same step_id
+  // concurrently so that BufRendezvous entries will make the correct
+  // values accessible.
+  int64 collective_graph_key = 7;
 }
 
 message RegisterGraphResponse {
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
new file mode 100644 (file)
index 0000000..0a0669e
--- /dev/null
@@ -0,0 +1,80 @@
+path: "tensorflow.ConfigProto.Experimental"
+tf_class {
+  is_instance: "<class \'tensorflow.core.protobuf.config_pb2.Experimental\'>"
+  is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
+  member {
+    name: "COLLECTIVE_GROUP_LEADER_FIELD_NUMBER"
+    mtype: "<type \'int\'>"
+  }
+  member {
+    name: "DESCRIPTOR"
+    mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
+  }
+  member {
+    name: "Extensions"
+    mtype: "<type \'getset_descriptor\'>"
+  }
+  member_method {
+    name: "ByteSize"
+  }
+  member_method {
+    name: "Clear"
+  }
+  member_method {
+    name: "ClearExtension"
+  }
+  member_method {
+    name: "ClearField"
+  }
+  member_method {
+    name: "CopyFrom"
+  }
+  member_method {
+    name: "DiscardUnknownFields"
+  }
+  member_method {
+    name: "FindInitializationErrors"
+  }
+  member_method {
+    name: "FromString"
+  }
+  member_method {
+    name: "HasExtension"
+  }
+  member_method {
+    name: "HasField"
+  }
+  member_method {
+    name: "IsInitialized"
+  }
+  member_method {
+    name: "ListFields"
+  }
+  member_method {
+    name: "MergeFrom"
+  }
+  member_method {
+    name: "MergeFromString"
+  }
+  member_method {
+    name: "ParseFromString"
+  }
+  member_method {
+    name: "RegisterExtension"
+  }
+  member_method {
+    name: "SerializePartialToString"
+  }
+  member_method {
+    name: "SerializeToString"
+  }
+  member_method {
+    name: "SetInParent"
+  }
+  member_method {
+    name: "WhichOneof"
+  }
+  member_method {
+    name: "__init__"
+  }
+}
index 009d64a..0d53d1c 100644 (file)
@@ -27,6 +27,14 @@ tf_class {
     mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
   }
   member {
+    name: "EXPERIMENTAL_FIELD_NUMBER"
+    mtype: "<type \'int\'>"
+  }
+  member {
+    name: "Experimental"
+    mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
+  }
+  member {
     name: "Extensions"
     mtype: "<type \'getset_descriptor\'>"
   }
diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt
new file mode 100644 (file)
index 0000000..6a5e46a
--- /dev/null
@@ -0,0 +1,80 @@
+path: "tensorflow.RunOptions.Experimental"
+tf_class {
+  is_instance: "<class \'tensorflow.core.protobuf.config_pb2.Experimental\'>"
+  is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
+  member {
+    name: "COLLECTIVE_GRAPH_KEY_FIELD_NUMBER"
+    mtype: "<type \'int\'>"
+  }
+  member {
+    name: "DESCRIPTOR"
+    mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
+  }
+  member {
+    name: "Extensions"
+    mtype: "<type \'getset_descriptor\'>"
+  }
+  member_method {
+    name: "ByteSize"
+  }
+  member_method {
+    name: "Clear"
+  }
+  member_method {
+    name: "ClearExtension"
+  }
+  member_method {
+    name: "ClearField"
+  }
+  member_method {
+    name: "CopyFrom"
+  }
+  member_method {
+    name: "DiscardUnknownFields"
+  }
+  member_method {
+    name: "FindInitializationErrors"
+  }
+  member_method {
+    name: "FromString"
+  }
+  member_method {
+    name: "HasExtension"
+  }
+  member_method {
+    name: "HasField"
+  }
+  member_method {
+    name: "IsInitialized"
+  }
+  member_method {
+    name: "ListFields"
+  }
+  member_method {
+    name: "MergeFrom"
+  }
+  member_method {
+    name: "MergeFromString"
+  }
+  member_method {
+    name: "ParseFromString"
+  }
+  member_method {
+    name: "RegisterExtension"
+  }
+  member_method {
+    name: "SerializePartialToString"
+  }
+  member_method {
+    name: "SerializeToString"
+  }
+  member_method {
+    name: "SetInParent"
+  }
+  member_method {
+    name: "WhichOneof"
+  }
+  member_method {
+    name: "__init__"
+  }
+}
index 2f3e7f1..65e5588 100644 (file)
@@ -11,6 +11,14 @@ tf_class {
     mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
   }
   member {
+    name: "EXPERIMENTAL_FIELD_NUMBER"
+    mtype: "<type \'int\'>"
+  }
+  member {
+    name: "Experimental"
+    mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
+  }
+  member {
     name: "Extensions"
     mtype: "<type \'getset_descriptor\'>"
   }