From 49ee96a60bea1b595cff3cb550cfc8d2ade5ed8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 20 Mar 2018 16:13:58 -0700 Subject: [PATCH] [XLA] Use IDs instead of names to represent the edges of HLO graph in hlo.proto. PiperOrigin-RevId: 189831057 --- .../compiler/xla/client/xla_client/xla_builder.cc | 17 ++++--- tensorflow/compiler/xla/service/hlo.proto | 34 ++++++++------ tensorflow/compiler/xla/service/hlo_computation.cc | 27 +++++------ tensorflow/compiler/xla/service/hlo_computation.h | 18 +++++--- tensorflow/compiler/xla/service/hlo_instruction.cc | 54 ++++++++++++---------- tensorflow/compiler/xla/service/hlo_instruction.h | 8 ++-- tensorflow/compiler/xla/service/hlo_module.cc | 18 ++++++-- 7 files changed, 104 insertions(+), 72 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 6328a4f..8829fc6 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -99,16 +99,17 @@ StatusOr XlaBuilder::Build() { // Not all instructions can be roots. Walk backwards from the last added // instruction until a valid root is found. + entry.set_root_id(-1); for (int64 i = instructions_.size() - 1; i >= 0; i--) { TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(instructions_[i].opcode())); if (CanBeRoot(opcode)) { - entry.set_root_name(instructions_[i].name()); + entry.set_root_id(instructions_[i].id()); *program_shape->mutable_result() = instructions_[i].shape(); break; } } - if (entry.root_name().empty()) { + if (entry.root_id() == -1) { return FailedPrecondition("no root instruction was found"); } @@ -141,7 +142,9 @@ StatusOr XlaBuilder::Build() { XlaComputation computation(id); HloModuleProto* module = computation.mutable_proto(); module->set_name(entry.name()); + module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); *module->mutable_program_shape() = entry.program_shape(); for (auto& e : embedded_) { module->add_computations()->Swap(&e.second); @@ -162,8 +165,8 @@ XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, lhs_instr->shape(), rhs_instr->shape(), broadcast_dimensions)); - instr.add_operand_names(lhs_instr->name()); - instr.add_operand_names(rhs_instr->name()); + instr.add_operand_ids(lhs_instr->id()); + instr.add_operand_ids(rhs_instr->id()); return AddInstruction(std::move(instr)); }; return NoteErrorOrReturn(op()); @@ -195,11 +198,12 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, // Add input operands. for (const auto& operand : operands) { TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand)); - instr.add_operand_names(operand_instr->name()); + instr.add_operand_ids(operand_instr->id()); } // Add called computation. - *instr.add_called_computation_names() = computation.proto().name(); + instr.add_called_computation_ids( + computation.proto().entry_computation_id()); for (const HloComputationProto& e : computation.proto().computations()) { embedded_.insert({e.id(), e}); } @@ -229,6 +233,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) { const int64 handle = instructions_.size(); + instr.set_id(handle); if (instr.name().empty()) { instr.set_name(StrCat(instr.opcode(), ".", handle)); } else { diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index b86fbd8..406fead 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. -// -// Don't use these protos in the real compilation or execution codepaths. The -// data format is meant for debugging only, and may change without notice. +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the -// corresponding C++ classes. +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. // // FIELD NAMES ARE IMPORTANT // @@ -40,16 +39,17 @@ message HloInstructionProto { reserved "parameter_name"; reserved 12; reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; string name = 1; string opcode = 2; xla.Shape shape = 3; - // TODO(b/67782397): Replace instruction names with HloInstruction ids. - repeated string operand_names = 4; - repeated string control_predecessor_names = 5; - repeated string called_computation_names = 6; - xla.OpMetadata metadata = 7; // Literal, only present for kConstant. @@ -137,30 +137,38 @@ message HloInstructionProto { // The id of this instruction. int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; } // Serialization of HloComputation. message HloComputationProto { + reserved 3; + reserved "root_name"; + string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; - // The name of the root of the computation. - string root_name = 3; - // The program shape (with layout) of this computation. xla.ProgramShape program_shape = 4; // The id of this computation. int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; + int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4e85219..6f983d0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -65,6 +65,7 @@ HloComputation::HloComputation( std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), + unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); @@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->UniquifyName(&parent()->instruction_name_uniquer()); instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } - Reparent(instruction.get()); + instruction->set_parent(this); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); @@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -void HloComputation::Reparent(HloInstruction* instruction) { - instruction->set_parent(this); -} - bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -393,12 +390,16 @@ string HloComputation::ToString(const HloPrintOptions& options) const { HloComputationProto HloComputation::ToProto() const { HloComputationProto proto; + CHECK(unique_id_ != -1) + << "This computation does not have a valid id. Please make sure the " + "computation is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } - proto.set_root_name(root_instruction()->name()); + proto.set_root_id(root_instruction()->unique_id()); *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } @@ -406,9 +407,9 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map) { + const tensorflow::gtl::FlatMap& computation_map) { std::vector> instructions; - tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( @@ -418,14 +419,14 @@ HloComputation::CreateFromProto( if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } - TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); - instruction_map[instruction->name()] = instruction.get(); + TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); + instruction_map[instruction_proto.id()] = instruction.get(); instructions.push_back(std::move(instruction)); } - TF_RET_CHECK(!proto.root_name().empty()); - TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); - HloInstruction* root = instruction_map.at(proto.root_name()); + TF_RET_CHECK(proto.root_id() != -1); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); + HloInstruction* root = instruction_map.at(proto.root_id()); return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 630d367..9d3f6e9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -160,12 +160,12 @@ class HloComputation { // module: the module which will contain the computation. The newly created // computation is *not* added to the module, however. // proto: the proto to convert from. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map); + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // @@ -334,6 +334,15 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // The id of this computation should be unique within the module. + void SetUniqueId(int64 id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + unique_id_ = id; + } + + int64 unique_id() const { return unique_id_; } + private: explicit HloComputation( const string& name, int parameter_count, @@ -344,10 +353,6 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Helper for setting the parent of instructions that are added to this - // computation. - void Reparent(HloInstruction* instruction); - // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. @@ -365,6 +370,7 @@ class HloComputation { std::vector CollectUnreachableRoots() const; string name_; + int64 unique_id_; HloInstruction* root_instruction_; // If this computation is a fusion computation, this field points to the diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 83fcc5d..a2a2c1e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -52,22 +52,22 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map) { + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const string& operand_name : proto.operand_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) - << "No instruction named " << operand_name; - instruction->AppendOperand(instruction_map.at(operand_name)); - } - for (const string& predecessor_name : proto.control_predecessor_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) - << "No instruction named " << predecessor_name; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) ->AddControlDependencyTo(instruction.get())); } @@ -80,21 +80,21 @@ StatusOr> HloInstruction::CreateFromProto( StringToFusionKind(proto.fusion_kind())); // Find the fused computation and set its fusion instruction. - TF_RET_CHECK(proto.called_computation_names_size() == 1) + TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Expect 1 called computation for fusion instruction, but sees " - << proto.called_computation_names_size(); - const string& fusion_name = proto.called_computation_names(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_name); + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) - << "No fusion computation named " << fusion_name; + << "No fusion computation with id " << fusion_id; fused_computation->SetFusionInstruction(instruction.get()); instruction->called_computations_.push_back(fused_computation); } else { - for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) - << "No computation named " << computation_name; + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; instruction->called_computations_.push_back( - computation_map.at(computation_name)); + computation_map.at(computation_id)); } } @@ -2315,14 +2315,18 @@ string HloInstruction::ToShortString() const { HloInstructionProto HloInstruction::ToProto() const { HloInstructionProto proto; + CHECK(unique_id_ != -1) + << "This instruction does not have a valid id. Please make sure the " + "instruction is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); *proto.mutable_shape() = shape_; for (const HloInstruction* operand : operands_) { - *proto.add_operand_names() = operand->name(); + proto.add_operand_ids(operand->unique_id()); } for (const HloInstruction* control : control_predecessors_) { - *proto.add_control_predecessor_names() = control->name(); + proto.add_control_predecessor_ids(control->unique_id()); } *proto.mutable_metadata() = metadata_; @@ -2332,11 +2336,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.add_called_computation_names() = - fused_instructions_computation()->name(); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); } else { for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); + proto.add_called_computation_ids(computation->unique_id()); } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a111e1e..a94ba14 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -179,15 +179,15 @@ class HloInstruction { // module: the module which will contain the instruction. The newly created // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. - // instruction_map: a map from instruction name to HloInstruction*. This map + // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map); + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 4091ebb..2037764 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal( for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); + computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -204,8 +209,10 @@ string HloModule::ToString(const HloPrintOptions& options) const { HloModuleProto HloModule::ToProto() const { HloModuleProto proto; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); + proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { HloComputationProto computation_proto = computation->ToProto(); if (computation->name() == entry_computation_->name()) { @@ -249,19 +256,20 @@ StatusOr> HloModule::CreateFromProto( auto module = MakeUnique(proto.name(), entry_computation_handle, module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { TF_ASSIGN_OR_RETURN(std::unique_ptr computation, HloComputation::CreateFromProto( module.get(), computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); - TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); - string computation_name = computation->name(); + int64 computation_id = computation_proto.id(); + TF_RET_CHECK(computation_id != -1); + TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_name] = module->AddComputationInternal( + computation_map[computation_id] = module->AddComputationInternal( std::move(computation), - /*is_entry=*/proto.entry_computation_name() == computation_name, + /*is_entry=*/proto.entry_computation_id() == computation_id, /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); -- 2.7.4