support profiling multiple tpu through one grpc and one session.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Apr 2018 22:59:47 +0000 (15:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 23:06:18 +0000 (16:06 -0700)
data are saved with host prefix.

PiperOrigin-RevId: 192523668

tensorflow/contrib/tpu/profiler/BUILD
tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
tensorflow/contrib/tpu/profiler/tpu_profiler.proto
tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto

index 1c32993..dbf1ab6 100644 (file)
@@ -46,6 +46,7 @@ tf_cc_binary(
     visibility = ["//visibility:public"],
     deps = [
         ":dump_tpu_profile",
+        ":tpu_profiler_analysis_proto_cc",
         ":tpu_profiler_proto_cc",
         ":version",
         "//tensorflow/core:framework_internal",
index 6b198db..a535884 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 
 #include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h"
 #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
+#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h"
 #include "tensorflow/contrib/tpu/profiler/version.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -40,6 +41,7 @@ namespace tensorflow {
 namespace tpu {
 namespace {
 
+using ::tensorflow::grpc::TPUProfileAnalysis;
 using ::tensorflow::TPUProfiler;
 
 constexpr uint64 kMaxEvents = 1000000;
@@ -64,11 +66,10 @@ Status ValidateHostPortPair(const string& host_port) {
   return Status::OK();
 }
 
-// Returns whether the returned trace is empty.
-// Failure are handled by CHECK, i.e. abort()
-bool Profile(const string& service_addr, const string& logdir, int duration_ms,
-             const string& repository_root, const string& session_id,
-             const ProfileOptions& opts) {
+ProfileRequest PopulateProfileRequest(int duration_ms,
+                                      const string& repository_root,
+                                      const string& session_id,
+                                      const ProfileOptions& opts) {
   ProfileRequest request;
   request.set_duration_ms(duration_ms);
   request.set_max_events(kMaxEvents);
@@ -83,6 +84,17 @@ bool Profile(const string& service_addr, const string& logdir, int duration_ms,
   *request.mutable_opts() = opts;
   std::cout << "Limiting the number of trace events to " << kMaxEvents
             << std::endl;
+  return request;
+}
+
+// Returns whether the returned trace is empty.
+// Failure are handled by CHECK, i.e. abort()
+bool Profile(const string& service_addr, const string& logdir, int duration_ms,
+             const string& repository_root, const string& session_id,
+             const ProfileOptions& opts) {
+  ProfileRequest request =
+      PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
+
   ::grpc::ClientContext context;
   ::grpc::ChannelArguments channel_args;
   // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available.
@@ -120,7 +132,36 @@ bool NewSession(const string& service_addr,
                 const std::vector<tensorflow::string>& hostnames,
                 int duration_ms, const string& repository_root,
                 const string& session_id, const ProfileOptions& opts) {
-  return true;
+  NewProfileSessionRequest new_session_request;
+  *new_session_request.mutable_request() =
+      PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
+  new_session_request.set_repository_root(repository_root);
+  new_session_request.set_session_id(session_id);
+  std::copy(
+      hostnames.begin(), hostnames.end(),
+      proto2::RepeatedFieldBackInserter(new_session_request.mutable_hosts()));
+
+  ::grpc::ClientContext context;
+  ::grpc::ChannelArguments channel_args;
+  // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
+  // `ValidateHostPortPair` checks for empty host string case.
+  channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
+  // TODO(jiesun): GRPC support following relevant naming scheme:
+  // 1. dns:///host:port
+  // 2. ipv4:host:port or ipv6:[host]:port
+  // We might need to change the prefix which depends on what TPU name resolver
+  // will give us.
+  std::unique_ptr<TPUProfileAnalysis::Stub> stub =
+      TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel(
+          "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
+          channel_args));
+  NewProfileSessionResponse new_session_response;
+  TF_QCHECK_OK(FromGrpcStatus(
+      stub->NewSession(&context, new_session_request, &new_session_response)));
+
+  std::cout << "Profile session succeed for hosts:"
+            << str_util::Join(hostnames, ",");
+  return new_session_response.empty_trace();
 }
 
 }  // namespace
index ae50858..b53f9be 100644 (file)
@@ -64,7 +64,8 @@ Status WriteGzippedDataToFile(const string& filename, const string& data) {
 
 Status DumpTraceToLogDirectory(StringPiece run_dir, const string& host_prefix,
                                const string& encoded_trace, std::ostream* os) {
-  string proto_path = JoinPath(run_dir, kProtoTraceFileName);
+  string proto_path =
+      JoinPath(run_dir, StrCat(host_prefix, kProtoTraceFileName));
   TF_RETURN_IF_ERROR(
       WriteStringToFile(Env::Default(), proto_path, encoded_trace));
   LOG(INFO) << "Dumped raw-proto trace data to " << proto_path;
index 8505c4b..7be694e 100644 (file)
@@ -96,5 +96,10 @@ message ProfileResponse {
 
   // Data payload for each required tools.
   repeated ProfileToolData tool_data = 6;
-  // next-field: 7
+
+  // When we write profiling data directly to repository directory, we need a
+  // way to figure out whether the captured trace is empty (due to idle TPU).
+  bool empty_trace = 7;
+
+  // next-field: 8
 }
index a4fc8d4..8b0bbde 100644 (file)
@@ -7,13 +7,15 @@ message NewProfileSessionRequest {
   ProfileRequest request = 1;
   string repository_root = 2;
   repeated string hosts = 3;
+  string session_id = 4;
 }
 
 message NewProfileSessionResponse {
   // Auxiliary error_message.
   string error_message = 1;
-  // If success, return session identifier for future reference.
-  string session_id = 2;
+
+  // Whether all hosts had returned a empty trace.
+  bool empty_trace = 2;
 }
 
 message EnumProfileSessionsAndToolsRequest {