"boosted_trees_ops",
"candidate_sampling_ops",
"checkpoint_ops",
+ "collective_ops",
"control_flow_ops",
"ctc_ops",
"data_flow_ops",
":boosted_trees_ops_op_lib",
":candidate_sampling_ops_op_lib",
":checkpoint_ops_op_lib",
+ ":collective_ops_op_lib",
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
":cudnn_rnn_ops_op_lib",
"//tensorflow/core/kernels:boosted_trees_ops",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:cudnn_rnn_kernels",
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/bfc_allocator.h",
+ "common_runtime/buf_rendezvous.h",
+ "common_runtime/build_graph_options.h",
"common_runtime/collective_executor_mgr.h",
"common_runtime/collective_param_resolver_local.h",
"common_runtime/collective_rma_local.h",
- "common_runtime/device_resolver_local.h",
- "common_runtime/buf_rendezvous.h",
- "common_runtime/build_graph_options.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/costmodel_manager.h",
"common_runtime/debugger_state_interface.h",
"common_runtime/device_factory.h",
+ "common_runtime/device_resolver_local.h",
"common_runtime/device_set.h",
"common_runtime/dma_helper.h",
"common_runtime/eigen_thread_pool.h",
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
"common_runtime/pending_counts.h",
+ "common_runtime/placer.h",
"common_runtime/process_util.h",
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
"common_runtime/scoped_allocator.h",
"common_runtime/scoped_allocator_mgr.h",
"common_runtime/session_factory.h",
- "common_runtime/placer.h",
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+class CollectiveOpKernel : public AsyncOpKernel {
+ public:
+ explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+ // A string encoding instance, frame and iter to be handed off to
+ // the implementation for use in generating RecvBuf keys.
+ string GetCollectiveKey(OpKernelContext* c) {
+ return strings::StrCat(col_params_.instance.instance_key, ":",
+ c->frame_iter().frame_id, ":",
+ c->frame_iter().iter_id);
+ }
+
+ // Returns false if calling invocation of ComputeAsync should return
+ // immediately.
+ bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
+ const DoneCallback& done) {
+ if (col_params_.group.group_size >
+ col_params_.instance.device_names.size()) {
+ // This is the first invocation: Finish initializing col_params_.
+ // Call in a blockable thread because it's not guaranteed that
+ // this call cannot block.
+ c->env()->SchedClosure([this, c, done, col_exec]() {
+ col_exec->CompleteParamsAsync(c->device()->name(), &col_params_,
+ c->cancellation_manager(),
+ [this, c, done](const Status& s) {
+ if (s.ok()) {
+ ComputeAsync(c, done);
+ } else {
+ c->SetStatus(s);
+ done();
+ }
+ });
+ });
+ return false;
+ }
+ return true;
+ }
+
+ CollectiveParams col_params_;
+};
+
+class CollectiveReduceOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = REDUCTION_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("subdiv_offsets",
+ &col_params_.instance.impl_details.subdiv_offsets));
+ string merge_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
+ OP_REQUIRES(c, merge_op_name == "Add" || merge_op_name == "Mul",
+ errors::InvalidArgument(
+ "merge_op must be one of {\"Add\", \"Mul\"} but got ",
+ merge_op_name));
+ string final_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
+ OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
+ errors::InvalidArgument(
+ "final_op must be one of {\"Id\", \"Div\"} but got ",
+ final_op_name));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+
+ const NodeDef& real_node = c->def();
+ col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
+ merge_op_name, ",", final_op_name, ")");
+ col_params_.group.device_type = c->device_type();
+
+ // Find the OpKernels by name, type and device type.
+ NodeDef sub_node;
+ // The merge_op takes two inputs
+ sub_node.add_input(real_node.input(0));
+ sub_node.add_input(real_node.input(0));
+ sub_node.set_device(real_node.device());
+ SetAttrValue(col_params_.instance.data_type,
+ &(*sub_node.mutable_attr())["T"]);
+ col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
+ col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
+ }
+
+ std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
+ const string& name,
+ NodeDef* sub_node) {
+ std::unique_ptr<OpKernel> k;
+ if (name.empty() || name == "Id") return k;
+ sub_node->set_name(name);
+ sub_node->set_op(name);
+ Status status;
+ k = CreateOpKernel(c->device_type(), c->device(),
+ c->device()->GetAllocator(AllocatorAttributes()),
+ *sub_node, c->graph_def_version(), &status);
+ if (!status.ok()) {
+ c->CtxFailureWithWarning(errors::Internal("Failed to build OpKernel for ",
+ name, " : ",
+ status.error_message()));
+ }
+ return k;
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
+ CollectiveReduceOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
+ CollectiveReduceOpKernel);
+
+class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = true;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ OP_REQUIRES_ASYNC(
+ c, shape_.IsSameSize(c->input(0).shape()),
+ errors::Internal("Declared shape of op ", col_params_.name,
+ " does not match shape of input"),
+ done);
+ // Allocate the output Tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
+ CollectiveBcastSendOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
+ CollectiveBcastSendOpKernel);
+
+class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = false;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
+ CollectiveBcastRecvOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
+ CollectiveBcastRecvOpKernel);
+
+} // namespace
+} // namespace tensorflow