R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
+ const string include_hlo_profile_printer_data_proto =
+ opts.gen_hlo_profile_printer_data
+ ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
+ : "";
+
+ // When HLO profiling is disabled we only forward declare the
+ // HloProfilePrinter protobuf. So we can only conditionally emit this code
+ // calling HloProfilePrinter::profile_counters_size.
+ const string assign_profile_counters_size =
+ opts.gen_hlo_profile_printer_data
+ ? "data->profile_counters_size = "
+ "data->hlo_profile_printer_data->profile_counters_size();"
+ : "";
+
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
{{INCLUDE_XLA_DATA_PROTO}}
+{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
+ data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+ {{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
return *kStaticData;
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
+
+ // Metadata that can be used to pretty-print profile counters.
+ static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+ static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+ {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
+ return kHloProfilePrinterData;
+ }
};
{{NS_END}}
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+ {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
+ {"{{DECLS_FROM_OBJ_FILE}}",
+ str_util::Join(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
+ {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
+ metadata_result.hlo_profile_printer_data_access_shim},
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
+ {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
+ include_hlo_profile_printer_data_proto},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
+ metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
- {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
- {"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
- {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
- metadata_result.program_shape_access_shim}};
+ {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
-static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) {
+static string CreateUniqueIdentifier(const CodegenOpts& opts,
+ StringPiece suffix) {
string result = "__tfcompile";
for (const string& n : opts.namespaces) {
strings::StrAppend(&result, "_", n);
}
- strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape");
+ strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
return result;
}
// When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
// a shim that evaluates to nullptr, which is what we want.
+ ProtobufToEmbed program_shape_protobuf{
+ CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
+ program_shape.get()};
+
+ ProtobufToEmbed hlo_profile_printer_data_protobuf{
+ CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
+ "xla::HloProfilePrinterData",
+ compile_result.aot->hlo_profile_printer_data()};
+
TF_ASSIGN_OR_RETURN(
- EmbeddedProtocolBuffer embedded_program_shape,
- CreateEmbeddedProtocolBuffer(opts.target_triple,
- CreateUniqueIdentifierForProgramShape(opts),
- "xla::ProgramShape", program_shape.get()));
+ EmbeddedProtocolBuffers embedded_protobufs,
+ CreateEmbeddedProtocolBuffers(
+ opts.target_triple,
+ {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
metadata_result->program_shape_access_shim =
- std::move(embedded_program_shape.cpp_shim_expression);
+ std::move(embedded_protobufs.cpp_shims[0].expression);
+ metadata_result->hlo_profile_printer_data_access_shim =
+ std::move(embedded_protobufs.cpp_shims[1].expression);
+ metadata_result->header_variable_decls.emplace_back(
+ std::move(embedded_protobufs.cpp_shims[0].variable_decl));
metadata_result->header_variable_decls.emplace_back(
- std::move(embedded_program_shape.cpp_variable_decl));
+ std::move(embedded_protobufs.cpp_shims[1].variable_decl));
metadata_result->object_file_data =
- std::move(embedded_program_shape.object_file_data);
+ std::move(embedded_protobufs.object_file_data);
return Status::OK();
}
// If true, generate program shape data for the ProgramShape method.
bool gen_program_shape = false;
+
+ // If true, emit a serialized HloProfilePrinterData protobuf that can be used
+ // to pretty print HLO profile counters.
+ bool gen_hlo_profile_printer_data = false;
};
// Describes a generated metadata object file.
// GenerateMetadata.
string program_shape_access_shim;
+ // hlo_profile_printer_data_access_shim is a C++ expression that constructs
+ // the xla::HloProfilePrinterData instance for the CompileResult passed to
+ // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a
+ // C++ expression that evaluates to nullptr at runtime.
+ string hlo_profile_printer_data_access_shim;
+
// The contents of the object (".o") file.
string object_file_data;
};
fetch->set_name("myfetch");
CompileResult compile_result;
compile_result.aot.reset(
- new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
+ new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#include "tensorflow/compiler/xla/xla_data.pb.h"
+
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
+
namespace foo {
namespace bar {
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
+ data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+
return data;
}();
return *kStaticData;
}();
return kShape;
}
+
+ // Metadata that can be used to pretty-print profile counters.
+ static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+ static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+ nullptr;
+ return kHloProfilePrinterData;
+ }
};
} // end namespace bar
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
+
return CompileXla(client, computation, aot_opts, compile_result);
}
using xla::llvm_ir::AsStringRef;
-static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
- llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine,
- const ::tensorflow::protobuf::MessageLite& proto,
+static void AddEmbeddedProtocolBufferToLlvmModule(
+ llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
StringPiece unique_identifier, string* protobuf_array_symbol_name,
int64* protobuf_array_size) {
string protobuf_array_contents = proto.SerializeAsString();
strings::StrCat(unique_identifier, "_protobuf_array_contents");
*protobuf_array_size = protobuf_array_contents.size();
- std::unique_ptr<llvm::Module> module =
- MakeUnique<llvm::Module>("embedded_data_module", *llvm_context);
-
llvm::Constant* protobuf_array_initializer =
- llvm::ConstantDataArray::getString(*llvm_context,
+ llvm::ConstantDataArray::getString(module->getContext(),
AsStringRef(protobuf_array_contents),
/*AddNull=*/false);
new llvm::GlobalVariable(
*module, protobuf_array_initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
-
- return module;
}
static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
/*Features=*/"", llvm::TargetOptions(), llvm::None));
}
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
- StringPiece target_triple, StringPiece symbol_prefix,
- StringPiece qualified_cpp_protobuf_name,
- const ::tensorflow::protobuf::MessageLite* proto) {
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+ StringPiece target_triple,
+ gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
GetTargetMachineFromTriple(target_triple));
llvm::LLVMContext llvm_context;
- string object_file, cpp_shim, cpp_variable_decl;
-
- if (proto) {
- string protobuf_array_symbol_name;
- int64 protobuf_array_size;
-
- std::unique_ptr<llvm::Module> module_with_serialized_proto =
- CreateModuleWithEmbeddedProtocolBuffer(
- &llvm_context, target_machine.get(), *proto, symbol_prefix,
- &protobuf_array_symbol_name, &protobuf_array_size);
- TF_ASSIGN_OR_RETURN(object_file,
- CodegenModule(target_machine.get(),
- std::move(module_with_serialized_proto)));
- cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name,
- protobuf_array_symbol_name,
- protobuf_array_size);
-
- cpp_variable_decl = strings::StrCat("extern \"C\" char ",
- protobuf_array_symbol_name, "[];");
- } else {
- TF_ASSIGN_OR_RETURN(
- object_file,
- CodegenModule(target_machine.get(),
- MakeUnique<llvm::Module>("empty_module", llvm_context)));
- cpp_shim = "nullptr";
+ std::unique_ptr<llvm::Module> module_with_serialized_proto =
+ MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
+
+ EmbeddedProtocolBuffers result;
+
+ for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
+ string cpp_shim, cpp_variable_decl;
+ if (protobuf_to_embed.message) {
+ string protobuf_array_symbol_name;
+ int64 protobuf_array_size;
+
+ AddEmbeddedProtocolBufferToLlvmModule(
+ module_with_serialized_proto.get(), *protobuf_to_embed.message,
+ protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
+ &protobuf_array_size);
+ cpp_shim = CreateCPPShimExpression(
+ protobuf_to_embed.qualified_cpp_protobuf_name,
+ protobuf_array_symbol_name, protobuf_array_size);
+
+ cpp_variable_decl = strings::StrCat("extern \"C\" char ",
+ protobuf_array_symbol_name, "[];");
+ } else {
+ cpp_shim = "nullptr";
+ }
+ result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
}
- return {{cpp_shim, cpp_variable_decl, object_file}};
+ TF_ASSIGN_OR_RETURN(result.object_file_data,
+ CodegenModule(target_machine.get(),
+ std::move(module_with_serialized_proto)));
+ return result;
}
} // namespace tfcompile
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace tfcompile {
using xla::StatusOr;
-// Represents a protocol buffer embedded into an object file and describes a way
-// to access it at runtime.
-struct EmbeddedProtocolBuffer {
- // cpp_shim_expression is a C++ expression that creates an instance of said
- // protocol buffer when executed.
- string cpp_shim_expression;
-
- // cpp_variable_decl is an "extern C" array declaration that is used in
- // cpp_shim_expression. It must be visible wherever cpp_shim_expression is
- // emitted.
- string cpp_variable_decl;
-
- // The contents of the object (".o") file the protocol buffer is embbed in.
- // This needs to be linked in to any program that wants to execute
- // cpp_variable_decl .
+// Represents a set of protocol buffers embedded into an object file and
+// describes how to access them at runtime.
+struct EmbeddedProtocolBuffers {
+ // Each instance CPPShim describes how to generate C++ code to instantiate a
+ // protobuf instance from the corresponding static data emitted into the
+ // object file.
+ struct CPPShim {
+ // `expression` is a C++ expression that creates an instance of said
+ // protocol buffer when executed.
+ string expression;
+
+ // `variable_decl` is an "extern C" array declaration that is used in
+ // `expression`. It must be visible wherever `expression` is emitted.
+ string variable_decl;
+ };
+
+ // Each cpp_shim corresponds to one embedded protocol buffer.
+ std::vector<CPPShim> cpp_shims;
+
+ // The contents of the object (".o") file the protocol buffers are embbed in.
+ // This needs to be linked in to any program that wants to execute any of the
+ // expressions in `cpp_shims`.
string object_file_data;
};
-// Creates an object file that contains `proto`.
-//
-// `proto` is allowed to be nullptr, in which case the generated C++ shim
-// expression is just `nullptr`, and the generated object file does not define
-// any symbols.
+// Describes a protocol buffer to embed into an object file.
+struct ProtobufToEmbed {
+ // `symbol_prefix` is prefix that is guaranteed to be unique across the binary
+ // or DSO the generated object file will be linked into.
+ string symbol_prefix;
+
+ // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
+ // namespace qualified) protocol buffer name. This is only used in
+ // CPPShim::expression so relatively qualified names are fine as long as
+ // they're valid wherever CPPShim::expression is emitted.
+ string qualified_cpp_protobuf_name;
+
+ // `message` is the protocol buffer to be embedded. It is allowed to be
+ // nullptr, in which case the generated C++ shim expression is just `nullptr`,
+ // and the generated object file does not define any symbols.
+ const ::tensorflow::protobuf::MessageLite* message;
+};
+
+// Embeds a a sequence of protocol buffers into an object file.
//
// `target_triple` is the target triple for the target architecture for the
// generated object file.
//
-// `symbol_prefix` is prefix that is guaranteed to be unique across the binary
-// or DSO the generated object file will be linked into.
-//
-// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
-// namespace qualified) protocol buffer name. This needs is only used in
-// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified
-// names are fine as long as they're valid wherever cpp_shim_expression
-// is emitted.
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
- StringPiece target_triple, StringPiece symbol_prefix,
- StringPiece qualified_cpp_protobuf_name,
- const ::tensorflow::protobuf::MessageLite* proto);
+// `protobufs_to_embed` describes the protocol buffers to embed into the
+// resulting object file. The C++ shim for protobufs_to_embed[i] is
+// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents
+// of all the protocol buffers are embedded into a single .o file whose content
+// is stored in the object_file_data field in the returned
+// EmbeddedProtocolBuffers instance.
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+ StringPiece target_triple,
+ gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
} // namespace tfcompile
} // namespace tensorflow
)
tf_library(
+ name = "test_graph_tfmatmulandadd_with_profiling",
+ testonly = 1,
+ config = "test_graph_tfmatmulandadd.config.pbtxt",
+ cpp_class = "MatMulAndAddCompWithProfiling",
+ enable_xla_hlo_profiling = True,
+ graph = "test_graph_tfmatmulandadd.pb",
+)
+
+tf_library(
name = "test_graph_tfsplits",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
+ ":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_profile_printer",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
+using ::testing::HasSubstr;
+using ::testing::UnorderedElementsAre;
+
TEST(TFCompileTest, Add) {
AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
}
+TEST(TFCompileTest, HloProfiling) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ MatMulAndAddCompWithProfiling fn;
+ ASSERT_TRUE(fn.hlo_profiling_enabled());
+
+ fn.set_thread_pool(&device);
+
+ // x = [[1, 2], [3, 4]]
+ fn.arg0(0, 0) = 1;
+ fn.arg0(0, 1) = 2;
+ fn.arg0(1, 0) = 3;
+ fn.arg0(1, 1) = 4;
+
+ // y = [[10, 20], [30, 40]]
+ fn.arg1(0, 0) = 10;
+ fn.arg1(0, 1) = 20;
+ fn.arg1(1, 0) = 30;
+ fn.arg1(1, 1) = 40;
+
+ EXPECT_TRUE(fn.Run());
+
+ string hlo_profile_as_string =
+ xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
+ /*clock_rate_ghz=*/1.0);
+ VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+
+ std::vector<string> hlo_profile_lines =
+ tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+
+ auto header = HasSubstr("Execution profile for");
+ auto total_cycles_profile_line = HasSubstr("[total]");
+ auto dot_profile_line = HasSubstr(
+ "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ auto add_profile_line = HasSubstr(
+ "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ auto tuple_profile_line = HasSubstr(
+ "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+ "f32[2,2]{1,0} %add)");
+ auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
+
+ hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
+ hlo_profile_lines.end());
+
+ EXPECT_THAT(
+ hlo_profile_lines,
+ UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line,
+ add_profile_line, tuple_profile_line,
+ arg0_profile_line, arg1_profile_line));
+}
+
} // namespace
} // namespace tfcompile
} // namespace tensorflow
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
- include_standard_runtime_deps=True, deps=None, tags=None):
+ include_standard_runtime_deps=True,
+ enable_xla_hlo_profiling=False, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Given an invocation of tf_library(name="foo", ...), generates the following
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
+ enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
+ and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
+ if enable_xla_hlo_profiling:
+ profiling_flag = "--xla_hlo_profile"
+ else:
+ profiling_flag = ""
native.genrule(
name=("gen_" + name),
srcs=[
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
- " " + flags),
+ " " + flags + " " + profiling_flag),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
+ ] or []) + (enable_xla_hlo_profiling and [
+ "//tensorflow/compiler/xla/service:hlo_profile_printer_data"
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
+ codegen_opts.gen_hlo_profile_printer_data =
+ xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
CpuAotCompilationResult::CpuAotCompilationResult(
ObjectFileData object_file_data, BufferSizes buffer_sizes,
- int64 result_buffer_index)
+ int64 result_buffer_index,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
: object_file_data_(std::move(object_file_data)),
buffer_sizes_(std::move(buffer_sizes)),
- result_buffer_index_(result_buffer_index) {}
+ result_buffer_index_(result_buffer_index),
+ hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
CpuAotCompilationResult::~CpuAotCompilationResult() = default;
public:
static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(
- HloComputation* computation,
+ const HloComputation& computation,
const std::unordered_map<const HloInstruction*, int64>&
assigned_indices) {
std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
CollectProfileCandidates profile_candidates_for_computation(
&hlo_to_profile_idx, assigned_indices);
- TF_RETURN_IF_ERROR(
- computation->Accept(&profile_candidates_for_computation));
+ TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
return hlo_to_profile_idx;
}
return Status::OK();
}
+Status CreateHloProfilingArtifacts(
+ const HloModule& module,
+ std::unordered_map<const HloInstruction*, int64>*
+ instruction_to_profile_idx,
+ std::unordered_map<const HloComputation*, int64>*
+ computation_to_profile_idx,
+ std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
+ std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
+ *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module);
+ const HloComputation& entry_computation = *module.entry_computation();
+
+ TF_ASSIGN_OR_RETURN(
+ *instruction_to_profile_idx,
+ CollectProfileCandidates::GetCandidatesForComputation(
+ entry_computation,
+ (*hlo_profile_index_map)->instruction_to_profile_idx()));
+
+ auto shape_size_bytes = [](const Shape& shape) {
+ // On the cpu, opaques are pointers.
+ if (ShapeUtil::IsOpaque(shape)) {
+ return static_cast<int64>(sizeof(void*));
+ }
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+ };
+
+ HloCostAnalysis cost_analysis(shape_size_bytes);
+ TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
+ *hlo_profile_printer_data =
+ CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis);
+ *computation_to_profile_idx =
+ (*hlo_profile_index_map)->computation_to_profile_idx();
+
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
if (module->config().hlo_profiling_enabled()) {
- hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
-
- TF_ASSIGN_OR_RETURN(
- instruction_to_profile_idx,
- CollectProfileCandidates::GetCandidatesForComputation(
- entry_computation,
- hlo_profile_index_map->instruction_to_profile_idx()));
-
- auto shape_size_bytes = [](const Shape& shape) {
- // On the cpu, opaques are pointers.
- if (ShapeUtil::IsOpaque(shape)) {
- return static_cast<int64>(sizeof(void*));
- }
- return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
- };
-
- HloCostAnalysis cost_analysis(shape_size_bytes);
- TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
- hlo_profile_printer_data =
- CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis);
- computation_to_profile_idx =
- hlo_profile_index_map->computation_to_profile_idx();
+ TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
+ *module, &instruction_to_profile_idx, &computation_to_profile_idx,
+ &hlo_profile_index_map, &hlo_profile_printer_data));
}
std::unique_ptr<Executable> cpu_executable;
proto, xla_dump_optimized_hlo_proto_to, module->name()));
}
+ std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
+ std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
+ std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
+
+ if (module->config().hlo_profiling_enabled()) {
+ TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
+ *module, &instruction_to_profile_idx, &computation_to_profile_idx,
+ &hlo_profile_index_map, &hlo_profile_printer_data));
+ }
+
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
- /*instruction_to_profile_idx=*/
- std::unordered_map<const HloInstruction*, int64>{},
- /*computation_to_profile_idx=*/
- std::unordered_map<const HloComputation*, int64>{},
+ std::move(instruction_to_profile_idx),
+ std::move(computation_to_profile_idx),
target_machine.get(),
/*external_constant_pool=*/nullptr);
HloComputation* computation = module->entry_computation();
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_sizes),
- result_slice.index()));
+ result_slice.index(), std::move(hlo_profile_printer_data)));
}
VLOG(1) << "Compilation finished";
class CpuAotCompilationResult : public AotCompilationResult {
public:
- CpuAotCompilationResult(ObjectFileData object_file_data,
- BufferSizes buffer_sizes, int64 result_buffer_index);
+ CpuAotCompilationResult(
+ ObjectFileData object_file_data, BufferSizes buffer_sizes,
+ int64 result_buffer_index,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
~CpuAotCompilationResult();
+ HloProfilePrinterData* hlo_profile_printer_data() const {
+ return hlo_profile_printer_data_.get();
+ }
+
const ObjectFileData& object_file_data() const { return object_file_data_; }
const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
int64 result_buffer_index() const { return result_buffer_index_; }
// result of the computation. This buffer should be passed into the output
// parameter when calling the compiled computation.
const int64 result_buffer_index_;
+
+ // Contains an instance of HloProfilePrinterData if HLO profiling is enabled,
+ // otherwise is nullptr.
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
};
// CPU-targeting implementation of the XLA Compiler interface.