Collective Ops Part 2
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 17:56:29 +0000 (10:56 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 17:59:33 +0000 (10:59 -0700)
Kernel/Op defs for reduction and broadcast.

Note that kernels just set up CollectiveParams and don't
define detailed algorithms.

This change is part of a series of changes introducing infrastructure
for collective ops and initial implementations of reduction and broadcast.

PiperOrigin-RevId: 192151715

tensorflow/core/BUILD
tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt [new file with mode: 0644]
tensorflow/core/framework/collective.h
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/collective_ops.cc [new file with mode: 0644]
tensorflow/core/ops/collective_ops.cc [new file with mode: 0644]

index 6f2391c..5a0535f 100644 (file)
@@ -687,6 +687,7 @@ tf_gen_op_libs(
         "boosted_trees_ops",
         "candidate_sampling_ops",
         "checkpoint_ops",
+        "collective_ops",
         "control_flow_ops",
         "ctc_ops",
         "data_flow_ops",
@@ -803,6 +804,7 @@ cc_library(
         ":boosted_trees_ops_op_lib",
         ":candidate_sampling_ops_op_lib",
         ":checkpoint_ops_op_lib",
+        ":collective_ops_op_lib",
         ":control_flow_ops_op_lib",
         ":ctc_ops_op_lib",
         ":cudnn_rnn_ops_op_lib",
@@ -948,6 +950,7 @@ cc_library(
         "//tensorflow/core/kernels:boosted_trees_ops",
         "//tensorflow/core/kernels:candidate_sampler_ops",
         "//tensorflow/core/kernels:checkpoint_ops",
+        "//tensorflow/core/kernels:collective_ops",
         "//tensorflow/core/kernels:control_flow_ops",
         "//tensorflow/core/kernels:ctc_ops",
         "//tensorflow/core/kernels:cudnn_rnn_kernels",
@@ -2249,17 +2252,17 @@ tf_cuda_library(
 CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/allocator_retry.h",
     "common_runtime/bfc_allocator.h",
+    "common_runtime/buf_rendezvous.h",
+    "common_runtime/build_graph_options.h",
     "common_runtime/collective_executor_mgr.h",
     "common_runtime/collective_param_resolver_local.h",
     "common_runtime/collective_rma_local.h",
-    "common_runtime/device_resolver_local.h",
-    "common_runtime/buf_rendezvous.h",
-    "common_runtime/build_graph_options.h",
     "common_runtime/constant_folding.h",
     "common_runtime/copy_tensor.h",
     "common_runtime/costmodel_manager.h",
     "common_runtime/debugger_state_interface.h",
     "common_runtime/device_factory.h",
+    "common_runtime/device_resolver_local.h",
     "common_runtime/device_set.h",
     "common_runtime/dma_helper.h",
     "common_runtime/eigen_thread_pool.h",
@@ -2270,6 +2273,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/mkl_cpu_allocator.h",
     "common_runtime/optimization_registry.h",
     "common_runtime/pending_counts.h",
+    "common_runtime/placer.h",
     "common_runtime/process_util.h",
     "common_runtime/profile_handler.h",
     "common_runtime/renamed_device.h",
@@ -2278,7 +2282,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/scoped_allocator.h",
     "common_runtime/scoped_allocator_mgr.h",
     "common_runtime/session_factory.h",
-    "common_runtime/placer.h",
     "common_runtime/stats_publisher_interface.h",
     "common_runtime/step_stats_collector.h",
     "common_runtime/threadpool_device.h",
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
new file mode 100644 (file)
index 0000000..88049bc
--- /dev/null
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastRecv"
+  visibility: SKIP
+  summary: "Receives a tensor value broadcast from another device."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
new file mode 100644 (file)
index 0000000..7ff70f5
--- /dev/null
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveBcastSend"
+  visibility: SKIP
+  summary: "Broadcasts a tensor value to one or more other devices."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
new file mode 100644 (file)
index 0000000..10d9771
--- /dev/null
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "CollectiveReduce"
+  visibility: SKIP
+  summary: "Mutually reduces multiple tensors of identical type and shape."
+}
index 362d345..5810c7f 100644 (file)
@@ -103,11 +103,8 @@ struct CollectiveParams {
   // Rank of this device in each subdivision permutation.
   std::vector<int> subdiv_rank;
   std::vector<int> subdiv_source_rank;
-  const Tensor* in_tensor;             // kernel input
-  Tensor* out_tensor;                  // kernel output
   std::unique_ptr<OpKernel> merge_op;  // reduction only
   std::unique_ptr<OpKernel> final_op;  // reduction only
-  OpKernelContext* op_context;
   string ToString() const;
 };
 
index b931f79..1018e8d 100644 (file)
@@ -132,6 +132,17 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
+    name = "collective_ops",
+    prefix = "collective_ops",
+    deps = [
+        "//tensorflow/core:collective_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
+tf_kernel_library(
     name = "concat_lib",
     srcs = [
         "concat_lib_cpu.cc",
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
new file mode 100644 (file)
index 0000000..5de41ba
--- /dev/null
@@ -0,0 +1,266 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+class CollectiveOpKernel : public AsyncOpKernel {
+ public:
+  explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+  // A string encoding instance, frame and iter to be handed off to
+  // the implementation for use in generating RecvBuf keys.
+  string GetCollectiveKey(OpKernelContext* c) {
+    return strings::StrCat(col_params_.instance.instance_key, ":",
+                           c->frame_iter().frame_id, ":",
+                           c->frame_iter().iter_id);
+  }
+
+  // Returns false if calling invocation of ComputeAsync should return
+  // immediately.
+  bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
+                             const DoneCallback& done) {
+    if (col_params_.group.group_size >
+        col_params_.instance.device_names.size()) {
+      // This is the first invocation: Finish initializing col_params_.
+      // Call in a blockable thread because it's not guaranteed that
+      // this call cannot block.
+      c->env()->SchedClosure([this, c, done, col_exec]() {
+        col_exec->CompleteParamsAsync(c->device()->name(), &col_params_,
+                                      c->cancellation_manager(),
+                                      [this, c, done](const Status& s) {
+                                        if (s.ok()) {
+                                          ComputeAsync(c, done);
+                                        } else {
+                                          c->SetStatus(s);
+                                          done();
+                                        }
+                                      });
+      });
+      return false;
+    }
+    return true;
+  }
+
+  CollectiveParams col_params_;
+};
+
+class CollectiveReduceOpKernel : public CollectiveOpKernel {
+ public:
+  explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
+      : CollectiveOpKernel(c) {
+    col_params_.instance.type = REDUCTION_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+    OP_REQUIRES_OK(
+        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+    OP_REQUIRES_OK(
+        c, c->GetAttr("subdiv_offsets",
+                      &col_params_.instance.impl_details.subdiv_offsets));
+    string merge_op_name;
+    OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
+    OP_REQUIRES(c, merge_op_name == "Add" || merge_op_name == "Mul",
+                errors::InvalidArgument(
+                    "merge_op must be one of {\"Add\", \"Mul\"} but got ",
+                    merge_op_name));
+    string final_op_name;
+    OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
+    OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
+                errors::InvalidArgument(
+                    "final_op must be one of {\"Id\", \"Div\"} but got ",
+                    final_op_name));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+
+    const NodeDef& real_node = c->def();
+    col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
+                                       merge_op_name, ",", final_op_name, ")");
+    col_params_.group.device_type = c->device_type();
+
+    // Find the OpKernels by name, type and device type.
+    NodeDef sub_node;
+    // The merge_op takes two inputs
+    sub_node.add_input(real_node.input(0));
+    sub_node.add_input(real_node.input(0));
+    sub_node.set_device(real_node.device());
+    SetAttrValue(col_params_.instance.data_type,
+                 &(*sub_node.mutable_attr())["T"]);
+    col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
+    col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
+  }
+
+  std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
+                                          const string& name,
+                                          NodeDef* sub_node) {
+    std::unique_ptr<OpKernel> k;
+    if (name.empty() || name == "Id") return k;
+    sub_node->set_name(name);
+    sub_node->set_op(name);
+    Status status;
+    k = CreateOpKernel(c->device_type(), c->device(),
+                       c->device()->GetAllocator(AllocatorAttributes()),
+                       *sub_node, c->graph_def_version(), &status);
+    if (!status.ok()) {
+      c->CtxFailureWithWarning(errors::Internal("Failed to build OpKernel for ",
+                                                name, " : ",
+                                                status.error_message()));
+    }
+    return k;
+  }
+
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            col_params_.name),
+        done);
+    if (!CanProceedWithCompute(c, col_exec, done)) return;
+    // Allocate the output tensor, trying to reuse the input.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(c,
+                         c->forward_input_or_allocate_output(
+                             {0}, 0, c->input(0).shape(), &output),
+                         done);
+
+    auto actual_done = [c, col_exec, done](const Status& s) {
+      OP_REQUIRES_OK_ASYNC(c, s, done);
+      done();
+    };
+    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
+                        CollectiveReduceOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
+                        CollectiveReduceOpKernel);
+
+class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
+ public:
+  explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
+      : CollectiveOpKernel(c) {
+    col_params_.instance.type = BROADCAST_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+    OP_REQUIRES_OK(
+        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+    OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+    col_params_.is_source = true;
+    col_params_.instance.impl_details.subdiv_offsets = {0};
+
+    col_params_.name =
+        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+    col_params_.group.device_type = c->device_type();
+  }
+
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            col_params_.name),
+        done);
+    if (!CanProceedWithCompute(c, col_exec, done)) return;
+    OP_REQUIRES_ASYNC(
+        c, shape_.IsSameSize(c->input(0).shape()),
+        errors::Internal("Declared shape of op ", col_params_.name,
+                         " does not match shape of input"),
+        done);
+    // Allocate the output Tensor, trying to reuse the input.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(
+        c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
+
+    auto actual_done = [c, col_exec, done](const Status& s) {
+      OP_REQUIRES_OK_ASYNC(c, s, done);
+      done();
+    };
+    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+  }
+
+ private:
+  TensorShape shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
+                        CollectiveBcastSendOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
+                        CollectiveBcastSendOpKernel);
+
+class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
+ public:
+  explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
+      : CollectiveOpKernel(c) {
+    col_params_.instance.type = BROADCAST_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+    OP_REQUIRES_OK(
+        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+    OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+    col_params_.is_source = false;
+    col_params_.instance.impl_details.subdiv_offsets = {0};
+
+    col_params_.name =
+        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+    col_params_.group.device_type = c->device_type();
+  }
+
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            col_params_.name),
+        done);
+    if (!CanProceedWithCompute(c, col_exec, done)) return;
+    // No input, so must allocate output.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+
+    auto actual_done = [c, col_exec, done](const Status& s) {
+      OP_REQUIRES_OK_ASYNC(c, s, done);
+      done();
+    };
+    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+  }
+
+ private:
+  TensorShape shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
+                        CollectiveBcastRecvOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
+                        CollectiveBcastRecvOpKernel);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
new file mode 100644 (file)
index 0000000..d6157a6
--- /dev/null
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("CollectiveReduce")
+    .Input("input: T")
+    .Output("data: T")
+    .Attr("T: {float, float16, float64, int32, int64}")
+    .Attr("group_size: int")
+    .Attr("group_key: int")
+    .Attr("instance_key: int")
+    .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
+    .Attr("final_op: {'Id', 'Div'}")
+    .Attr("subdiv_offsets: list(int)")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("CollectiveBcastSend")
+    .Input("input: T")
+    .Output("data: T")
+    .Attr("T: {float, float16, float64, int32, int64}")
+    .Attr("group_size: int")
+    .Attr("group_key: int")
+    .Attr("instance_key: int")
+    .Attr("shape: shape")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ExplicitShape);
+
+REGISTER_OP("CollectiveBcastRecv")
+    .Output("data: T")
+    .Attr("T: {float, float16, float64, int32, int64}")
+    .Attr("group_size: int")
+    .Attr("group_key: int")
+    .Attr("instance_key: int")
+    .Attr("shape: shape")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ExplicitShape);
+
+}  // namespace tensorflow