return instruction;
}
+HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
+ CHECK_EQ(opcode(), HloOpcode::kFusion);
+ CHECK_EQ(operand_count(),
+ fused_instructions_computation()->parameter_instructions().size());
+ const int64 param_no = operand_count();
+ // Name the parameter after the instruction it represents in the outer
+ // (non-fusion) computation.
+ string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ HloInstruction* fused_parameter =
+ fused_instructions_computation()->AddParameter(
+ HloInstruction::CreateParameter(param_no, new_operand->shape(),
+ param_name));
+ AppendOperand(new_operand);
+ return fused_parameter;
+}
+
void HloInstruction::MergeFusionInstruction(
HloInstruction* instruction_to_merge) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
// Clone's operand was not already an operand of the fusion
// instruction. Add it as an operand and add a corresponding fused
// parameter instruction.
- int64 param_no = fused_parameters.size();
- // Name the parameter after the instruction it represents in the outer
- // (non-fusion) computation.
- string param_name = StrCat(operand->name(), ".param_", param_no);
- fused_param = fused_instructions_computation()->AddParameter(
- CreateParameter(param_no, operand->shape(), param_name));
- AppendOperand(operand);
+ fused_param = AddFusionOperand(operand);
}
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
}
// Return true if this operator has a sharding assigned.
bool has_sharding() const { return sharding_ != nullptr; }
+ // Adds a new operand the fusion instruction.
+ HloInstruction* AddFusionOperand(HloInstruction* new_operand);
+
// Merges the fused instructions from 'instruction_to_merge' into the
// fused instruction set of 'this', updating operands as necessary.
//