Add metadata for gathering information about host compute transfers while compiling...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Mar 2018 00:46:54 +0000 (16:46 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 00:53:12 +0000 (16:53 -0800)
PiperOrigin-RevId: 188102740

tensorflow/compiler/tf2xla/BUILD
tensorflow/compiler/tf2xla/host_compute_metadata.proto [new file with mode: 0644]
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h

index fb82c26..eb20ca5 100644 (file)
@@ -58,6 +58,15 @@ xla_proto_library(
     ],
 )
 
+xla_proto_library(
+    name = "host_compute_metadata_proto",
+    srcs = ["host_compute_metadata.proto"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
 cc_library(
     name = "tf2xla",
     srcs = ["tf2xla.cc"],
@@ -149,6 +158,7 @@ cc_library(
         ":common",
         ":dump_graph",
         ":functionalize_control_flow",
+        ":host_compute_metadata_proto",
         ":sharding_util",
         ":tf2xla_util",
         "//tensorflow/compiler/tf2xla/lib:util",
diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto
new file mode 100644 (file)
index 0000000..43ab371
--- /dev/null
@@ -0,0 +1,38 @@
+syntax = "proto3";
+
+package tensorflow.tf2xla;
+option cc_enable_arenas = true;
+option java_outer_classname = "Tf2XlaProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.tf2xla";
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// TensorMetadata indicates the type and shape of a Tensor that is
+// part of a host compute transfer.
+message TensorMetadata {
+  DataType type = 1;
+  TensorShapeProto shape = 2;
+}
+
+// HostTransferMetadata describes a transfer either from host to device
+// or device to host. It has a key that is unique to the computation,
+// and metadata about the list of tensors being transferred.
+message HostTransferMetadata {
+  // The key used to identify this transfer.
+  string key = 1;
+
+  // For each Tensor being transferred, its type and shape.
+  repeated TensorMetadata metadata = 2;
+}
+
+// HostComputeMetadata describes all the sends and recvs
+// from all host compute transfer ops in a computation.
+message HostComputeMetadata {
+  // Metadata about each device_to_host transfer
+  repeated HostTransferMetadata device_to_host = 1;
+
+  // Metadata about each host_to_device transfer
+  repeated HostTransferMetadata host_to_device = 2;
+}
index 5ec05c4..0dc5118 100644 (file)
@@ -674,6 +674,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
   VLOG(2) << "XLA output shape: "
           << xla::ShapeUtil::HumanString(result->xla_output_shape);
 
+  // Copy the host transfer metadata to the result.
+  for (const auto& send : host_compute_sends_) {
+    *result->host_compute_metadata.add_device_to_host() = send.second;
+  }
+  for (const auto& recv : host_compute_recvs_) {
+    *result->host_compute_metadata.add_host_to_device() = recv.second;
+  }
+
   // Tensorflow expects a major-to-minor order of results.
   xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
 
@@ -708,4 +716,59 @@ Status XlaCompiler::GetChannelHandle(const string& key,
   return Status::OK();
 }
 
+namespace {
+
+void SetTransfer(const string& key, const std::vector<DataType>& types,
+                 const std::vector<TensorShape>& shapes,
+                 tf2xla::HostTransferMetadata* transfer) {
+  transfer->set_key(key);
+  CHECK(types.size() == shapes.size());
+  for (int i = 0; i < types.size(); ++i) {
+    tf2xla::TensorMetadata* metadata = transfer->add_metadata();
+    metadata->set_type(types[i]);
+    shapes[i].AsProto(metadata->mutable_shape());
+  }
+}
+
+}  // namespace
+
+Status XlaCompiler::SetDeviceToHostMetadata(
+    const string& key, const std::vector<DataType>& types,
+    const std::vector<TensorShape>& shapes) {
+  if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
+    return errors::InvalidArgument(
+        "Duplicate calls to SetDeviceToHostMetadata with key ", key);
+  }
+  tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
+  SetTransfer(key, types, shapes, &transfer);
+  return Status::OK();
+}
+
+Status XlaCompiler::GetDeviceToHostShapes(
+    const string& key, std::vector<TensorShape>* shapes) const {
+  const auto iter = host_compute_sends_.find(key);
+  if (iter == host_compute_sends_.end()) {
+    return errors::InvalidArgument(
+        "No host compute send shapes registered for key ", key);
+  }
+  shapes->clear();
+  for (int i = 0; i < iter->second.metadata_size(); ++i) {
+    TensorShape shape(iter->second.metadata(i).shape());
+    shapes->push_back(shape);
+  }
+  return Status::OK();
+}
+
+Status XlaCompiler::SetHostToDeviceMetadata(
+    const string& key, const std::vector<DataType>& types,
+    const std::vector<TensorShape>& shapes) {
+  if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
+    return errors::InvalidArgument(
+        "Duplicate calls to SetHostToDeviceMetadata with key ", key);
+  }
+  tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
+  SetTransfer(key, types, shapes, &transfer);
+  return Status::OK();
+}
+
 }  // namespace tensorflow
index c4449bc..a70d263 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
 
+#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/common_runtime/device.h"
@@ -216,6 +217,10 @@ class XlaCompiler {
     // containing both constant and non-constant results.
     std::vector<OutputDescription> outputs;
 
+    // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
+    // matching RecvAtHost/SendFromHost Ops in the outer graph.
+    tf2xla::HostComputeMetadata host_compute_metadata;
+
     // Resources whose values were updated by the computation, ordered
     // by return value position. Resource updates follow the non-constant
     // results in the outputs of XLA computation.
@@ -296,6 +301,22 @@ class XlaCompiler {
   // same XlaCompiler.
   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
 
+  // Sets the shapes and types for the device to host transfer associated with
+  // 'key'.
+  Status SetDeviceToHostMetadata(const string& key,
+                                 const std::vector<DataType>& types,
+                                 const std::vector<TensorShape>& shapes);
+
+  // Gets the shapes the device to host transfer associated with 'key'.
+  Status GetDeviceToHostShapes(const string& key,
+                               std::vector<TensorShape>* shapes) const;
+
+  // Sets the shapes and types for the host to device transfer associated with
+  // 'key'.
+  Status SetHostToDeviceMetadata(const string& key,
+                                 const std::vector<DataType>& types,
+                                 const std::vector<TensorShape>& shapes);
+
   const Options& options() const { return options_; }
   xla::Client* client() const { return options_.client; }
   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
@@ -359,6 +380,9 @@ class XlaCompiler {
 
   std::unordered_map<string, xla::ChannelHandle> channels_;
 
+  std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
+  std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
 };