[XLA] Make hlo deserialization stable for HloModule by sorting by ids when creating...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 23:40:03 +0000 (16:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 23:42:52 +0000 (16:42 -0700)
Also, delete the HloModule parameter HloInstruction::CreateFromProto, it's not used anywhere.

Also, in ToProto, set sharding to proto if there is sharding.

PiperOrigin-RevId: 196049173

tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation.h
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/hlo_module.cc

index 17e43c3..05dceb1 100644 (file)
@@ -407,27 +407,37 @@ HloComputationProto HloComputation::ToProto() const {
 
 /* static */ StatusOr<std::unique_ptr<HloComputation>>
 HloComputation::CreateFromProto(
-    HloModule* module, const HloComputationProto& proto,
+    const HloComputationProto& proto,
     const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
-  std::vector<std::unique_ptr<HloInstruction>> instructions;
   tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
+  tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+  std::vector<std::unique_ptr<HloInstruction>> instructions;
   int64 parameter_count = 0;
   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
     TF_ASSIGN_OR_RETURN(
         std::unique_ptr<HloInstruction> instruction,
-        HloInstruction::CreateFromProto(module, instruction_proto,
-                                        instruction_map, computation_map));
+        HloInstruction::CreateFromProto(instruction_proto, instruction_map,
+                                        computation_map));
     if (instruction->opcode() == HloOpcode::kParameter) {
       parameter_count++;
     }
     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
     instruction_map[instruction_proto.id()] = instruction.get();
+    to_proto_id[instruction.get()] = instruction_proto.id();
     instructions.push_back(std::move(instruction));
   }
 
   TF_RET_CHECK(proto.root_id() != -1);
   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
   HloInstruction* root = instruction_map.at(proto.root_id());
+
+  // Sort the instructions in the proto id's order.
+  std::sort(instructions.begin(), instructions.end(),
+            [&](const std::unique_ptr<HloInstruction>& a,
+                const std::unique_ptr<HloInstruction>& b) {
+              return to_proto_id[a.get()] < to_proto_id[b.get()];
+            });
+
   return WrapUnique(new HloComputation(proto.name(), parameter_count,
                                        &instructions, root,
                                        /*fusion_instruction=*/nullptr));
index 9898355..ba9d44a 100644 (file)
@@ -157,14 +157,12 @@ class HloComputation {
 
   // Creates a computation from the given proto. Arguments:
   //
-  //   module: the module which will contain the computation. The newly created
-  //     computation is *not* added to the module, however.
   //   proto: the proto to convert from.
   //   computation_map: a map from computation id to HloComputation*. This map
   //     must contain all computations which the newly constructed computation
   //     calls.
   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
-      HloModule* module, const HloComputationProto& proto,
+      const HloComputationProto& proto,
       const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
 
   // Gets the instructions in this computation.
index 03e0391..3ff1007 100644 (file)
@@ -51,7 +51,7 @@ using ::tensorflow::strings::StrCat;
 
 /* static */
 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
-    HloModule* module, const HloInstructionProto& proto,
+    const HloInstructionProto& proto,
     const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
     const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
   TF_RET_CHECK(!proto.opcode().empty());
@@ -2396,6 +2396,10 @@ HloInstructionProto HloInstruction::ToProto() const {
     proto.add_fft_length(fft_len);
   }
 
+  if (has_sharding()) {
+    *proto.mutable_sharding() = sharding().ToProto();
+  }
+
   proto.set_channel_name(channel_name_);
   proto.set_cost_estimate_ns(cost_estimate_ns_);
 
index ea5fc5b..2e5895e 100644 (file)
@@ -185,8 +185,6 @@ class HloInstruction {
 
   // Creates an instruction from the given proto. Arguments:
   //
-  //   module: the module which will contain the instruction. The newly created
-  //     instruction is *not* added to the module or any computation, however.
   //   proto: the proto to convert from.
   //   instruction_map: a map from instruction id to HloInstruction*. This map
   //     must contain all operands of the newly constructed instruction.
@@ -194,7 +192,7 @@ class HloInstruction {
   //     must contain all computations which the newly constructed instruction
   //     calls.
   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
-      HloModule* module, const HloInstructionProto& proto,
+      const HloInstructionProto& proto,
       const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
       const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
 
index 5308fb5..fbf1d58 100644 (file)
@@ -266,24 +266,44 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
 
-  auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
-                                      module_config);
-
   tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
+  tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
+  std::vector<std::unique_ptr<HloComputation>> computations;
+  HloComputation* entry = nullptr;
   for (const HloComputationProto& computation_proto : proto.computations()) {
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
-                        HloComputation::CreateFromProto(
-                            module.get(), computation_proto, computation_map));
+    TF_ASSIGN_OR_RETURN(
+        std::unique_ptr<HloComputation> computation,
+        HloComputation::CreateFromProto(computation_proto, computation_map));
     CHECK_NE(computation.get(), nullptr);
     int64 computation_id = computation_proto.id();
     TF_RET_CHECK(computation_id != -1);
     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
+    computation_map[computation_id] = computation.get();
+    to_proto_id[computation.get()] = computation_id;
+    if (computation_id == proto.entry_computation_id()) {
+      entry = computation.get();
+    }
+    computations.push_back(std::move(computation));
+  }
+  TF_RET_CHECK(entry != nullptr);
+
+  auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
+                                      module_config);
+
+  // Sort the computations in the proto id's order.
+  std::sort(computations.begin(), computations.end(),
+            [&](const std::unique_ptr<HloComputation>& a,
+                const std::unique_ptr<HloComputation>& b) {
+              return to_proto_id[a.get()] < to_proto_id[b.get()];
+            });
+
+  // Add sorted computations to the module.
+  for (auto& computation : computations) {
+    bool is_entry = computation.get() == entry;
     // Don't uniquify names because we want names to be stable across
     // serialization and deserialization.
-    computation_map[computation_id] = module->AddComputationInternal(
-        std::move(computation),
-        /*is_entry=*/proto.entry_computation_id() == computation_id,
-        /*uniquify_names=*/false);
+    module->AddComputationInternal(std::move(computation), is_entry,
+                                   /*uniquify_names=*/false);
   }
   TF_RET_CHECK(module->entry_computation_ != nullptr);