HLO profiling for tfcompile.
authorSanjoy Das <sanjoy@google.com>
Sat, 28 Apr 2018 03:06:35 +0000 (20:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 28 Apr 2018 03:09:38 +0000 (20:09 -0700)
This CL extends the --xla_hlo_profile knob to tfcompile.  tf_library rules can
now set enable_xla_hlo_profiling to True to:

 - Have the generated code update per-HLO profile counters as it executes.
 - Have tfcompile generate and serialize an instance HloProfilePrinterData with
   a compiled model that can be used to pretty-print the collected profile
   counters.

PiperOrigin-RevId: 194627272

13 files changed:
tensorflow/compiler/aot/codegen.cc
tensorflow/compiler/aot/codegen.h
tensorflow/compiler/aot/codegen_test.cc
tensorflow/compiler/aot/codegen_test_h.golden
tensorflow/compiler/aot/compile.cc
tensorflow/compiler/aot/embedded_protocol_buffers.cc
tensorflow/compiler/aot/embedded_protocol_buffers.h
tensorflow/compiler/aot/tests/BUILD
tensorflow/compiler/aot/tests/tfcompile_test.cc
tensorflow/compiler/aot/tfcompile.bzl
tensorflow/compiler/aot/tfcompile_main.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.h

index 2cae85e..0025842 100644 (file)
@@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
           R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
           : "";
 
+  const string include_hlo_profile_printer_data_proto =
+      opts.gen_hlo_profile_printer_data
+          ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
+          : "";
+
+  // When HLO profiling is disabled we only forward declare the
+  // HloProfilePrinter protobuf.  So we can only conditionally emit this code
+  // calling HloProfilePrinter::profile_counters_size.
+  const string assign_profile_counters_size =
+      opts.gen_hlo_profile_printer_data
+          ? "data->profile_counters_size = "
+            "data->hlo_profile_printer_data->profile_counters_size();"
+          : "";
+
   // Use a poor-man's text templating mechanism; first populate the full header
   // with placeholder tokens, and then rewrite the tokens with real values.
   *header =
@@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
 #define TFCOMPILE_GENERATED_{{ENTRY}}_H_  // NOLINT(build/header_guard)
 
 {{INCLUDE_XLA_DATA_PROTO}}
+{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
       data->arg_names = StaticArgNames();
       data->result_names = StaticResultNames();
       data->program_shape = StaticProgramShape();
+      data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+      {{ASSIGN_PROFILE_COUNTERS_SIZE}}
       return data;
     }();
     return *kStaticData;
@@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
     static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
     return kShape;
   }
+
+  // Metadata that can be used to pretty-print profile counters.
+  static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+    static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+      {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
+    return kHloProfilePrinterData;
+  }
 };
 {{NS_END}}
 
