From 49eb5240c2270502d2ff4426b0ce80de91ab27f0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 22 Feb 2018 11:34:16 -0800 Subject: [PATCH] [TF:XLA] Improve readability of HLO graphs when rendered via Tensorboard. Add operator metadata around the computation arguments and retvals, so they are grouped together. Teach the batchnorm expander pass to propagate the operator metadata from the original batch norm operators. PiperOrigin-RevId: 186648547 --- tensorflow/compiler/tf2xla/xla_compiler.cc | 18 ++++++++++++++++++ tensorflow/compiler/xla/service/batchnorm_expander.cc | 3 +++ 2 files changed, 21 insertions(+) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 15bba46..5ec05c4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -365,6 +365,13 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; @@ -412,6 +419,8 @@ Status BuildComputation( // Builds the XLA computation. builder->Tuple(elems); + builder->ClearOpMetadata(); + xla::StatusOr computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status(); @@ -514,6 +523,13 @@ Status XlaCompiler::BuildArguments( } } + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name("XLA_Args"); + builder->SetOpMetadata(arg_metadata); + // Build parameter handles for non-constant arguments. std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { @@ -552,6 +568,8 @@ Status XlaCompiler::BuildArguments( } } + builder->ClearOpMetadata(); + // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 27ddfd4..84c9db3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -153,6 +153,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -334,6 +335,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -419,6 +421,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; -- 2.7.4