refactor and add proto field required by POD support.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 00:43:43 +0000 (17:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 00:46:08 +0000 (17:46 -0700)
PiperOrigin-RevId: 191826636

tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
tensorflow/contrib/tpu/profiler/tpu_profiler.proto

index f2003e0..6b198db 100644 (file)
@@ -64,9 +64,11 @@ Status ValidateHostPortPair(const string& host_port) {
   return Status::OK();
 }
 
-ProfileResponse Profile(const string& service_addr, int duration_ms,
-                        const string& repository_root, const string& session_id,
-                        const ProfileOptions& opts) {
+// Returns whether the returned trace is empty.
+// Failure are handled by CHECK, i.e. abort()
+bool Profile(const string& service_addr, const string& logdir, int duration_ms,
+             const string& repository_root, const string& session_id,
+             const ProfileOptions& opts) {
   ProfileRequest request;
   request.set_duration_ms(duration_ms);
   request.set_max_events(kMaxEvents);
@@ -94,7 +96,31 @@ ProfileResponse Profile(const string& service_addr, int duration_ms,
           channel_args));
   ProfileResponse response;
   TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
-  return response;
+
+  if (!response.encoded_trace().empty()) {
+    TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
+        logdir, session_id, "", response, &std::cout));
+    // Print this at the end so that it's not buried in irrelevant LOG messages.
+    std::cout
+        << "NOTE: using the trace duration " << duration_ms << "ms."
+        << std::endl
+        << "Set an appropriate duration (with --duration_ms) if you "
+           "don't see a full step in your trace or the captured trace is too "
+           "large."
+        << std::endl;
+  }
+
+  return response.encoded_trace().empty();
+}
+
+// Start a new profiling session that include all the hosts included in
+// hostnames, for the time interval of duration_ms. Possibly save the profiling
+// result in the directory specified by repository_root and session_id.
+bool NewSession(const string& service_addr,
+                const std::vector<tensorflow::string>& hostnames,
+                int duration_ms, const string& repository_root,
+                const string& session_id, const ProfileOptions& opts) {
+  return true;
 }
 
 }  // namespace
@@ -104,12 +130,16 @@ ProfileResponse Profile(const string& service_addr, int duration_ms,
 int main(int argc, char** argv) {
   tensorflow::string FLAGS_service_addr;
   tensorflow::string FLAGS_logdir;
+  tensorflow::string FLAGS_workers_list;
   int FLAGS_duration_ms = 2000;
   int FLAGS_num_tracing_attempts = 3;
   bool FLAGS_include_dataset_ops = true;
   std::vector<tensorflow::Flag> flag_list = {
       tensorflow::Flag("service_addr", &FLAGS_service_addr,
                        "Address of TPU profiler service e.g. localhost:8466"),
+      tensorflow::Flag("workers_list", &FLAGS_workers_list,
+                       "The list of worker TPUs that we are about to profile "
+                       "in the current session."),
       tensorflow::Flag("logdir", &FLAGS_logdir,
                        "Path of TensorBoard log directory e.g. /tmp/tb_log, "
                        "gs://tb_bucket"),
@@ -153,18 +183,30 @@ int main(int argc, char** argv) {
   constexpr char kProfilePluginDirectory[] = "plugins/profile/";
   tensorflow::string repository_root =
       ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory);
+  std::vector<tensorflow::string> hostnames =
+      tensorflow::str_util::Split(FLAGS_workers_list, ",");
+
+  bool empty_trace = false;
   while (true) {
     std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
               << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
-    response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms,
-                                        repository_root, session_id, opts);
-    if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break;
+    if (hostnames.empty()) {
+      empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir,
+                                             duration_ms, repository_root,
+                                             session_id, opts);
+    } else {
+      tensorflow::string tpu_master = FLAGS_service_addr;
+      empty_trace =
+          tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms,
+                                      repository_root, session_id, opts);
+    }
+    if (remaining_attempts <= 0 || !empty_trace) break;
     std::cout << "No trace event is collected. Automatically retrying."
               << std::endl
               << std::endl;
   }
 
