From: Yuanzhong Xu Date: Tue, 22 May 2018 20:59:48 +0000 (-0700) Subject: [XLA] Skip BF16 output conversion folding when CRS is the root. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~200 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8a362a264e2219872b390eb8c22286acba32d39f;p=platform%2Fupstream%2Ftensorflow.git [XLA] Skip BF16 output conversion folding when CRS is the root. PiperOrigin-RevId: 197618934 --- diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152..1b8b2d2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count());