@@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
       {"{{ARG_NAMES_CODE}}", arg_names_code},
       {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
       {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+      {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
       {"{{CLASS}}", opts.class_name},
+      {"{{DECLS_FROM_OBJ_FILE}}",
+       str_util::Join(metadata_result.header_variable_decls, "\n")},
       {"{{ENTRY}}", compile_result.entry_point},
+      {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
+       metadata_result.hlo_profile_printer_data_access_shim},
       {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
+      {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
+       include_hlo_profile_printer_data_proto},
       {"{{METHODS_ARG}}\n", methods_arg},
       {"{{METHODS_RESULT}}\n", methods_result},
       {"{{NS_END}}\n", ns_end},
       {"{{NS_START}}\n", ns_start},
       {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+      {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
+       metadata_result.program_shape_access_shim},
       {"{{RESULT_INDEX}}", strings::StrCat(result_index)},
       {"{{RESULT_NAMES_CODE}}", result_names_code},
       {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
       {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
       {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
-      {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
-      {"{{DECLS_FROM_OBJ_FILE}}",
-       str_util::Join(metadata_result.header_variable_decls, "\n")},
-      {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
-       metadata_result.program_shape_access_shim}};
+      {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
   str_util::ReplaceAllPairs(header, rewrites);
   return Status::OK();
 }
 
-static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) {
+static string CreateUniqueIdentifier(const CodegenOpts& opts,
+                                     StringPiece suffix) {
   string result = "__tfcompile";
   for (const string& n : opts.namespaces) {
     strings::StrAppend(&result, "_", n);
   }
 
-  strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape");
+  strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
   return result;
 }
 
@@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts,
   // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
   // a shim that evaluates to nullptr, which is what we want.
 
+  ProtobufToEmbed program_shape_protobuf{
+      CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
+      program_shape.get()};
+
+  ProtobufToEmbed hlo_profile_printer_data_protobuf{
+      CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
+      "xla::HloProfilePrinterData",
+      compile_result.aot->hlo_profile_printer_data()};
+
   TF_ASSIGN_OR_RETURN(
-      EmbeddedProtocolBuffer embedded_program_shape,
-      CreateEmbeddedProtocolBuffer(opts.target_triple,
-                                   CreateUniqueIdentifierForProgramShape(opts),
-                                   "xla::ProgramShape", program_shape.get()));
+      EmbeddedProtocolBuffers embedded_protobufs,
+      CreateEmbeddedProtocolBuffers(
+          opts.target_triple,
+          {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
 
   metadata_result->program_shape_access_shim =
-      std::move(embedded_program_shape.cpp_shim_expression);
+      std::move(embedded_protobufs.cpp_shims[0].expression);
+  metadata_result->hlo_profile_printer_data_access_shim =
+      std::move(embedded_protobufs.cpp_shims[1].expression);
+  metadata_result->header_variable_decls.emplace_back(
+      std::move(embedded_protobufs.cpp_shims[0].variable_decl));
   metadata_result->header_variable_decls.emplace_back(
-      std::move(embedded_program_shape.cpp_variable_decl));
+      std::move(embedded_protobufs.cpp_shims[1].variable_decl));
   metadata_result->object_file_data =
-      std::move(embedded_program_shape.object_file_data);
+      std::move(embedded_protobufs.object_file_data);
   return Status::OK();
 }
 
index 3430b1f..83f2d3e 100644 (file)
@@ -44,6 +44,10 @@ struct CodegenOpts {
 
   // If true, generate program shape data for the ProgramShape method.
   bool gen_program_shape = false;
+
+  // If true, emit a serialized HloProfilePrinterData protobuf that can be used
+  // to pretty print HLO profile counters.
+  bool gen_hlo_profile_printer_data = false;
 };
 
 // Describes a generated metadata object file.
@@ -57,6 +61,12 @@ struct MetadataResult {
   // GenerateMetadata.
   string program_shape_access_shim;
 
+  // hlo_profile_printer_data_access_shim is a C++ expression that constructs
+  // the xla::HloProfilePrinterData instance for the CompileResult passed to
+  // GenerateMetadata.  If the xla::HloProfilePrinterData is null then this is a
+  // C++ expression that evaluates to nullptr at runtime.
+  string hlo_profile_printer_data_access_shim;
+
   // The contents of the object (".o") file.
   string object_file_data;
 };
index 2642536..29bc9c1 100644 (file)
@@ -172,7 +172,7 @@ TEST(CodegenTest, Golden) {
   fetch->set_name("myfetch");
   CompileResult compile_result;
   compile_result.aot.reset(
-      new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
+      new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
   compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
       {
           xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
index ac3b587..6e050cf 100644 (file)
@@ -10,6 +10,7 @@
 #define TFCOMPILE_GENERATED_entry_point_H_  // NOLINT(build/header_guard)
 
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+
 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -23,6 +24,7 @@ extern "C" void entry_point(
 
 extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
 
+
 namespace foo {
 namespace bar {
 
@@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
       data->arg_names = StaticArgNames();
       data->result_names = StaticResultNames();
       data->program_shape = StaticProgramShape();
+      data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+      
       return data;
     }();
     return *kStaticData;
@@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
   }();
     return kShape;
   }
+
+  // Metadata that can be used to pretty-print profile counters.
+  static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+    static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+      nullptr;
+    return kHloProfilePrinterData;
+  }
 };
 
 }  // end namespace bar
index e17a7c4..31044ff 100644 (file)
@@ -110,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
       flags.target_triple, flags.target_cpu, flags.target_features,
       flags.entry_point,
       xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
+
   return CompileXla(client, computation, aot_opts, compile_result);
 }
 
index 0048eec..63d22de 100644 (file)
@@ -36,9 +36,8 @@ namespace tfcompile {
 
 using xla::llvm_ir::AsStringRef;
 
-static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
-    llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine,
-    const ::tensorflow::protobuf::MessageLite& proto,
+static void AddEmbeddedProtocolBufferToLlvmModule(
+    llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
     StringPiece unique_identifier, string* protobuf_array_symbol_name,
     int64* protobuf_array_size) {
   string protobuf_array_contents = proto.SerializeAsString();
@@ -46,19 +45,14 @@ static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
       strings::StrCat(unique_identifier, "_protobuf_array_contents");
   *protobuf_array_size = protobuf_array_contents.size();
 
-  std::unique_ptr<llvm::Module> module =
-      MakeUnique<llvm::Module>("embedded_data_module", *llvm_context);
-
   llvm::Constant* protobuf_array_initializer =
-      llvm::ConstantDataArray::getString(*llvm_context,
+      llvm::ConstantDataArray::getString(module->getContext(),
                                          AsStringRef(protobuf_array_contents),
                                          /*AddNull=*/false);
   new llvm::GlobalVariable(
       *module, protobuf_array_initializer->getType(),
       /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
       protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
-
-  return module;
 }
 
 static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
@@ -115,42 +109,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
       /*Features=*/"", llvm::TargetOptions(), llvm::None));
 }
 
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
-    StringPiece target_triple, StringPiece symbol_prefix,
-    StringPiece qualified_cpp_protobuf_name,
-    const ::tensorflow::protobuf::MessageLite* proto) {
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+    StringPiece target_triple,
+    gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
   TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
                       GetTargetMachineFromTriple(target_triple));
 
   llvm::LLVMContext llvm_context;
-  string object_file, cpp_shim, cpp_variable_decl;
-
-  if (proto) {
-    string protobuf_array_symbol_name;
-    int64 protobuf_array_size;
-
-    std::unique_ptr<llvm::Module> module_with_serialized_proto =
-        CreateModuleWithEmbeddedProtocolBuffer(
-            &llvm_context, target_machine.get(), *proto, symbol_prefix,
-            &protobuf_array_symbol_name, &protobuf_array_size);
-    TF_ASSIGN_OR_RETURN(object_file,
-                        CodegenModule(target_machine.get(),
-                                      std::move(module_with_serialized_proto)));
-    cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name,
-                                       protobuf_array_symbol_name,
-                                       protobuf_array_size);
-
-    cpp_variable_decl = strings::StrCat("extern \"C\" char ",
-                                        protobuf_array_symbol_name, "[];");
-  } else {
-    TF_ASSIGN_OR_RETURN(
-        object_file,
-        CodegenModule(target_machine.get(),
-                      MakeUnique<llvm::Module>("empty_module", llvm_context)));
-    cpp_shim = "nullptr";
+  std::unique_ptr<llvm::Module> module_with_serialized_proto =
+      MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
+
+  EmbeddedProtocolBuffers result;
+
+  for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
+    string cpp_shim, cpp_variable_decl;
+    if (protobuf_to_embed.message) {
+      string protobuf_array_symbol_name;
+      int64 protobuf_array_size;
+
+      AddEmbeddedProtocolBufferToLlvmModule(
+          module_with_serialized_proto.get(), *protobuf_to_embed.message,
+          protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
+          &protobuf_array_size);
+      cpp_shim = CreateCPPShimExpression(
+          protobuf_to_embed.qualified_cpp_protobuf_name,
+          protobuf_array_symbol_name, protobuf_array_size);
+
+      cpp_variable_decl = strings::StrCat("extern \"C\" char ",
+                                          protobuf_array_symbol_name, "[];");
+    } else {
+      cpp_shim = "nullptr";
+    }
+    result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
   }
 
-  return {{cpp_shim, cpp_variable_decl, object_file}};
+  TF_ASSIGN_OR_RETURN(result.object_file_data,
+                      CodegenModule(target_machine.get(),
+                                    std::move(module_with_serialized_proto)));
+  return result;
 }
 
 }  // namespace tfcompile
index 8436e0f..ebfe480 100644 (file)
@@ -21,51 +21,70 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
 
 #include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/platform/protobuf.h"
 
 namespace tensorflow {
 namespace tfcompile {
 using xla::StatusOr;
 
-// Represents a protocol buffer embedded into an object file and describes a way
-// to access it at runtime.
-struct EmbeddedProtocolBuffer {
-  // cpp_shim_expression is a C++ expression that creates an instance of said
-  // protocol buffer when executed.
-  string cpp_shim_expression;
-
-  // cpp_variable_decl is an "extern C" array declaration that is used in
-  // cpp_shim_expression.  It must be visible wherever cpp_shim_expression is
-  // emitted.
-  string cpp_variable_decl;
-
-  // The contents of the object (".o") file the protocol buffer is embbed in.
-  // This needs to be linked in to any program that wants to execute
-  // cpp_variable_decl .
+// Represents a set of protocol buffers embedded into an object file and
+// describes how to access them at runtime.
+struct EmbeddedProtocolBuffers {
+  // Each instance CPPShim describes how to generate C++ code to instantiate a
+  // protobuf instance from the corresponding static data emitted into the
+  // object file.
+  struct CPPShim {
+    // `expression` is a C++ expression that creates an instance of said
+    // protocol buffer when executed.
+    string expression;
+
+    // `variable_decl` is an "extern C" array declaration that is used in
+    // `expression`.  It must be visible wherever `expression` is emitted.
+    string variable_decl;
+  };
+
+  // Each cpp_shim corresponds to one embedded protocol buffer.
+  std::vector<CPPShim> cpp_shims;
+
+  // The contents of the object (".o") file the protocol buffers are embbed in.
+  // This needs to be linked in to any program that wants to execute any of the
+  // expressions in `cpp_shims`.
   string object_file_data;
 };
 
-// Creates an object file that contains `proto`.
-//
-// `proto` is allowed to be nullptr, in which case the generated C++ shim
-// expression is just `nullptr`, and the generated object file does not define
-// any symbols.
+// Describes a protocol buffer to embed into an object file.
+struct ProtobufToEmbed {
+  // `symbol_prefix` is prefix that is guaranteed to be unique across the binary
+  // or DSO the generated object file will be linked into.
+  string symbol_prefix;
+
+  // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
+  // namespace qualified) protocol buffer name.  This is only used in
+  // CPPShim::expression so relatively qualified names are fine as long as
+  // they're valid wherever CPPShim::expression is emitted.
+  string qualified_cpp_protobuf_name;
+
+  // `message` is the protocol buffer to be embedded.  It is allowed to be
+  // nullptr, in which case the generated C++ shim expression is just `nullptr`,
+  // and the generated object file does not define any symbols.
+  const ::tensorflow::protobuf::MessageLite* message;
+};
+
+// Embeds a a sequence of protocol buffers into an object file.
 //
 // `target_triple` is the target triple for the target architecture for the
 // generated object file.
 //
-// `symbol_prefix` is prefix that is guaranteed to be unique across the binary
-// or DSO the generated object file will be linked into.
-//
-// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
-// namespace qualified) protocol buffer name.  This needs is only used in
-// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified
-// names are fine as long as they're valid wherever cpp_shim_expression
-// is emitted.
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
-    StringPiece target_triple, StringPiece symbol_prefix,
-    StringPiece qualified_cpp_protobuf_name,
-    const ::tensorflow::protobuf::MessageLite* proto);
+// `protobufs_to_embed` describes the protocol buffers to embed into the
+// resulting object file.  The C++ shim for protobufs_to_embed[i] is
+// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance.  The contents
+// of all the protocol buffers are embedded into a single .o file whose content
+// is stored in the object_file_data field in the returned
+// EmbeddedProtocolBuffers instance.
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+    StringPiece target_triple,
+    gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
 
 }  // namespace tfcompile
 }  // namespace tensorflow
index bb73cb1..222e268 100644 (file)
@@ -164,6 +164,15 @@ tf_library(
 )
 
 tf_library(
+    name = "test_graph_tfmatmulandadd_with_profiling",
+    testonly = 1,
+    config = "test_graph_tfmatmulandadd.config.pbtxt",
+    cpp_class = "MatMulAndAddCompWithProfiling",
+    enable_xla_hlo_profiling = True,
+    graph = "test_graph_tfmatmulandadd.pb",
+)
+
+tf_library(
     name = "test_graph_tfsplits",
     testonly = 1,
     config = "test_graph_tfsplits.config.pbtxt",
@@ -189,9 +198,13 @@ tf_cc_test(
         ":test_graph_tfgather",
         ":test_graph_tfmatmul",
         ":test_graph_tfmatmulandadd",
+        ":test_graph_tfmatmulandadd_with_profiling",
         ":test_graph_tfsplits",
         "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo_profile_printer",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//third_party/eigen3",
index 67dbd64..aa9d968 100644 (file)
@@ -25,15 +25,22 @@ limitations under the License.
 #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
 #include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
 namespace tfcompile {
 namespace {
 
+using ::testing::HasSubstr;
+using ::testing::UnorderedElementsAre;
+
 TEST(TFCompileTest, Add) {
   AddComp add;
   EXPECT_EQ(add.arg0_data(), add.args()[0]);
@@ -484,6 +491,59 @@ TEST(TFCompileTest, ProgramShape) {
   EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
 }
 
+TEST(TFCompileTest, HloProfiling) {
+  Eigen::ThreadPool tp(1);
+  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+  MatMulAndAddCompWithProfiling fn;
+  ASSERT_TRUE(fn.hlo_profiling_enabled());
+
+  fn.set_thread_pool(&device);
+
+  // x = [[1, 2], [3, 4]]
+  fn.arg0(0, 0) = 1;
+  fn.arg0(0, 1) = 2;
+  fn.arg0(1, 0) = 3;
+  fn.arg0(1, 1) = 4;
+
+  // y = [[10, 20], [30, 40]]
+  fn.arg1(0, 0) = 10;
+  fn.arg1(0, 1) = 20;
+  fn.arg1(1, 0) = 30;
+  fn.arg1(1, 1) = 40;
+
+  EXPECT_TRUE(fn.Run());
+
+  string hlo_profile_as_string =
+      xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
+                           /*clock_rate_ghz=*/1.0);
+  VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+
+  std::vector<string> hlo_profile_lines =
+      tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+
+  auto header = HasSubstr("Execution profile for");
+  auto total_cycles_profile_line = HasSubstr("[total]");
+  auto dot_profile_line = HasSubstr(
+      "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+  auto add_profile_line = HasSubstr(
+      "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+  auto tuple_profile_line = HasSubstr(
+      "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+      "f32[2,2]{1,0} %add)");
+  auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+  auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
+
+  hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
+                          hlo_profile_lines.end());
+
+  EXPECT_THAT(
+      hlo_profile_lines,
+      UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line,
+                           add_profile_line, tuple_profile_line,
+                           arg0_profile_line, arg1_profile_line));
+}
+
 }  // namespace
 }  // namespace tfcompile
 }  // namespace tensorflow
index 3a877c5..5c57fee 100644 (file)
@@ -25,7 +25,8 @@ def tf_library(name, graph, config,
                visibility=None, testonly=None,
                tfcompile_flags=None,
                tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
-               include_standard_runtime_deps=True, deps=None, tags=None):
+               include_standard_runtime_deps=True,
+               enable_xla_hlo_profiling=False, deps=None, tags=None):
   """Runs tfcompile to compile a TensorFlow graph into executable code.
 
   Given an invocation of tf_library(name="foo", ...), generates the following
@@ -68,6 +69,8 @@ def tf_library(name, graph, config,
     include_standard_runtime_deps: If True, the standard list of kernel/runtime
       deps is added to deps.  If False, deps must contain the full set of deps
       needed by the generated library.
+    enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
+      and emit metadata that lets us pretty-print the gathered profile counters.
     deps: a list of deps to include on the build rules for the generated
       library, added to the standard deps if standard_runtime_deps is True.
     tags: tags to apply to subsidiary build rules.
@@ -137,6 +140,10 @@ def tf_library(name, graph, config,
     flags = tfcompile_flags
   else:
     flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
+  if enable_xla_hlo_profiling:
+    profiling_flag = "--xla_hlo_profile"
+  else:
+    profiling_flag = ""
   native.genrule(
       name=("gen_" + name),
       srcs=[
@@ -157,7 +164,7 @@ def tf_library(name, graph, config,
            " --out_header=$(@D)/" + header_file +
            " --out_metadata_object=$(@D)/" + metadata_object_file +
            " --out_function_object=$(@D)/" + function_object_file +
-           " " + flags),
+           " " + flags + " " + profiling_flag),
       tools=[tfcompile_tool],
       visibility=visibility,
       testonly=testonly,
@@ -220,6 +227,8 @@ def tf_library(name, graph, config,
       ] + (need_xla_data_proto and [
           # If we're generating the program shape, we must depend on the proto.
           "//tensorflow/compiler/xla:xla_data_proto",
+      ] or []) + (enable_xla_hlo_profiling and [
+          "//tensorflow/compiler/xla/service:hlo_profile_printer_data"
       ] or []) + (include_standard_runtime_deps and [
           # TODO(cwhipkey): only depend on kernel code that the model actually needed.
           "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
index 8ea014c..839e158 100644 (file)
@@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) {
   if (flags.cpp_class.empty()) {
     return errors::InvalidArgument("Must specify --cpp_class");
   }
+  codegen_opts.gen_hlo_profile_printer_data =
+      xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
   TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
                                    &codegen_opts.namespaces));
 
index 150c12e..ec2bb6c 100644 (file)
@@ -118,10 +118,12 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
 
 CpuAotCompilationResult::CpuAotCompilationResult(
     ObjectFileData object_file_data, BufferSizes buffer_sizes,
-    int64 result_buffer_index)
+    int64 result_buffer_index,
+    std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
     : object_file_data_(std::move(object_file_data)),
       buffer_sizes_(std::move(buffer_sizes)),
-      result_buffer_index_(result_buffer_index) {}
+      result_buffer_index_(result_buffer_index),
+      hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
 
 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
 
@@ -171,14 +173,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
  public:
   static StatusOr<std::unordered_map<const HloInstruction*, int64>>
   GetCandidatesForComputation(
-      HloComputation* computation,
+      const HloComputation& computation,
       const std::unordered_map<const HloInstruction*, int64>&
           assigned_indices) {
     std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
     CollectProfileCandidates profile_candidates_for_computation(
         &hlo_to_profile_idx, assigned_indices);
-    TF_RETURN_IF_ERROR(
-        computation->Accept(&profile_candidates_for_computation));
+    TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
     return hlo_to_profile_idx;
   }
 
@@ -424,6 +425,41 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
   return Status::OK();
 }
 
+Status CreateHloProfilingArtifacts(
+    const HloModule& module,
+    std::unordered_map<const HloInstruction*, int64>*
+        instruction_to_profile_idx,
+    std::unordered_map<const HloComputation*, int64>*
+        computation_to_profile_idx,
+    std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
+    std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
+  *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module);
+  const HloComputation& entry_computation = *module.entry_computation();
+
+  TF_ASSIGN_OR_RETURN(
+      *instruction_to_profile_idx,
+      CollectProfileCandidates::GetCandidatesForComputation(
+          entry_computation,
+          (*hlo_profile_index_map)->instruction_to_profile_idx()));
+
+  auto shape_size_bytes = [](const Shape& shape) {
+    // On the cpu, opaques are pointers.
+    if (ShapeUtil::IsOpaque(shape)) {
+      return static_cast<int64>(sizeof(void*));
+    }
+    return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+  };
+
+  HloCostAnalysis cost_analysis(shape_size_bytes);
+  TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
+  *hlo_profile_printer_data =
+      CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis);
+  *computation_to_profile_idx =
+      (*hlo_profile_index_map)->computation_to_profile_idx();
+
+  return Status::OK();
+}
+
 }  // namespace
 
 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
@@ -478,28 +514,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
   std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
   if (module->config().hlo_profiling_enabled()) {
-    hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
-
-    TF_ASSIGN_OR_RETURN(
-        instruction_to_profile_idx,
-        CollectProfileCandidates::GetCandidatesForComputation(
-            entry_computation,
-            hlo_profile_index_map->instruction_to_profile_idx()));
-
-    auto shape_size_bytes = [](const Shape& shape) {
-      // On the cpu, opaques are pointers.
-      if (ShapeUtil::IsOpaque(shape)) {
-        return static_cast<int64>(sizeof(void*));
-      }
-      return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
-    };
-
-    HloCostAnalysis cost_analysis(shape_size_bytes);
-    TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
-    hlo_profile_printer_data =
-        CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis);
-    computation_to_profile_idx =
-        hlo_profile_index_map->computation_to_profile_idx();
+    TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
+        *module, &instruction_to_profile_idx, &computation_to_profile_idx,
+        &hlo_profile_index_map, &hlo_profile_printer_data));
   }
 
   std::unique_ptr<Executable> cpu_executable;
@@ -715,11 +732,20 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
           proto, xla_dump_optimized_hlo_proto_to, module->name()));
     }
 
+    std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
+    std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
+    std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
+    std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
+
+    if (module->config().hlo_profiling_enabled()) {
+      TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
+          *module, &instruction_to_profile_idx, &computation_to_profile_idx,
+          &hlo_profile_index_map, &hlo_profile_printer_data));
+    }
+
     IrEmitter ir_emitter(*module, *assignment, &llvm_module,
-                         /*instruction_to_profile_idx=*/
-                         std::unordered_map<const HloInstruction*, int64>{},
-                         /*computation_to_profile_idx=*/
-                         std::unordered_map<const HloComputation*, int64>{},
+                         std::move(instruction_to_profile_idx),
+                         std::move(computation_to_profile_idx),
                          target_machine.get(),
                          /*external_constant_pool=*/nullptr);
     HloComputation* computation = module->entry_computation();
@@ -794,7 +820,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
 
     results.emplace_back(MakeUnique<CpuAotCompilationResult>(
         std::move(object_file_data), std::move(buffer_sizes),
-        result_slice.index()));
+        result_slice.index(), std::move(hlo_profile_printer_data)));
   }
 
   VLOG(1) << "Compilation finished";
index 151af38..65b05f0 100644 (file)
@@ -76,10 +76,16 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
 
 class CpuAotCompilationResult : public AotCompilationResult {
  public:
-  CpuAotCompilationResult(ObjectFileData object_file_data,
-                          BufferSizes buffer_sizes, int64 result_buffer_index);
+  CpuAotCompilationResult(
+      ObjectFileData object_file_data, BufferSizes buffer_sizes,
+      int64 result_buffer_index,
+      std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
   ~CpuAotCompilationResult();
 
+  HloProfilePrinterData* hlo_profile_printer_data() const {
+    return hlo_profile_printer_data_.get();
+  }
+
   const ObjectFileData& object_file_data() const { return object_file_data_; }
   const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
   int64 result_buffer_index() const { return result_buffer_index_; }
@@ -97,6 +103,10 @@ class CpuAotCompilationResult : public AotCompilationResult {
   // result of the computation.  This buffer should be passed into the output
   // parameter when calling the compiled computation.
   const int64 result_buffer_index_;
+
+  // Contains an instance of HloProfilePrinterData if HLO profiling is enabled,
+  // otherwise is nullptr.
+  std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
 };
 
 // CPU-targeting implementation of the XLA Compiler interface.