[TF:XLA] Improve readability of HLO graphs when rendered via Tensorboard.
authorPeter Hawkins <phawkins@google.com>
Thu, 22 Feb 2018 19:34:16 +0000 (11:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 19:38:14 +0000 (11:38 -0800)
Add operator metadata around the computation arguments and retvals, so they are grouped together.
Teach the batchnorm expander pass to propagate the operator metadata from the original batch norm operators.

PiperOrigin-RevId: 186648547

tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/xla/service/batchnorm_expander.cc

index 15bba46..5ec05c4 100644 (file)
@@ -365,6 +365,13 @@ Status BuildComputation(
               return a->arg_num() < b->arg_num();
             });
 
+  // Attach a common operator name as metadata. This has no semantic effect — it
+  // merely makes the HLO graph more readable when visualized via TensorBoard,
+  // since TensorBoard forms groups out of operators with similar names.
+  xla::OpMetadata retval_metadata;
+  retval_metadata.set_op_name("XLA_Retvals");
+  builder->SetOpMetadata(retval_metadata);
+
   for (const XlaResource* resource : arg_resources) {
     const XlaCompiler::Argument& arg = args[resource->arg_num()];
     const int core = arg_cores[resource->arg_num()];
@@ -412,6 +419,8 @@ Status BuildComputation(
 
   // Builds the XLA computation.
   builder->Tuple(elems);
+  builder->ClearOpMetadata();
+
   xla::StatusOr<xla::Computation> computation_status = builder->Build();
   if (!computation_status.ok()) {
     return computation_status.status();
@@ -514,6 +523,13 @@ Status XlaCompiler::BuildArguments(
     }
   }
 
+  // Attach a common operator name as metadata. This has no semantic effect — it
+  // merely makes the HLO graph more readable when visualized via TensorBoard,
+  // since TensorBoard forms groups out of operators with similar names.
+  xla::OpMetadata arg_metadata;
+  arg_metadata.set_op_name("XLA_Args");
+  builder->SetOpMetadata(arg_metadata);
+
   // Build parameter handles for non-constant arguments.
   std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
   if (use_tuple_arg) {
@@ -552,6 +568,8 @@ Status XlaCompiler::BuildArguments(
     }
   }
 
+  builder->ClearOpMetadata();
+
   // Fill in the handles in non-constant arguments.
   VLOG(2) << "XLA computation inputs:";
   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
index 27ddfd4..84c9db3 100644 (file)
@@ -153,6 +153,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
   std::vector<HloInstruction*> added_instructions;
   auto add = [&](std::unique_ptr<HloInstruction> inst) {
     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+    added_inst->set_metadata(batch_norm->metadata());
     added_instructions.push_back(added_inst);
     return added_inst;
   };
@@ -334,6 +335,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
   std::vector<HloInstruction*> added_instructions;
   auto add = [&](std::unique_ptr<HloInstruction> inst) {
     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+    added_inst->set_metadata(batch_norm->metadata());
     added_instructions.push_back(added_inst);
     return added_inst;
   };
@@ -419,6 +421,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
   std::vector<HloInstruction*> added_instructions;
   auto add = [&](std::unique_ptr<HloInstruction> inst) {
     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+    added_inst->set_metadata(batch_norm->metadata());
     added_instructions.push_back(added_inst);
     return added_inst;
   };