// Not all instructions can be roots. Walk backwards from the last added
// instruction until a valid root is found.
+ entry.set_root_id(-1);
for (int64 i = instructions_.size() - 1; i >= 0; i--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instructions_[i].opcode()));
if (CanBeRoot(opcode)) {
- entry.set_root_name(instructions_[i].name());
+ entry.set_root_id(instructions_[i].id());
*program_shape->mutable_result() = instructions_[i].shape();
break;
}
}
- if (entry.root_name().empty()) {
+ if (entry.root_id() == -1) {
return FailedPrecondition("no root instruction was found");
}
XlaComputation computation(id);
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
+ module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
+ module->set_entry_computation_id(entry.id());
*module->mutable_program_shape() = entry.program_shape();
for (auto& e : embedded_) {
module->add_computations()->Swap(&e.second);
ShapeInference::InferBinaryOpShape(
HloOpcode::kAdd, lhs_instr->shape(),
rhs_instr->shape(), broadcast_dimensions));
- instr.add_operand_names(lhs_instr->name());
- instr.add_operand_names(rhs_instr->name());
+ instr.add_operand_ids(lhs_instr->id());
+ instr.add_operand_ids(rhs_instr->id());
return AddInstruction(std::move(instr));
};
return NoteErrorOrReturn(op());
// Add input operands.
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
- instr.add_operand_names(operand_instr->name());
+ instr.add_operand_ids(operand_instr->id());
}
// Add called computation.
- *instr.add_called_computation_names() = computation.proto().name();
+ instr.add_called_computation_ids(
+ computation.proto().entry_computation_id());
for (const HloComputationProto& e : computation.proto().computations()) {
embedded_.insert({e.id(), e});
}
XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
const int64 handle = instructions_.size();
+ instr.set_id(handle);
if (instr.name().empty()) {
instr.set_name(StrCat(instr.opcode(), ".", handle));
} else {
limitations under the License.
==============================================================================*/
-// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING.
-//
-// Don't use these protos in the real compilation or execution codepaths. The
-// data format is meant for debugging only, and may change without notice.
+// This proto file defines messages which represent the HLO module. This is a
+// full fidelity serialization of the c++ HLO constructs.
//
// Many of the protos below are simple 1-to-1 serializations of the
-// corresponding C++ classes.
+// corresponding C++ classes, e.g., HloModule, HloComputation, and
+// HloInstruction.
//
// FIELD NAMES ARE IMPORTANT
//
reserved "parameter_name";
reserved 12;
reserved "fused_instructions_computation";
+ reserved 4;
+ reserved "operand_names";
+ reserved 5;
+ reserved "control_predecessor_names";
+ reserved 6;
+ reserved "called_computation_names";
string name = 1;
string opcode = 2;
xla.Shape shape = 3;
- // TODO(b/67782397): Replace instruction names with HloInstruction ids.
- repeated string operand_names = 4;
- repeated string control_predecessor_names = 5;
- repeated string called_computation_names = 6;
-
xla.OpMetadata metadata = 7;
// Literal, only present for kConstant.
// The id of this instruction.
int64 id = 35;
+
+ repeated int64 operand_ids = 36;
+ repeated int64 control_predecessor_ids = 37;
+ repeated int64 called_computation_ids = 38;
}
// Serialization of HloComputation.
message HloComputationProto {
+ reserved 3;
+ reserved "root_name";
+
string name = 1;
// The array of instructions is always in a valid dependency order, where
// operands appear before their users.
repeated HloInstructionProto instructions = 2;
- // The name of the root of the computation.
- string root_name = 3;
-
// The program shape (with layout) of this computation.
xla.ProgramShape program_shape = 4;
// The id of this computation.
int64 id = 5;
+
+ // The id of the root of the computation.
+ int64 root_id = 6;
}
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
string entry_computation_name = 2;
+ int64 entry_computation_id = 6;
// The array of computations is always in a valid dependency order, where
// callees appear before their callers.
std::vector<std::unique_ptr<HloInstruction>>* instructions,
HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(name),
+ unique_id_(-1),
root_instruction_(root_instruction),
fusion_instruction_(fusion_instruction) {
param_instructions_.resize(parameter_count, nullptr);
instruction->UniquifyName(&parent()->instruction_name_uniquer());
instruction->SetUniqueId(parent()->NewUniqueInstructionId());
}
- Reparent(instruction.get());
+ instruction->set_parent(this);
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
instructions_.insert(instructions_.end(), std::move(instruction));
return Status::OK();
}
-void HloComputation::Reparent(HloInstruction* instruction) {
- instruction->set_parent(this);
-}
-
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
HloComputationProto HloComputation::ToProto() const {
HloComputationProto proto;
+ CHECK(unique_id_ != -1)
+ << "This computation does not have a valid id. Please make sure the "
+ "computation is inside a module before dumping it.";
+ proto.set_id(unique_id_);
proto.set_name(name_);
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
HloInstructionProto instruction_proto = instruction->ToProto();
proto.add_instructions()->Swap(&instruction_proto);
}
- proto.set_root_name(root_instruction()->name());
+ proto.set_root_id(root_instruction()->unique_id());
*proto.mutable_program_shape() = ComputeProgramShape();
return proto;
}
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
+ const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
std::vector<std::unique_ptr<HloInstruction>> instructions;
- tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
+ tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
- TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name()));
- instruction_map[instruction->name()] = instruction.get();
+ TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
+ instruction_map[instruction_proto.id()] = instruction.get();
instructions.push_back(std::move(instruction));
}
- TF_RET_CHECK(!proto.root_name().empty());
- TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
- HloInstruction* root = instruction_map.at(proto.root_name());
+ TF_RET_CHECK(proto.root_id() != -1);
+ TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
+ HloInstruction* root = instruction_map.at(proto.root_id());
return WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));
// module: the module which will contain the computation. The newly created
// computation is *not* added to the module, however.
// proto: the proto to convert from.
- // computation_map: a map from computation name to HloComputation*. This map
+ // computation_map: a map from computation id to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
+ const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
// Gets the instructions in this computation.
//
fusion_instruction_ = fusion_instruction;
}
+ // The id of this computation should be unique within the module.
+ void SetUniqueId(int64 id) {
+ CHECK_EQ(unique_id_, -1);
+ CHECK_GE(id, 0);
+ unique_id_ = id;
+ }
+
+ int64 unique_id() const { return unique_id_; }
+
private:
explicit HloComputation(
const string& name, int parameter_count,
HloInstruction* AddInstructionInternal(
std::unique_ptr<HloInstruction> instruction);
- // Helper for setting the parent of instructions that are added to this
- // computation.
- void Reparent(HloInstruction* instruction);
-
// Fuses HLOs in instructions_to_fuse into fusion_instruction.
//
// Pre-condition: fusion_instruction's opcode is kFusion.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
string name_;
+ int64 unique_id_;
HloInstruction* root_instruction_;
// If this computation is a fusion computation, this field points to the
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
+ const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
+ const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
- for (const string& operand_name : proto.operand_names()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_name))
- << "No instruction named " << operand_name;
- instruction->AppendOperand(instruction_map.at(operand_name));
- }
- for (const string& predecessor_name : proto.control_predecessor_names()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name))
- << "No instruction named " << predecessor_name;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name)
+ for (const int64 operand_id : proto.operand_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
+ << "No instruction with id " << operand_id;
+ instruction->AppendOperand(instruction_map.at(operand_id));
+ }
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
->AddControlDependencyTo(instruction.get()));
}
StringToFusionKind(proto.fusion_kind()));
// Find the fused computation and set its fusion instruction.
- TF_RET_CHECK(proto.called_computation_names_size() == 1)
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Expect 1 called computation for fusion instruction, but sees "
- << proto.called_computation_names_size();
- const string& fusion_name = proto.called_computation_names(0);
- auto* fused_computation = FindPtrOrNull(computation_map, fusion_name);
+ << proto.called_computation_ids_size();
+ const int64 fusion_id = proto.called_computation_ids(0);
+ auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
- << "No fusion computation named " << fusion_name;
+ << "No fusion computation with id " << fusion_id;
fused_computation->SetFusionInstruction(instruction.get());
instruction->called_computations_.push_back(fused_computation);
} else {
- for (const string& computation_name : proto.called_computation_names()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_name))
- << "No computation named " << computation_name;
+ for (const int64 computation_id : proto.called_computation_ids()) {
+ TF_RET_CHECK(ContainsKey(computation_map, computation_id))
+ << "No computation with id " << computation_id;
instruction->called_computations_.push_back(
- computation_map.at(computation_name));
+ computation_map.at(computation_id));
}
}
HloInstructionProto HloInstruction::ToProto() const {
HloInstructionProto proto;
+ CHECK(unique_id_ != -1)
+ << "This instruction does not have a valid id. Please make sure the "
+ "instruction is inside a module before dumping it.";
+ proto.set_id(unique_id_);
proto.set_name(name_);
proto.set_opcode(HloOpcodeString(opcode_));
*proto.mutable_shape() = shape_;
for (const HloInstruction* operand : operands_) {
- *proto.add_operand_names() = operand->name();
+ proto.add_operand_ids(operand->unique_id());
}
for (const HloInstruction* control : control_predecessors_) {
- *proto.add_control_predecessor_names() = control->name();
+ proto.add_control_predecessor_ids(control->unique_id());
}
*proto.mutable_metadata() = metadata_;
proto.set_parameter_number(parameter_number_);
if (opcode() == HloOpcode::kFusion) {
proto.set_fusion_kind(xla::ToString(fusion_kind()));
- *proto.add_called_computation_names() =
- fused_instructions_computation()->name();
+ proto.add_called_computation_ids(
+ fused_instructions_computation()->unique_id());
} else {
for (const HloComputation* computation : called_computations_) {
- *proto.add_called_computation_names() = computation->name();
+ proto.add_called_computation_ids(computation->unique_id());
}
}
// module: the module which will contain the instruction. The newly created
// instruction is *not* added to the module or any computation, however.
// proto: the proto to convert from.
- // instruction_map: a map from instruction name to HloInstruction*. This map
+ // instruction_map: a map from instruction id to HloInstruction*. This map
// must contain all operands of the newly constructed instruction.
- // computation_map: a map from computation name to HloComputation*. This map
+ // computation_map: a map from computation id to HloComputation*. This map
// must contain all computations which the newly constructed instruction
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
+ const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
+ const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
for (auto* instruction : computation->instructions()) {
instruction->SetUniqueId(NewUniqueInstructionId());
}
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
+
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
HloModuleProto HloModule::ToProto() const {
HloModuleProto proto;
+ proto.set_id(unique_id_);
proto.set_name(name_);
proto.set_entry_computation_name(entry_computation_->name());
+ proto.set_entry_computation_id(entry_computation_->unique_id());
for (const HloComputation* computation : MakeComputationPostOrder()) {
HloComputationProto computation_proto = computation->ToProto();
if (computation->name() == entry_computation_->name()) {
auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
module_config);
- tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
+ tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
for (const HloComputationProto& computation_proto : proto.computations()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
HloComputation::CreateFromProto(
module.get(), computation_proto, computation_map));
CHECK_NE(computation.get(), nullptr);
- TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
- string computation_name = computation->name();
+ int64 computation_id = computation_proto.id();
+ TF_RET_CHECK(computation_id != -1);
+ TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
- computation_map[computation_name] = module->AddComputationInternal(
+ computation_map[computation_id] = module->AddComputationInternal(
std::move(computation),
- /*is_entry=*/proto.entry_computation_name() == computation_name,
+ /*is_entry=*/proto.entry_computation_id() == computation_id,
/*uniquify_names=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);