return Status::OK();
}
-ProfileResponse Profile(const string& service_addr, int duration_ms,
- const string& repository_root, const string& session_id,
- const ProfileOptions& opts) {
+// Returns whether the returned trace is empty.
+// Failure are handled by CHECK, i.e. abort()
+bool Profile(const string& service_addr, const string& logdir, int duration_ms,
+ const string& repository_root, const string& session_id,
+ const ProfileOptions& opts) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
channel_args));
ProfileResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
- return response;
+
+ if (!response.encoded_trace().empty()) {
+ TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
+ logdir, session_id, "", response, &std::cout));
+ // Print this at the end so that it's not buried in irrelevant LOG messages.
+ std::cout
+ << "NOTE: using the trace duration " << duration_ms << "ms."
+ << std::endl
+ << "Set an appropriate duration (with --duration_ms) if you "
+ "don't see a full step in your trace or the captured trace is too "
+ "large."
+ << std::endl;
+ }
+
+ return response.encoded_trace().empty();
+}
+
+// Start a new profiling session that include all the hosts included in
+// hostnames, for the time interval of duration_ms. Possibly save the profiling
+// result in the directory specified by repository_root and session_id.
+bool NewSession(const string& service_addr,
+ const std::vector<tensorflow::string>& hostnames,
+ int duration_ms, const string& repository_root,
+ const string& session_id, const ProfileOptions& opts) {
+ return true;
}
} // namespace
int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
+ tensorflow::string FLAGS_workers_list;
int FLAGS_duration_ms = 2000;
int FLAGS_num_tracing_attempts = 3;
bool FLAGS_include_dataset_ops = true;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
+ tensorflow::Flag("workers_list", &FLAGS_workers_list,
+ "The list of worker TPUs that we are about to profile "
+ "in the current session."),
tensorflow::Flag("logdir", &FLAGS_logdir,
"Path of TensorBoard log directory e.g. /tmp/tb_log, "
"gs://tb_bucket"),
constexpr char kProfilePluginDirectory[] = "plugins/profile/";
tensorflow::string repository_root =
::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory);
+ std::vector<tensorflow::string> hostnames =
+ tensorflow::str_util::Split(FLAGS_workers_list, ",");
+
+ bool empty_trace = false;
while (true) {
std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << remaining_attempts-- << std::endl;
- response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms,
- repository_root, session_id, opts);
- if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break;
+ if (hostnames.empty()) {
+ empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir,
+ duration_ms, repository_root,
+ session_id, opts);
+ } else {
+ tensorflow::string tpu_master = FLAGS_service_addr;
+ empty_trace =
+ tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms,
+ repository_root, session_id, opts);
+ }
+ if (remaining_attempts <= 0 || !empty_trace) break;
std::cout << "No trace event is collected. Automatically retrying."
<< std::endl
<< std::endl;
}
- if (response.encoded_trace().empty()) {
+ if (empty_trace) {
std::cout << "No trace event is collected after "
<< FLAGS_num_tracing_attempts << " attempt(s). "
<< "Perhaps, you want to try again (with more attempts?)."
return 0;
}
- TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
- FLAGS_logdir, session_id, response, &std::cout));
- // Print this at the end so that it's not buried in irrelevant LOG messages.
- std::cout
- << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
- << "Set an appropriate duration (with --duration_ms) if you "
- "don't see a full step in your trace or the captured trace is too "
- "large."
- << std::endl;
+ return 0;
}
using ::tensorflow::io::JoinPath;
using ::tensorflow::protobuf::util::JsonOptions;
using ::tensorflow::protobuf::util::MessageToJsonString;
+using ::tensorflow::strings::StrCat;
constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph.";
constexpr char kJsonOpProfileFileName[] = "op_profile.json";
return Status::OK();
}
-Status DumpTraceToLogDirectory(StringPiece run_dir, const string& encoded_trace,
- std::ostream* os) {
+Status DumpTraceToLogDirectory(StringPiece run_dir, const string& host_prefix,
+ const string& encoded_trace, std::ostream* os) {
string proto_path = JoinPath(run_dir, kProtoTraceFileName);
TF_RETURN_IF_ERROR(
WriteStringToFile(Env::Default(), proto_path, encoded_trace));
LOG(INFO) << "Dumped raw-proto trace data to " << proto_path;
- string json_path = JoinPath(run_dir, kJsonTraceFileName);
+ string json_path = JoinPath(run_dir, StrCat(host_prefix, kJsonTraceFileName));
Trace trace;
trace.ParseFromString(encoded_trace);
- *os << "Trace contains " << trace.trace_events_size() << " events."
- << std::endl;
+ if (os) {
+ *os << "Trace contains " << trace.trace_events_size() << " events."
+ << std::endl;
+ }
TF_RETURN_IF_ERROR(
WriteGzippedDataToFile(json_path, TraceEventsToJson(trace)));
- *os << "Dumped JSON trace data to " << json_path << std::endl;
+ if (os) {
+ *os << "Dumped JSON trace data to " << json_path << std::endl;
+ }
return Status::OK();
}
Status DumpOpProfileToLogDirectory(StringPiece run_dir,
+ const string& host_prefix,
const tpu::op_profile::Profile& profile,
std::ostream* os) {
- string path = JoinPath(run_dir, kJsonOpProfileFileName);
+ string path = JoinPath(run_dir, StrCat(host_prefix, kJsonOpProfileFileName));
string json;
JsonOptions options;
options.always_print_primitive_fields = true;
string(status.error_message()));
}
TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json));
- *os << "Dumped json op profile data to " << path << std::endl;
+ if (os) {
+ *os << "Dumped json op profile data to " << path << std::endl;
+ }
return Status::OK();
}
Status DumpToolDataToLogDirectory(StringPiece run_dir,
+ const string& host_prefix,
const tensorflow::ProfileToolData& tool,
std::ostream* os) {
- string path = JoinPath(run_dir, tool.name());
+ string path = JoinPath(run_dir, StrCat(host_prefix, tool.name()));
TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data()));
- *os << "Dumped tool data for " << tool.name() << " to " << path << std::endl;
- return Status::OK();
-}
-
-Status DumpGraphEvents(const string& logdir, const string& run,
- const ProfileResponse& response, std::ostream* os) {
- int num_graphs = response.computation_graph_size();
- if (response.computation_graph_size() == 0) return Status::OK();
- // The server might generates multiple graphs for one program; we simply
- // pick the first one.
- if (num_graphs > 1) {
- *os << num_graphs
- << " TPU program variants observed over the profiling period. "
- << "One computation graph will be chosen arbitrarily." << std::endl;
- }
- // The graph plugin expects the graph in <logdir>/<run>/<event.file>.
- string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
- TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir));
- EventsWriter event_writer(JoinPath(run_dir, "events"));
- Event event;
- // Add the computation graph.
- event.set_graph_def(response.computation_graph(0).SerializeAsString());
- event_writer.WriteEvent(event);
- *os << "Wrote a HLO graph to " << event_writer.FileName() << std::endl;
-
- if (response.has_hlo_metadata()) {
- tensorflow::TaggedRunMetadata tagged_run_metadata;
- tagged_run_metadata.set_tag(run);
- tagged_run_metadata.set_run_metadata(
- response.hlo_metadata().SerializeAsString());
- tensorflow::Event meta_event;
- *meta_event.mutable_tagged_run_metadata() = tagged_run_metadata;
- event_writer.WriteEvent(meta_event);
- *os << "Wrote HLO ops run metadata to " << event_writer.FileName()
+ if (os) {
+ *os << "Dumped tool data for " << tool.name() << " to " << path
<< std::endl;
}
return Status::OK();
} // namespace
Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+ const string& host,
const ProfileResponse& response,
std::ostream* os) {
// Dumps profile data to <logdir>/plugins/profile/<run>/.
+ string host_prefix = host.empty() ? "" : StrCat(host, ".");
string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
// Ignore computation_graph for now.
if (!response.encoded_trace().empty()) {
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
- TF_RETURN_IF_ERROR(
- DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os));
+ TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
+ response.encoded_trace(), os));
}
if (response.has_op_profile() &&
(response.op_profile().has_by_program_structure() ||
response.op_profile().has_by_category())) {
- TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir,
+ TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
response.op_profile(), os));
}
for (const auto& tool_data : response.tool_data()) {
- TF_RETURN_IF_ERROR(
- DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
+ TF_RETURN_IF_ERROR(DumpToolDataToLogDirectory(profile_run_dir, host_prefix,
+ tool_data, os));
}
return Status::OK();
// Note: this function creates a directory even when all fields in
// ProfileResponse are unset/empty.
Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+ const string& host,
const ProfileResponse& response,
std::ostream* os);
// next-field: 2
}
+message ToolRequestOptions {
+ // Required formats for the tool, it should be one of "json", "proto", "raw"
+ // etc. If not specified (backward compatible), use default format, i.e. most
+ // tools use json format.
+ string output_formats = 2;
+
+ // Whether save the result directly to repository or pass it back to caller.
+ // Default to false for backward compatibilities.
+ bool save_to_repo = 3;
+}
+
message ProfileRequest {
// In future, the caller will be able to customize when profiling starts and
// stops. For now, it collects `duration_ms` milliseconds worth of data.
// events.
uint64 max_events = 2;
- // required profiling tools name such as "input_pipeline_analyzer" etc
+ // Required profiling tools name such as "input_pipeline_analyzer" etc
repeated string tools = 3;
+ // Specifies the requirement for each tools.
+ map<string, ToolRequestOptions> tool_options = 8;
+
// Optional profiling options that control how a TF session will be profiled.
ProfileOptions opts = 4;
// The user provided profile session identifier.
string session_id = 6;
+ // The hostname of system where the profile should happen.
+ // We use it as identifier in part of our output filename.
+ string host_name = 7;
+
// In future, the caller will indicate which TF session is being profiled, and
// only data relating to that program will be returned. For now, we assume
// all activity during the profiling period is relevant.
- // next-field: 7
+ // next-field: 9
}
message ProfileToolData {