[XLA] Use IDs instead of names to represent the edges of HLO graph in hlo.proto.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Mar 2018 23:13:58 +0000 (16:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 23:17:13 +0000 (16:17 -0700)
PiperOrigin-RevId: 189831057

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/service/hlo.proto
tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation.h
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/hlo_module.cc

index 6328a4f..8829fc6 100644 (file)
@@ -99,16 +99,17 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
 
   // Not all instructions can be roots. Walk backwards from the last added
   // instruction until a valid root is found.
+  entry.set_root_id(-1);
   for (int64 i = instructions_.size() - 1; i >= 0; i--) {
     TF_ASSIGN_OR_RETURN(HloOpcode opcode,
                         StringToHloOpcode(instructions_[i].opcode()));
     if (CanBeRoot(opcode)) {
-      entry.set_root_name(instructions_[i].name());
+      entry.set_root_id(instructions_[i].id());
       *program_shape->mutable_result() = instructions_[i].shape();
       break;
     }
   }
-  if (entry.root_name().empty()) {
+  if (entry.root_id() == -1) {
     return FailedPrecondition("no root instruction was found");
   }
 
@@ -141,7 +142,9 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
   XlaComputation computation(id);
   HloModuleProto* module = computation.mutable_proto();
   module->set_name(entry.name());
+  module->set_id(entry.id());
   module->set_entry_computation_name(entry.name());
+  module->set_entry_computation_id(entry.id());
   *module->mutable_program_shape() = entry.program_shape();
   for (auto& e : embedded_) {
     module->add_computations()->Swap(&e.second);
@@ -162,8 +165,8 @@ XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
                         ShapeInference::InferBinaryOpShape(
                             HloOpcode::kAdd, lhs_instr->shape(),
                             rhs_instr->shape(), broadcast_dimensions));
-    instr.add_operand_names(lhs_instr->name());
-    instr.add_operand_names(rhs_instr->name());
+    instr.add_operand_ids(lhs_instr->id());
+    instr.add_operand_ids(rhs_instr->id());
     return AddInstruction(std::move(instr));
   };
   return NoteErrorOrReturn(op());
@@ -195,11 +198,12 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
     // Add input operands.
     for (const auto& operand : operands) {
       TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
-      instr.add_operand_names(operand_instr->name());
+      instr.add_operand_ids(operand_instr->id());
     }
 
     // Add called computation.
-    *instr.add_called_computation_names() = computation.proto().name();
+    instr.add_called_computation_ids(
+        computation.proto().entry_computation_id());
     for (const HloComputationProto& e : computation.proto().computations()) {
       embedded_.insert({e.id(), e});
     }
@@ -229,6 +233,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
 
 XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
   const int64 handle = instructions_.size();
+  instr.set_id(handle);
   if (instr.name().empty()) {
     instr.set_name(StrCat(instr.opcode(), ".", handle));
   } else {
index b86fbd8..406fead 100644 (file)
@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING.
-//
-// Don't use these protos in the real compilation or execution codepaths. The
-// data format is meant for debugging only, and may change without notice.
+// This proto file defines messages which represent the HLO module. This is a
+// full fidelity serialization of the c++ HLO constructs.
 //
 // Many of the protos below are simple 1-to-1 serializations of the
-// corresponding C++ classes.
+// corresponding C++ classes, e.g., HloModule, HloComputation, and
+// HloInstruction.
 //
 // FIELD NAMES ARE IMPORTANT
 //
@@ -40,16 +39,17 @@ message HloInstructionProto {
   reserved "parameter_name";
   reserved 12;
   reserved "fused_instructions_computation";
+  reserved 4;
+  reserved "operand_names";
+  reserved 5;
+  reserved "control_predecessor_names";
+  reserved 6;
+  reserved "called_computation_names";
 
   string name = 1;
   string opcode = 2;
   xla.Shape shape = 3;
 
-  // TODO(b/67782397): Replace instruction names with HloInstruction ids.
-  repeated string operand_names = 4;
-  repeated string control_predecessor_names = 5;
-  repeated string called_computation_names = 6;
-
   xla.OpMetadata metadata = 7;
 
   // Literal, only present for kConstant.
@@ -137,30 +137,38 @@ message HloInstructionProto {
 
   // The id of this instruction.
   int64 id = 35;
+
+  repeated int64 operand_ids = 36;
+  repeated int64 control_predecessor_ids = 37;
+  repeated int64 called_computation_ids = 38;
 }
 
 // Serialization of HloComputation.
 message HloComputationProto {
+  reserved 3;
+  reserved "root_name";
+
   string name = 1;
 
   // The array of instructions is always in a valid dependency order, where
   // operands appear before their users.
   repeated HloInstructionProto instructions = 2;
 
-  // The name of the root of the computation.
-  string root_name = 3;
-
   // The program shape (with layout) of this computation.
   xla.ProgramShape program_shape = 4;
 
   // The id of this computation.
   int64 id = 5;
+
+  // The id of the root of the computation.
+  int64 root_id = 6;
 }
 
 // Serialization of HloModule.
 message HloModuleProto {
   string name = 1;
   string entry_computation_name = 2;
+  int64 entry_computation_id = 6;
 
   // The array of computations is always in a valid dependency order, where
   // callees appear before their callers.
index 4e85219..6f983d0 100644 (file)
@@ -65,6 +65,7 @@ HloComputation::HloComputation(
     std::vector<std::unique_ptr<HloInstruction>>* instructions,
     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
     : name_(name),
+      unique_id_(-1),
       root_instruction_(root_instruction),
       fusion_instruction_(fusion_instruction) {
   param_instructions_.resize(parameter_count, nullptr);
@@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
     instruction->UniquifyName(&parent()->instruction_name_uniquer());
     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
   }
-  Reparent(instruction.get());
+  instruction->set_parent(this);
   HloInstruction* pinst = instruction.get();
   instruction_iterators_[pinst] =
       instructions_.insert(instructions_.end(), std::move(instruction));
@@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
   return Status::OK();
 }
 
-void HloComputation::Reparent(HloInstruction* instruction) {
-  instruction->set_parent(this);
-}
-
 bool HloComputation::IsRemovable(const HloInstruction* instruction) {
   // If the instruction has control predecessors or successors then we cannot
   // remove the instruction without violating ordering constraints (added, for
@@ -393,12 +390,16 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
 
 HloComputationProto HloComputation::ToProto() const {
   HloComputationProto proto;
+  CHECK(unique_id_ != -1)
+      << "This computation does not have a valid id. Please make sure the "
+         "computation is inside a module before dumping it.";
+  proto.set_id(unique_id_);
   proto.set_name(name_);
   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
     HloInstructionProto instruction_proto = instruction->ToProto();
     proto.add_instructions()->Swap(&instruction_proto);
   }
-  proto.set_root_name(root_instruction()->name());
+  proto.set_root_id(root_instruction()->unique_id());
   *proto.mutable_program_shape() = ComputeProgramShape();
   return proto;
 }
@@ -406,9 +407,9 @@ HloComputationProto HloComputation::ToProto() const {
 /* static */ StatusOr<std::unique_ptr<HloComputation>>
 HloComputation::CreateFromProto(
     HloModule* module, const HloComputationProto& proto,
-    const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
+    const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
   std::vector<std::unique_ptr<HloInstruction>> instructions;
-  tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
+  tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
   int64 parameter_count = 0;
   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
     TF_ASSIGN_OR_RETURN(
@@ -418,14 +419,14 @@ HloComputation::CreateFromProto(
     if (instruction->opcode() == HloOpcode::kParameter) {
       parameter_count++;
     }
-    TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name()));
-    instruction_map[instruction->name()] = instruction.get();
+    TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
+    instruction_map[instruction_proto.id()] = instruction.get();
     instructions.push_back(std::move(instruction));
   }
 
-  TF_RET_CHECK(!proto.root_name().empty());
-  TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
-  HloInstruction* root = instruction_map.at(proto.root_name());
+  TF_RET_CHECK(proto.root_id() != -1);
+  TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
+  HloInstruction* root = instruction_map.at(proto.root_id());
   return WrapUnique(new HloComputation(proto.name(), parameter_count,
                                        &instructions, root,
                                        /*fusion_instruction=*/nullptr));
index 630d367..9d3f6e9 100644 (file)
@@ -160,12 +160,12 @@ class HloComputation {
   //   module: the module which will contain the computation. The newly created
   //     computation is *not* added to the module, however.
   //   proto: the proto to convert from.
-  //   computation_map: a map from computation name to HloComputation*. This map
+  //   computation_map: a map from computation id to HloComputation*. This map
   //     must contain all computations which the newly constructed computation
   //     calls.
   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
       HloModule* module, const HloComputationProto& proto,
-      const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
+      const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
 
   // Gets the instructions in this computation.
   //
@@ -334,6 +334,15 @@ class HloComputation {
     fusion_instruction_ = fusion_instruction;
   }
 
+  // The id of this computation should be unique within the module.
+  void SetUniqueId(int64 id) {
+    CHECK_EQ(unique_id_, -1);
+    CHECK_GE(id, 0);
+    unique_id_ = id;
+  }
+
+  int64 unique_id() const { return unique_id_; }
+
  private:
   explicit HloComputation(
       const string& name, int parameter_count,
@@ -344,10 +353,6 @@ class HloComputation {
   HloInstruction* AddInstructionInternal(
       std::unique_ptr<HloInstruction> instruction);
 
-  // Helper for setting the parent of instructions that are added to this
-  // computation.
-  void Reparent(HloInstruction* instruction);
-
   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
   //
   // Pre-condition: fusion_instruction's opcode is kFusion.
@@ -365,6 +370,7 @@ class HloComputation {
   std::vector<HloInstruction*> CollectUnreachableRoots() const;
 
   string name_;
+  int64 unique_id_;
   HloInstruction* root_instruction_;
 
   // If this computation is a fusion computation, this field points to the
index 83fcc5d..a2a2c1e 100644 (file)
@@ -52,22 +52,22 @@ using ::tensorflow::strings::StrCat;
 /* static */
 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
     HloModule* module, const HloInstructionProto& proto,
-    const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
-    const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
+    const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
+    const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
   TF_RET_CHECK(!proto.opcode().empty());
   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
   TF_RET_CHECK(proto.has_shape());
 
   auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
-  for (const string& operand_name : proto.operand_names()) {
-    TF_RET_CHECK(ContainsKey(instruction_map, operand_name))
-        << "No instruction named " << operand_name;
-    instruction->AppendOperand(instruction_map.at(operand_name));
-  }
-  for (const string& predecessor_name : proto.control_predecessor_names()) {
-    TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name))
-        << "No instruction named " << predecessor_name;
-    TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name)
+  for (const int64 operand_id : proto.operand_ids()) {
+    TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
+        << "No instruction with id " << operand_id;
+    instruction->AppendOperand(instruction_map.at(operand_id));
+  }
+  for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+    TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+        << "No instruction with id " << predecessor_id;
+    TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
                            ->AddControlDependencyTo(instruction.get()));
   }
 
@@ -80,21 +80,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
                         StringToFusionKind(proto.fusion_kind()));
 
     // Find the fused computation and set its fusion instruction.
-    TF_RET_CHECK(proto.called_computation_names_size() == 1)
+    TF_RET_CHECK(proto.called_computation_ids_size() == 1)
         << "Expect 1 called computation for fusion instruction, but sees "
-        << proto.called_computation_names_size();
-    const string& fusion_name = proto.called_computation_names(0);
-    auto* fused_computation = FindPtrOrNull(computation_map, fusion_name);
+        << proto.called_computation_ids_size();
+    const int64 fusion_id = proto.called_computation_ids(0);
+    auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
     TF_RET_CHECK(fused_computation != nullptr)
-        << "No fusion computation named " << fusion_name;
+        << "No fusion computation with id " << fusion_id;
     fused_computation->SetFusionInstruction(instruction.get());
     instruction->called_computations_.push_back(fused_computation);
   } else {
-    for (const string& computation_name : proto.called_computation_names()) {
-      TF_RET_CHECK(ContainsKey(computation_map, computation_name))
-          << "No computation named " << computation_name;
+    for (const int64 computation_id : proto.called_computation_ids()) {
+      TF_RET_CHECK(ContainsKey(computation_map, computation_id))
+          << "No computation with id " << computation_id;
       instruction->called_computations_.push_back(
-          computation_map.at(computation_name));
+          computation_map.at(computation_id));
     }
   }
 
@@ -2315,14 +2315,18 @@ string HloInstruction::ToShortString() const {
 
 HloInstructionProto HloInstruction::ToProto() const {
   HloInstructionProto proto;
+  CHECK(unique_id_ != -1)
+      << "This instruction does not have a valid id. Please make sure the "
+         "instruction is inside a module before dumping it.";
+  proto.set_id(unique_id_);
   proto.set_name(name_);
   proto.set_opcode(HloOpcodeString(opcode_));
   *proto.mutable_shape() = shape_;
   for (const HloInstruction* operand : operands_) {
-    *proto.add_operand_names() = operand->name();
+    proto.add_operand_ids(operand->unique_id());
   }
   for (const HloInstruction* control : control_predecessors_) {
-    *proto.add_control_predecessor_names() = control->name();
+    proto.add_control_predecessor_ids(control->unique_id());
   }
 
   *proto.mutable_metadata() = metadata_;
@@ -2332,11 +2336,11 @@ HloInstructionProto HloInstruction::ToProto() const {
   proto.set_parameter_number(parameter_number_);
   if (opcode() == HloOpcode::kFusion) {
     proto.set_fusion_kind(xla::ToString(fusion_kind()));
-    *proto.add_called_computation_names() =
-        fused_instructions_computation()->name();
+    proto.add_called_computation_ids(
+        fused_instructions_computation()->unique_id());
   } else {
     for (const HloComputation* computation : called_computations_) {
-      *proto.add_called_computation_names() = computation->name();
+      proto.add_called_computation_ids(computation->unique_id());
     }
   }
 
index a111e1e..a94ba14 100644 (file)
@@ -179,15 +179,15 @@ class HloInstruction {
   //   module: the module which will contain the instruction. The newly created
   //     instruction is *not* added to the module or any computation, however.
   //   proto: the proto to convert from.
-  //   instruction_map: a map from instruction name to HloInstruction*. This map
+  //   instruction_map: a map from instruction id to HloInstruction*. This map
   //     must contain all operands of the newly constructed instruction.
-  //   computation_map: a map from computation name to HloComputation*. This map
+  //   computation_map: a map from computation id to HloComputation*. This map
   //     must contain all computations which the newly constructed instruction
   //     calls.
   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
       HloModule* module, const HloInstructionProto& proto,
-      const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
-      const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
+      const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
+      const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
 
   // Creates a parameter-retrieving instruction.
   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
index 4091ebb..2037764 100644 (file)
@@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal(
   for (auto* instruction : computation->instructions()) {
     instruction->SetUniqueId(NewUniqueInstructionId());
   }
+  // Set unique id to this computation.
+  CHECK_NE(computation->root_instruction()->unique_id(), -1)
+      << "Root has no valid id: " << computation->ToString();
+  computation->SetUniqueId(computation->root_instruction()->unique_id());
+
   computation->set_parent(this);
   computations_.push_back(std::move(computation));
   return computations_.back().get();
@@ -204,8 +209,10 @@ string HloModule::ToString(const HloPrintOptions& options) const {
 
 HloModuleProto HloModule::ToProto() const {
   HloModuleProto proto;
+  proto.set_id(unique_id_);
   proto.set_name(name_);
   proto.set_entry_computation_name(entry_computation_->name());
+  proto.set_entry_computation_id(entry_computation_->unique_id());
   for (const HloComputation* computation : MakeComputationPostOrder()) {
     HloComputationProto computation_proto = computation->ToProto();
     if (computation->name() == entry_computation_->name()) {
@@ -249,19 +256,20 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
   auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
                                       module_config);
 
-  tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
+  tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
   for (const HloComputationProto& computation_proto : proto.computations()) {
     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
                         HloComputation::CreateFromProto(
                             module.get(), computation_proto, computation_map));
     CHECK_NE(computation.get(), nullptr);
-    TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
-    string computation_name = computation->name();
+    int64 computation_id = computation_proto.id();
+    TF_RET_CHECK(computation_id != -1);
+    TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
     // Don't uniquify names because we want names to be stable across
     // serialization and deserialization.
-    computation_map[computation_name] = module->AddComputationInternal(
+    computation_map[computation_id] = module->AddComputationInternal(
         std::move(computation),
-        /*is_entry=*/proto.entry_computation_name() == computation_name,
+        /*is_entry=*/proto.entry_computation_id() == computation_id,
         /*uniquify_names=*/false);
   }
   TF_RET_CHECK(module->entry_computation_ != nullptr);