#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h"
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
+#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h"
#include "tensorflow/contrib/tpu/profiler/version.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tpu {
namespace {
+using ::tensorflow::grpc::TPUProfileAnalysis;
using ::tensorflow::TPUProfiler;
constexpr uint64 kMaxEvents = 1000000;
return Status::OK();
}
-// Returns whether the returned trace is empty.
-// Failure are handled by CHECK, i.e. abort()
-bool Profile(const string& service_addr, const string& logdir, int duration_ms,
- const string& repository_root, const string& session_id,
- const ProfileOptions& opts) {
+ProfileRequest PopulateProfileRequest(int duration_ms,
+ const string& repository_root,
+ const string& session_id,
+ const ProfileOptions& opts) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
*request.mutable_opts() = opts;
std::cout << "Limiting the number of trace events to " << kMaxEvents
<< std::endl;
+ return request;
+}
+
+// Returns whether the returned trace is empty.
+// Failure are handled by CHECK, i.e. abort()
+bool Profile(const string& service_addr, const string& logdir, int duration_ms,
+ const string& repository_root, const string& session_id,
+ const ProfileOptions& opts) {
+ ProfileRequest request =
+ PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
+
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
// TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available.
const std::vector<tensorflow::string>& hostnames,
int duration_ms, const string& repository_root,
const string& session_id, const ProfileOptions& opts) {
- return true;
+ NewProfileSessionRequest new_session_request;
+ *new_session_request.mutable_request() =
+ PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
+ new_session_request.set_repository_root(repository_root);
+ new_session_request.set_session_id(session_id);
+ std::copy(
+ hostnames.begin(), hostnames.end(),
+ proto2::RepeatedFieldBackInserter(new_session_request.mutable_hosts()));
+
+ ::grpc::ClientContext context;
+ ::grpc::ChannelArguments channel_args;
+ // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
+ // `ValidateHostPortPair` checks for empty host string case.
+ channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
+ // TODO(jiesun): GRPC support following relevant naming scheme:
+ // 1. dns:///host:port
+ // 2. ipv4:host:port or ipv6:[host]:port
+ // We might need to change the prefix which depends on what TPU name resolver
+ // will give us.
+ std::unique_ptr<TPUProfileAnalysis::Stub> stub =
+ TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel(
+ "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
+ channel_args));
+ NewProfileSessionResponse new_session_response;
+ TF_QCHECK_OK(FromGrpcStatus(
+ stub->NewSession(&context, new_session_request, &new_session_response)));
+
+ std::cout << "Profile session succeed for hosts:"
+ << str_util::Join(hostnames, ",");
+ return new_session_response.empty_trace();
}
} // namespace