From 8a362a264e2219872b390eb8c22286acba32d39f Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Tue, 22 May 2018 13:59:48 -0700 Subject: [PATCH] [XLA] Skip BF16 output conversion folding when CRS is the root. PiperOrigin-RevId: 197618934 --- .../xla/service/bfloat16_conversion_folding.cc | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152..1b8b2d2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count()); -- 2.7.4