"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)
return Literal::CreateFromProto(response.literal());
}
-StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
- LoadComputationSnapshotRequest request;
- *request.mutable_module() = module;
- LoadComputationSnapshotResponse response;
-
- Status s = stub_->LoadComputationSnapshot(&request, &response);
- if (!s.ok()) {
- return s;
- }
-
- VLOG(1) << "load snapshot response: " << response.ShortDebugString();
- return Computation(stub_, response.computation());
-}
-
StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
return XlaComputation(module.hlo().hlo_module());
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
// two computations via a pair of Send and Recv instructions.
StatusOr<ChannelHandle> CreateChannelHandle();
- StatusOr<Computation> LoadSnapshot(const SessionModule& module);
-
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
ServiceInterface* stub() { return stub_; }
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
- "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
- "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:hlo_proto",
- "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:computation_tracker",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
- "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
- "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
- "//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
],
)
//
// Dumps a graphviz URL for a snapshot computation to the command line.
//
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// The GraphViz URL is placed into the log stderr, whereas computation
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
- SessionModule module;
+ HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
- Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ XlaComputation computation =
+ client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
ComputationStats stats =
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
for (char* arg : args) {
- SessionModule session_module;
+ HloSnapshot snapshot;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
- &session_module));
- auto computation_status = client->LoadSnapshot(session_module);
+ &snapshot));
+ auto computation_status = client->LoadSnapshot(snapshot);
if (!computation_status.ok()) {
fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
computation_status.status().ToString().c_str());
continue;
}
- Computation computation = computation_status.ConsumeValueOrDie();
+ XlaComputation computation = computation_status.ConsumeValueOrDie();
std::unique_ptr<ProgramShape> program_shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
- local_service->CompileExecutable(computation.handle(), layouts,
- build_options);
+ local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
for (char* arg : args) {
- SessionModule session_module;
+ HloSnapshot snapshot;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
- &session_module));
- auto computation_status = client->LoadSnapshot(session_module);
+ &snapshot));
+ auto computation_status = client->LoadSnapshot(snapshot);
if (!computation_status.ok()) {
fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
computation_status.status().ToString().c_str());
continue;
}
- Computation computation = computation_status.ConsumeValueOrDie();
+ XlaComputation computation = computation_status.ConsumeValueOrDie();
if (compile) {
std::unique_ptr<ProgramShape> program_shape =
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
- local_service->CompileExecutable(computation.handle(), layouts,
- build_options);
+ local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
local_service->backend().platform()->Name().c_str(),
module.ToString(HloPrintOptions::ShortParsable()).c_str());
} else {
- const ComputationTracker& tracker = local_service->computation_tracker();
- UserComputation* user_computation =
- tracker.Resolve(computation.handle()).ConsumeValueOrDie();
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
+ auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
+ DebugOptions())
+ .ConsumeValueOrDie();
std::unique_ptr<HloModule> module =
- tracker.BuildHloModule(versioned_handle, HloModuleConfig())
+ HloModule::CreateFromProto(computation.proto(), config)
.ConsumeValueOrDie();
fprintf(stdout, "%s\n",
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
- SessionModule module;
+ HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
- Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ XlaComputation computation =
+ client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*");
debug_options.set_xla_hlo_dump_as_graphdef(true);
//
// Replays computations and shows the results on the command line.
//
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// Computations that require arguments can be replayed using fake data by
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
//
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
// otherwise, no infeed is performed.
-template <typename ModuleT>
-StatusOr<std::unique_ptr<Literal>> ReplayComputation(const ModuleT& module,
+StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
Client* client,
const Options& opts) {
- static_assert(std::is_same<ModuleT, HloSnapshot>::value ||
- std::is_same<ModuleT, SessionModule>::value,
- "Proto must be in HloSnapshot or SessionModule format");
TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module));
std::vector<std::unique_ptr<GlobalData>> arguments;
for (char* arg : args) {
HloSnapshot snapshot;
auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
- if (status.ok()) {
- StatusOr<std::unique_ptr<Literal>> result_status =
- ReplayComputation(snapshot, client, opts);
- if (!result_status.ok()) {
- fprintf(stderr, "%s: error: %s\n", arg,
- result_status.status().ToString().c_str());
- exit_status = EXIT_FAILURE;
- continue;
- }
-
- std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
- if (result != nullptr) {
- fprintf(stdout, "%s: %s :: %s:%s\n", arg,
- snapshot.hlo().hlo_module().name().c_str(),
- ShapeUtil::HumanString(result->shape()).c_str(),
- result->ToString().c_str());
- if (snapshot.has_result()) {
- std::unique_ptr<Literal> literal =
- Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
- fprintf(stdout, "was %s:%s\n",
- ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
- literal->ToString().c_str());
- }
- }
-
+ if (!status.ok()) {
+ fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg,
+ status.ToString().c_str());
continue;
}
- fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n",
- arg, status.ToString().c_str());
-
- SessionModule module;
- TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
StatusOr<std::unique_ptr<Literal>> result_status =
- ReplayComputation(module, client, opts);
+ ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
if (result != nullptr) {
- fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
+ fprintf(stdout, "%s: %s :: %s:%s\n", arg,
+ snapshot.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(result->shape()).c_str(),
result->ToString().c_str());
- if (module.has_result()) {
+ if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
- Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
+ Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
- ShapeUtil::HumanString(module.result().shape()).c_str(),
+ ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
literal->ToString().c_str());
}
}
// Shows the signature (ProgramShape) of binary snapshot proto(s) on the command
// line.
//
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
// ServiceInterface::SnapshotComputation to disk.
//
// The output format is:
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
- SessionModule module;
+ HloSnapshot module;
TF_CHECK_OK(
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
- Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ auto computation = client->LoadSnapshot(module).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> shape =
client->GetComputationShape(computation).ConsumeValueOrDie();
- fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(),
+ fprintf(stdout, "%s: %s :: %s\n", arg,
+ module.hlo().hlo_module().name().c_str(),
ShapeUtil::HumanString(*shape).c_str());
}
}