Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape()) ||
- !bfloat16_support_->SupportsMixedPrecisions(*crs)) {
- return DefaultAction(crs);
- }
-
// First use DefaultAction() to handle the operands. It can't handle
// tuple-shaped output.
TF_RETURN_IF_ERROR(DefaultAction(crs));
+ if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
+ return Status::OK();
+ }
+
+ // If the output is not a tuple, we don't need special handling.
+ if (!ShapeUtil::IsTuple(crs->shape())) {
+ return Status::OK();
+ }
+
+ // If crs is the root instruction, we should keep its original output type.
+ // The root instruction implicitly has a use from being the result of the
+ // computation, and the code below does not take this use into account.
+ if (crs == computation_->root_instruction()) {
+ return Status::OK();
+ }
+
// Then do per-tuple-element handling on the output.
std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
crs->operand_count());