From: David Majnemer Date: Fri, 16 Feb 2018 19:37:49 +0000 (-0800) Subject: [XLA] Factor out the code which adds operands to a fusion node X-Git-Tag: upstream/v1.7.0~31^2~616 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8dfaa05d2824290b33eb922a5269f0772f53478e;p=platform%2Fupstream%2Ftensorflow.git [XLA] Factor out the code which adds operands to a fusion node This makes it easier for Hlo passes to do interesting rewrites with new, additional parameters which were not operands to the original fusion node. PiperOrigin-RevId: 186024182 --- diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0981f1f4fe..0d9912d07d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -801,6 +801,22 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -993,13 +1009,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Clone's operand was not already an operand of the fusion // instruction. Add it as an operand and add a corresponding fused // parameter instruction. - int64 param_no = fused_parameters.size(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(operand->name(), ".param_", param_no); - fused_param = fused_instructions_computation()->AddParameter( - CreateParameter(param_no, operand->shape(), param_name)); - AppendOperand(operand); + fused_param = AddFusionOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 50931c563a..a4c41c9de8 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -917,6 +917,9 @@ class HloInstruction { // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. //