From 546d1d467372a176f337f2614165c6d754a386da Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 20 Mar 2018 14:33:19 -0700 Subject: [PATCH] [XLA] Simplify the HLO proto: don't nest the fusion computation in an fusion HloInstructionProto. PiperOrigin-RevId: 189811729 --- tensorflow/compiler/xla/service/hlo.proto | 3 ++- tensorflow/compiler/xla/service/hlo_computation.cc | 18 +++++++------- tensorflow/compiler/xla/service/hlo_computation.h | 10 +------- tensorflow/compiler/xla/service/hlo_instruction.cc | 28 ++++++++++++---------- tensorflow/compiler/xla/service/hlo_instruction.h | 7 +----- tensorflow/compiler/xla/service/hlo_module.cc | 22 +++-------------- 6 files changed, 30 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index bf903d6..b86fbd8 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -38,6 +38,8 @@ option cc_enable_arenas = true; message HloInstructionProto { reserved 10; reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; string name = 1; string opcode = 2; @@ -58,7 +60,6 @@ message HloInstructionProto { // Fusion state, only present for kFusion. string fusion_kind = 11; - HloComputationProto fused_instructions_computation = 12; // Index for kGetTupleElement. int64 tuple_index = 13; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index f99c7cf..4e85219 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -406,18 +406,15 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction) { + const tensorflow::gtl::FlatMap& computation_map) { std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } @@ -429,8 +426,9 @@ HloComputation::CreateFromProto( TF_RET_CHECK(!proto.root_name().empty()); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); HloInstruction* root = instruction_map.at(proto.root_name()); - return WrapUnique(new HloComputation( - proto.name(), parameter_count, &instructions, root, fusion_instruction)); + return WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dd9d346..630d367 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -163,17 +163,9 @@ class HloComputation { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used only when the instruction is a fusion instruction. - // fusion_instruction: if non-null then the newly created computation will - // be constructed as a fused computation with this instruction as its - // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction = nullptr); + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d33add2..83fcc5d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -52,9 +53,7 @@ using ::tensorflow::strings::StrCat; StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation) { + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -76,17 +75,20 @@ StatusOr> HloInstruction::CreateFromProto( // HloInstructionProto and do not appear as an HloComputationProto within the // HloModuleProto. if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_names_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_names_size(); + const string& fusion_name = proto.called_computation_names(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_name); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation named " << fusion_name; + fused_computation->SetFusionInstruction(instruction.get()); + instruction->called_computations_.push_back(fused_computation); } else { for (const string& computation_name : proto.called_computation_names()) { TF_RET_CHECK(ContainsKey(computation_map, computation_name)) @@ -2330,8 +2332,8 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); + *proto.add_called_computation_names() = + fused_instructions_computation()->name(); } else { for (const HloComputation* computation : called_computations_) { *proto.add_called_computation_names() = computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e4c8621..a111e1e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -184,15 +184,10 @@ class HloInstruction { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation); + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cdea3d5..4091ebb 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -207,11 +207,6 @@ HloModuleProto HloModule::ToProto() const { proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are added when the fusion instructions are created by - // HloInstruction::CreateFromProto. - if (computation->IsFusionComputation()) { - continue; - } HloComputationProto computation_proto = computation->ToProto(); if (computation->name() == entry_computation_->name()) { *proto.mutable_program_shape() = computation_proto.program_shape(); @@ -256,16 +251,9 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); string computation_name = computation->name(); @@ -283,10 +271,6 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); -- 2.7.4