bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
- bool xla_hlo_profile_last_run = false;
};
// Invokes the given computation passing arbitrary data for every (unbound)
//
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
// no infeed is performed.
-StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
- Client* client,
- const Options& opts) {
+StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
+ LocalClient* client, const Options& opts) {
XlaComputation computation(module.hlo().hlo_module());
- std::vector<std::unique_ptr<GlobalData>> arguments;
+ // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
+ // arguments. This is a bit involved, because we may have to convert from
+ // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
+ // objects.
+ std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
+ std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
+ std::vector<const ShapedBuffer*> argument_ptrs;
if (opts.use_fake_data) {
- arguments = MakeFakeArgumentsOrDie(computation, client);
+ global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
+ for (const auto& data : global_data_arguments) {
+ argument_ptrs.push_back(
+ client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
+ .ValueOrDie());
+ }
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
Literal::CreateFromProto(proto));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
- client->TransferToServer(*literal));
- arguments.push_back(std::move(data));
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer data,
+ client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ scoped_shaped_buffer_arguments.push_back(std::move(data));
+ }
+ for (const auto& argument : scoped_shaped_buffer_arguments) {
+ argument_ptrs.push_back(&argument);
}
}
});
}
- std::vector<GlobalData*> execute_arguments;
- execute_arguments.reserve(arguments.size());
- for (auto& argument : arguments) {
- execute_arguments.push_back(argument.get());
+ std::vector<const Shape*> argument_layouts;
+ for (const auto& param : computation.proto().program_shape().parameters()) {
+ argument_layouts.push_back(¶m);
}
+ std::unique_ptr<LocalExecutable> executable =
+ client->Compile(computation, argument_layouts, ExecutableBuildOptions())
+ .ValueOrDie();
// Run the computation num_runs times, and return the result from the last
// execution.
- std::unique_ptr<Literal> result;
+ StreamExecutorMemoryAllocator allocator(
+ client->platform(),
+ {client->platform()->ExecutorForDevice(0).ValueOrDie()});
+ tensorflow::gtl::optional<ScopedShapedBuffer> result;
for (int i = 0; i < opts.num_runs; ++i) {
ExecutionProfile profile;
- ExecutionOptions execution_options = CreateDefaultExecutionOptions();
- if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
- execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
- }
+ ExecutableRunOptions run_options;
+ run_options.set_execution_profile(&profile);
+ run_options.set_allocator(&allocator);
- if (opts.print_result) {
- TF_ASSIGN_OR_RETURN(
- result, client->ExecuteAndTransfer(computation, execute_arguments,
- &execution_options, &profile));
- } else {
- // If we're not printing the result, execute the computation but don't
- // bother retrieving the result. This can be a significant speedup.
- TF_RETURN_IF_ERROR(client
- ->Execute(computation, execute_arguments,
- &execution_options, &profile)
- .status());
- }
+ TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
LOG(INFO) << "Execution took "
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
}
- return std::move(result);
+ // Check that --num_runs > 0, otherwise *result below will fail with an
+ // unhelpful error (because the loop didn't run any iterations).
+ CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ client->ShapedBufferToLiteral(*result));
+ return std::move(*result_literal);
}
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
- Client* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
CHECK(opts.use_fake_data)
<< "HloProto input must be handled with --use_fake_data";
}
- StatusOr<std::unique_ptr<Literal>> result_status =
- ReplayComputation(snapshot, client, opts);
+
+ StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
continue;
}
- std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
- if (result != nullptr) {
+ if (opts.print_result) {
+ Literal result = std::move(result_status).ValueOrDie();
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
snapshot.hlo().hlo_module().name().c_str(),
- ShapeUtil::HumanString(result->shape()).c_str(),
- result->ToString().c_str());
+ ShapeUtil::HumanString(result.shape()).c_str(),
+ result.ToString().c_str());
if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
"Whether a fake infeed shape should be generated "
"derived from the computation"),
- tensorflow::Flag(
- "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
- "Pass --xla_hlo_profile the last time we run the computation."),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);