-  if (response.encoded_trace().empty()) {
+  if (empty_trace) {
     std::cout << "No trace event is collected after "
               << FLAGS_num_tracing_attempts << " attempt(s). "
               << "Perhaps, you want to try again (with more attempts?)."
@@ -175,13 +217,5 @@ int main(int argc, char** argv) {
     return 0;
   }
 
-  TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
-      FLAGS_logdir, session_id, response, &std::cout));
-  // Print this at the end so that it's not buried in irrelevant LOG messages.
-  std::cout
-      << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
-      << "Set an appropriate duration (with --duration_ms) if you "
-         "don't see a full step in your trace or the captured trace is too "
-         "large."
-      << std::endl;
+  return 0;
 }
index ebd6185..ae50858 100644 (file)
@@ -41,6 +41,7 @@ namespace {
 using ::tensorflow::io::JoinPath;
 using ::tensorflow::protobuf::util::JsonOptions;
 using ::tensorflow::protobuf::util::MessageToJsonString;
+using ::tensorflow::strings::StrCat;
 
 constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph.";
 constexpr char kJsonOpProfileFileName[] = "op_profile.json";
@@ -61,28 +62,33 @@ Status WriteGzippedDataToFile(const string& filename, const string& data) {
   return Status::OK();
 }
 
-Status DumpTraceToLogDirectory(StringPiece run_dir, const string& encoded_trace,
-                               std::ostream* os) {
+Status DumpTraceToLogDirectory(StringPiece run_dir, const string& host_prefix,
+                               const string& encoded_trace, std::ostream* os) {
   string proto_path = JoinPath(run_dir, kProtoTraceFileName);
   TF_RETURN_IF_ERROR(
       WriteStringToFile(Env::Default(), proto_path, encoded_trace));
   LOG(INFO) << "Dumped raw-proto trace data to " << proto_path;
 
-  string json_path = JoinPath(run_dir, kJsonTraceFileName);
+  string json_path = JoinPath(run_dir, StrCat(host_prefix, kJsonTraceFileName));
   Trace trace;
   trace.ParseFromString(encoded_trace);
-  *os << "Trace contains " << trace.trace_events_size() << " events."
-      << std::endl;
+  if (os) {
+    *os << "Trace contains " << trace.trace_events_size() << " events."
+        << std::endl;
+  }
   TF_RETURN_IF_ERROR(
       WriteGzippedDataToFile(json_path, TraceEventsToJson(trace)));
-  *os << "Dumped JSON trace data to " << json_path << std::endl;
+  if (os) {
+    *os << "Dumped JSON trace data to " << json_path << std::endl;
+  }
   return Status::OK();
 }
 
 Status DumpOpProfileToLogDirectory(StringPiece run_dir,
+                                   const string& host_prefix,
                                    const tpu::op_profile::Profile& profile,
                                    std::ostream* os) {
-  string path = JoinPath(run_dir, kJsonOpProfileFileName);
+  string path = JoinPath(run_dir, StrCat(host_prefix, kJsonOpProfileFileName));
   string json;
   JsonOptions options;
   options.always_print_primitive_fields = true;
@@ -93,49 +99,20 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir,
         string(status.error_message()));
   }
   TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json));
-  *os << "Dumped json op profile data to " << path << std::endl;
+  if (os) {
+    *os << "Dumped json op profile data to " << path << std::endl;
+  }
   return Status::OK();
 }
 
 Status DumpToolDataToLogDirectory(StringPiece run_dir,
+                                  const string& host_prefix,
                                   const tensorflow::ProfileToolData& tool,
                                   std::ostream* os) {
-  string path = JoinPath(run_dir, tool.name());
+  string path = JoinPath(run_dir, StrCat(host_prefix, tool.name()));
   TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data()));
