[XLA] Factor out the code which adds operands to a fusion node
authorDavid Majnemer <majnemer@google.com>
Fri, 16 Feb 2018 19:37:49 +0000 (11:37 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Feb 2018 19:42:14 +0000 (11:42 -0800)
This makes it easier for Hlo passes to do interesting rewrites with new,
additional parameters which were not operands to the original fusion node.

PiperOrigin-RevId: 186024182

tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h

index 0981f1f4fe57751d5b7059b4b08099385369e4b9..0d9912d07d0ddc8d8bf23281acf18ed4eee39573 100644 (file)
@@ -801,6 +801,22 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
   return instruction;
 }
 
+HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
+  CHECK_EQ(opcode(), HloOpcode::kFusion);
+  CHECK_EQ(operand_count(),
+           fused_instructions_computation()->parameter_instructions().size());
+  const int64 param_no = operand_count();
+  // Name the parameter after the instruction it represents in the outer
+  // (non-fusion) computation.
+  string param_name = StrCat(new_operand->name(), ".param_", param_no);
+  HloInstruction* fused_parameter =
+      fused_instructions_computation()->AddParameter(
+          HloInstruction::CreateParameter(param_no, new_operand->shape(),
+                                          param_name));
+  AppendOperand(new_operand);
+  return fused_parameter;
+}
+
 void HloInstruction::MergeFusionInstruction(
     HloInstruction* instruction_to_merge) {
   CHECK_EQ(opcode_, HloOpcode::kFusion);
@@ -993,13 +1009,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
       // Clone's operand was not already an operand of the fusion
       // instruction. Add it as an operand and add a corresponding fused
       // parameter instruction.
-      int64 param_no = fused_parameters.size();
-      // Name the parameter after the instruction it represents in the outer
-      // (non-fusion) computation.
-      string param_name = StrCat(operand->name(), ".param_", param_no);
-      fused_param = fused_instructions_computation()->AddParameter(
-          CreateParameter(param_no, operand->shape(), param_name));
-      AppendOperand(operand);
+      fused_param = AddFusionOperand(operand);
     }
     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
   }
index 50931c563a5319ecdf8493ee5507eb2551b34673..a4c41c9de843f49cd9d8402f8bfece483c176653 100644 (file)
@@ -917,6 +917,9 @@ class HloInstruction {
   // Return true if this operator has a sharding assigned.
   bool has_sharding() const { return sharding_ != nullptr; }
 
+  // Adds a new operand the fusion instruction.
+  HloInstruction* AddFusionOperand(HloInstruction* new_operand);
+
   // Merges the fused instructions from 'instruction_to_merge' into the
   // fused instruction set of 'this', updating operands as necessary.
   //