[XLA] Skip BF16 output conversion folding when CRS is the root.
authorYuanzhong Xu <yuanzx@google.com>
Tue, 22 May 2018 20:59:48 +0000 (13:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 21:02:24 +0000 (14:02 -0700)
PiperOrigin-RevId: 197618934

tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc

index 08d0152..1b8b2d2 100644 (file)
@@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
 
 Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
     HloInstruction* crs) {
-  if (!ShapeUtil::IsTuple(crs->shape()) ||
-      !bfloat16_support_->SupportsMixedPrecisions(*crs)) {
-    return DefaultAction(crs);
-  }
-
   // First use DefaultAction() to handle the operands. It can't handle
   // tuple-shaped output.
   TF_RETURN_IF_ERROR(DefaultAction(crs));
 
+  if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
+    return Status::OK();
+  }
+
+  // If the output is not a tuple, we don't need special handling.
+  if (!ShapeUtil::IsTuple(crs->shape())) {
+    return Status::OK();
+  }
+
+  // If crs is the root instruction, we should keep its original output type.
+  // The root instruction implicitly has a use from being the result of the
+  // computation, and the code below does not take this use into account.
+  if (crs == computation_->root_instruction()) {
+    return Status::OK();
+  }
+
   // Then do per-tuple-element handling on the output.
   std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
       crs->operand_count());