-  *os << "Dumped tool data for " << tool.name() << " to " << path << std::endl;
-  return Status::OK();
-}
-
-Status DumpGraphEvents(const string& logdir, const string& run,
-                       const ProfileResponse& response, std::ostream* os) {
-  int num_graphs = response.computation_graph_size();
-  if (response.computation_graph_size() == 0) return Status::OK();
-  // The server might generates multiple graphs for one program; we simply
-  // pick the first one.
-  if (num_graphs > 1) {
-    *os << num_graphs
-        << " TPU program variants observed over the profiling period. "
-        << "One computation graph will be chosen arbitrarily." << std::endl;
-  }
-  // The graph plugin expects the graph in <logdir>/<run>/<event.file>.
-  string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
-  TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir));
-  EventsWriter event_writer(JoinPath(run_dir, "events"));
-  Event event;
-  // Add the computation graph.
-  event.set_graph_def(response.computation_graph(0).SerializeAsString());
-  event_writer.WriteEvent(event);
-  *os << "Wrote a HLO graph to " << event_writer.FileName() << std::endl;
-
-  if (response.has_hlo_metadata()) {
-    tensorflow::TaggedRunMetadata tagged_run_metadata;
-    tagged_run_metadata.set_tag(run);
-    tagged_run_metadata.set_run_metadata(
-        response.hlo_metadata().SerializeAsString());
-    tensorflow::Event meta_event;
-    *meta_event.mutable_tagged_run_metadata() = tagged_run_metadata;
-    event_writer.WriteEvent(meta_event);
-    *os << "Wrote HLO ops run metadata to " << event_writer.FileName()
+  if (os) {
+    *os << "Dumped tool data for " << tool.name() << " to " << path
         << std::endl;
   }
   return Status::OK();
@@ -144,27 +121,29 @@ Status DumpGraphEvents(const string& logdir, const string& run,
 }  // namespace
 
 Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+                                  const string& host,
                                   const ProfileResponse& response,
                                   std::ostream* os) {
   // Dumps profile data to <logdir>/plugins/profile/<run>/.
+  string host_prefix = host.empty() ? "" : StrCat(host, ".");
   string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
   TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
 
   // Ignore computation_graph for now.
   if (!response.encoded_trace().empty()) {
     LOG(INFO) << "Converting trace events to TraceViewer JSON.";
-    TF_RETURN_IF_ERROR(
-        DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os));
+    TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
+                                               response.encoded_trace(), os));
   }
   if (response.has_op_profile() &&
       (response.op_profile().has_by_program_structure() ||
        response.op_profile().has_by_category())) {
-    TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir,
+    TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
                                                    response.op_profile(), os));
   }
   for (const auto& tool_data : response.tool_data()) {
-    TF_RETURN_IF_ERROR(
-        DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
+    TF_RETURN_IF_ERROR(DumpToolDataToLogDirectory(profile_run_dir, host_prefix,
+                                                  tool_data, os));
   }
 
   return Status::OK();
index 29ef977..ecf21b1 100644 (file)
@@ -32,6 +32,7 @@ namespace tpu {
 // Note: this function creates a directory even when all fields in
 // ProfileResponse are unset/empty.
 Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+                                  const string& host,
                                   const ProfileResponse& response,
                                   std::ostream* os);
 
index cddc3cd..8505c4b 100644 (file)
@@ -21,6 +21,17 @@ message ProfileOptions {
   // next-field: 2
 }
 
+message ToolRequestOptions {
+  // Required formats for the tool, it should be one of "json", "proto", "raw"
+  // etc. If not specified (backward compatible), use default format, i.e. most
+  // tools use json format.
+  string output_formats = 2;
+
+  // Whether save the result directly to repository or pass it back to caller.
+  // Default to false for backward compatibilities.
+  bool save_to_repo = 3;
+}
+
 message ProfileRequest {
   // In future, the caller will be able to customize when profiling starts and
   // stops. For now, it collects `duration_ms` milliseconds worth of data.
@@ -30,9 +41,12 @@ message ProfileRequest {
   // events.
   uint64 max_events = 2;
 
-  // required profiling tools name such as "input_pipeline_analyzer" etc
+  // Required profiling tools name such as "input_pipeline_analyzer" etc
   repeated string tools = 3;
 
+  // Specifies the requirement for each tools.
+  map<string, ToolRequestOptions> tool_options = 8;
+
   // Optional profiling options that control how a TF session will be profiled.
   ProfileOptions opts = 4;
 
@@ -43,10 +57,14 @@ message ProfileRequest {
   // The user provided profile session identifier.
   string session_id = 6;
 
+  // The hostname of system where the profile should happen.
+  // We use it as identifier in part of our output filename.
+  string host_name = 7;
+
   // In future, the caller will indicate which TF session is being profiled, and
   // only data relating to that program will be returned. For now, we assume
   // all activity during the profiling period is relevant.
-  // next-field: 7
+  // next-field: 9
 }
 
 message ProfileToolData {