From: A. Unique TensorFlower Date: Wed, 9 May 2018 23:40:03 +0000 (-0700) Subject: [XLA] Make hlo deserialization stable for HloModule by sorting by ids when creating... X-Git-Tag: upstream/v1.9.0_rc1~116^2^2~176 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c07b719ab030c46f19c8e5cdd92730eaec38a8fb;p=platform%2Fupstream%2Ftensorflow.git [XLA] Make hlo deserialization stable for HloModule by sorting by ids when creating from proto. Also, delete the HloModule parameter HloInstruction::CreateFromProto, it's not used anywhere. Also, in ToProto, set sharding to proto if there is sharding. PiperOrigin-RevId: 196049173 --- diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 17e43c3..05dceb1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -407,27 +407,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map) { - std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9898355..ba9d44a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -157,14 +157,12 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 03e0391..3ff1007 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -51,7 +51,7 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); @@ -2396,6 +2396,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); + } + proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index ea5fc5b..2e5895e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -185,8 +185,6 @@ class HloInstruction { // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. @@ -194,7 +192,7 @@ class HloInstruction { // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 5308fb5..fbf1d58 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -266,24 +266,44 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_id] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_id() == computation_id, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr);