[XLA] Support while loops and constant in HLO BF16 propagation.
authorYuanzhong Xu <yuanzx@google.com>
Fri, 2 Mar 2018 19:15:14 +0000 (11:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 19:19:27 +0000 (11:19 -0800)
PiperOrigin-RevId: 187644155

tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/bfloat16_propagation.cc
tensorflow/compiler/xla/service/bfloat16_propagation.h
tensorflow/compiler/xla/service/bfloat16_propagation_test.cc

index a345e95a8b69a1c55c650b192f1d56e193d92ada..1d1418fc2f7d2f47641bbe5806fc06dfbcb7ebd0 100644 (file)
@@ -1434,6 +1434,24 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
   }
 }
 
+StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+    const Shape& dest_shape) const {
+  if (!ShapeUtil::IsTuple(dest_shape)) {
+    return Convert(dest_shape.element_type());
+  }
+  std::vector<Literal> elements;
+  for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+    auto element = LiteralView::Create(*this, {i});
+    TF_ASSIGN_OR_RETURN(
+        auto new_element,
+        element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
+    elements.push_back(std::move(*new_element));
+  }
+  auto converted = MakeUnique<Literal>();
+  *converted = Literal::MoveIntoTuple(&elements);
+  return std::move(converted);
+}
+
 template <typename NativeT>
 bool Literal::Piece::EqualElementsInternal(
     const Literal::Piece& other, std::vector<int64>* multi_index) const {
index 1d58f0cbc72794bed659bcba211dfdc0077ebe06..cdc5d807e09e09663f3e03d6556a4b832d9420e5 100644 (file)
@@ -333,6 +333,11 @@ class Literal {
   StatusOr<std::unique_ptr<Literal>> Convert(
       PrimitiveType primitive_dest_type) const;
 
+  // Converts this literal to the given shape. Returns an error is the
+  // conversion is not possible.
+  StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+      const Shape& dest_shape) const;
+
   // Creates a scalar literal value zero of the given primitive type.
   static Literal Zero(PrimitiveType primitive_type);
 
index e4ae812532a5a6221cd15fdc5f92fc1e565f1521..d71790fb2d188c2100d317cd6bfdcd3be26dfea4 100644 (file)
@@ -129,6 +129,7 @@ cc_library(
         ":hlo_dce",
         ":hlo_pass",
         ":tuple_simplifier",
+        "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_tree",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
@@ -148,6 +149,7 @@ tf_cc_test(
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
     ],
 )
index 6145c690b911dd3c74d2677ceb840ae3b86d5309..7708504dc998f6da35ac5b180cb043d1e83d808a 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
 
+#include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -68,33 +69,53 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
     DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
   }
+  computations_visited_in_mutation_pass_.insert(
+      fusion->fused_instructions_computation());
 }
 
-void BFloat16Propagation::AdjustFusionParameters(HloInstruction* fusion) {
-  CHECK_EQ(fusion->fused_parameters().size(), fusion->operand_count());
-  for (int64 i = 0; i < fusion->operand_count(); ++i) {
-    auto parameter = fusion->fused_parameter(i);
-    ShapeUtil::ForEachMutableSubshape(
-        parameter->mutable_shape(),
-        [&](Shape* subshape, const ShapeIndex& index) {
-          if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
-            return;
-          }
-          PrimitiveType operand_type =
-              ShapeUtil::GetSubshape(fusion->operand(i)->shape(), index)
-                  .element_type();
-          if (subshape->element_type() == operand_type) {
-            return;
-          }
-          CHECK(operand_type == F32 || operand_type == BF16);
-          subshape->set_element_type(operand_type);
+void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
+    HloInstruction* while_hlo) {
+  CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
+
+  // We are depending on the while node itself having already been analyzed for
+  // whether it can output BF16 and this has been adjusted in the output shape,
+  // and now we're looking to update the body and condition computations to
+  // match the new output shape, as well as recursively process the whole while
+  // node even if the output shape was not modified.
+  HloComputation* body = while_hlo->while_body();
+  auto body_root = body->root_instruction();
+  HloComputation* condition = while_hlo->while_condition();
+
+  ShapeUtil::ForEachMutableSubshape(
+      body_root->mutable_shape(),
+      [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) {
+        if (subshape->element_type() != F32) {
+          return;
+        }
+        if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() ==
+            BF16) {
+          subshape->set_element_type(BF16);
           changed_ = true;
-          VLOG(2) << "Fused parameter " << parameter->ToString()
+          VLOG(2) << "While body root " << body_root->ToString()
                   << " at shape index " << index
-                  << " adjusted to match operand in fusion "
-                  << fusion->ToString();
-        });
+                  << " changed to BF16 precision for while "
+                  << while_hlo->ToString();
+        }
+      });
+
+  auto body_insts = body->MakeInstructionPostOrder();
+  for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
+       ++inst_it) {
+    DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
   }
+  computations_visited_in_mutation_pass_.insert(body);
+
+  auto condition_insts = condition->MakeInstructionPostOrder();
+  for (auto inst_it = condition_insts.rbegin();
+       inst_it != condition_insts.rend(); ++inst_it) {
+    DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
+  }
+  computations_visited_in_mutation_pass_.insert(condition);
 }
 
 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
@@ -108,14 +129,45 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
       continue;
     }
     for (const HloUse& use : value->uses()) {
+      if (!ContainsKey(instructions_visited_in_mutation_pass_,
+                       use.instruction)) {
+        // We don't know yet whether use.instruction will consume BF16 since it
+        // hasn't been visited. Although we visit instructions in reverse
+        // topological order, this is still possible because there may be
+        // unvisited instruction that alias the same buffer. In this case, we
+        // aggressively skip this use, and if this causes inconsistency (e.g.,
+        // one use is in BF16 but another use is in F32), it will be resolved at
+        // the end of the BFloat16Propagation pass.
+        continue;
+      }
+      // Any visited user that can accept BF16 has already been updated if
+      // necessary, e.g., the output has been changed to BF16 if it propagates
+      // precision, or a called computation's parameters have been changed to
+      // BF16 for fusions or whiles.
       if (use.instruction->opcode() == HloOpcode::kFusion) {
-        auto fused_parameter =
+        const auto* fused_parameter =
             use.instruction->fused_parameter(use.operand_number);
         if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index)
                 .element_type() != BF16) {
           return false;
         }
         continue;
+      } else if (use.instruction->opcode() == HloOpcode::kWhile) {
+        const auto* cond_parameter =
+            use.instruction->while_condition()->parameter_instruction(
+                use.operand_number);
+        if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index)
+                .element_type() != BF16) {
+          return false;
+        }
+        const auto* body_parameter =
+            use.instruction->while_body()->parameter_instruction(
+                use.operand_number);
+        if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index)
+                .element_type() != BF16) {
+          return false;
+        }
+        continue;
       }
       if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
               *use.instruction, use.operand_number)) {
@@ -149,24 +201,36 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
 
 void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
     HloInstruction* hlo, bool skip_parameters) {
-  // We handle any fusion computation after the instruction is handled, because
-  // we need to know a fusion's output shape before propagating inside its fused
-  // computation.
-  auto cleaner = tensorflow::gtl::MakeCleanup([this, hlo] {
-    if (hlo->opcode() == HloOpcode::kFusion) {
-      DetermineAndMutateFusionComputationPrecision(hlo);
-    }
-  });
+  // We handle any fusion computation or while body/condition after the
+  // instruction is handled, because we need to know the output shape of a
+  // fusion or while before propagating inside its  computations.
+  bool postpone_processing_called_computations = false;
+  auto cleaner = tensorflow::gtl::MakeCleanup(
+      [this, hlo, &postpone_processing_called_computations] {
+        if (!postpone_processing_called_computations) {
+          if (hlo->opcode() == HloOpcode::kFusion) {
+            DetermineAndMutateFusionComputationPrecision(hlo);
+          } else if (hlo->opcode() == HloOpcode::kWhile) {
+            DetermineAndMutateWhileComputationsPrecision(hlo);
+          }
+        }
+        instructions_visited_in_mutation_pass_.insert(hlo);
+      });
+
+  if (hlo->opcode() == HloOpcode::kWhile &&
+      (caller_counts_[hlo->while_condition()] > 1 ||
+       caller_counts_[hlo->while_body()] > 1)) {
+    postpone_processing_called_computations = true;
+    return;
+  }
 
   // Do not change precision for instructions related to entry and exit of a
   // computation, and control flow, because this pass might break the interfaces
   // or assumptions for them.
   if (hlo->opcode() == HloOpcode::kInfeed ||       //
       hlo->opcode() == HloOpcode::kOutfeed ||      //
-      hlo->opcode() == HloOpcode::kConstant ||     //
       hlo->opcode() == HloOpcode::kCustomCall ||   //
       hlo->opcode() == HloOpcode::kCall ||         //
-      hlo->opcode() == HloOpcode::kWhile ||        //
       hlo->opcode() == HloOpcode::kConditional ||  //
       (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
     return;
@@ -231,60 +295,198 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
   return true;
 }
 
-Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
-    HloModule* module) {
-  std::list<HloComputation*> computations_topological_order =
-      module->MakeComputationPostOrder();
-  for (auto comp_it = computations_topological_order.rbegin();
-       comp_it != computations_topological_order.rend(); ++comp_it) {
-    auto insts = (*comp_it)->MakeInstructionPostOrder();
-    // Do the adjustment on each instruction in the computation in reverse
-    // topological order.
-    for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
-      auto hlo = *inst_it;
-      auto adjust_buffer = [this, hlo](Shape* subshape,
-                                       const ShapeIndex& index) {
-        if (subshape->element_type() != F32 &&
-            subshape->element_type() != BF16) {
-          return;
+void BFloat16Propagation::AdjustCalledComputationParameters(
+    HloInstruction* hlo) {
+  auto adjust_computation =
+      [this, hlo](HloComputation* computation,
+                  tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+        // Adjust parameters.
+        CHECK_EQ(operands.size(), computation->num_parameters());
+        for (int64 i = 0; i < operands.size(); ++i) {
+          auto parameter = computation->parameter_instruction(i);
+          ShapeUtil::ForEachMutableSubshape(
+              parameter->mutable_shape(),
+              [this, i, hlo, &operands, parameter](Shape* subshape,
+                                                   const ShapeIndex& index) {
+                if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
+                  return;
+                }
+                PrimitiveType operand_type =
+                    ShapeUtil::GetSubshape(operands[i]->shape(), index)
+                        .element_type();
+                if (subshape->element_type() == operand_type) {
+                  return;
+                }
+                CHECK(operand_type == F32 || operand_type == BF16);
+                subshape->set_element_type(operand_type);
+                changed_ = true;
+                VLOG(2) << "Called computation parameter "
+                        << parameter->ToString() << " at shape index " << index
+                        << " adjusted to match operand in HLO "
+                        << hlo->ToString();
+              });
         }
-        PrimitiveType type = BF16;
-        for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
-          if (value->shape().element_type() == BF16) {
-            continue;
+      };
+
+  switch (hlo->opcode()) {
+    case HloOpcode::kFusion:
+      adjust_computation(hlo->fused_instructions_computation(),
+                         hlo->operands());
+      break;
+    case HloOpcode::kWhile:
+      adjust_computation(hlo->while_condition(), hlo->operands());
+      adjust_computation(hlo->while_body(), hlo->operands());
+      break;
+    default:
+      break;
+  }
+}
+
+void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
+  auto adjust_computation = [this, hlo](HloComputation* computation,
+                                        const Shape& output_shape) {
+    // Adjust root.
+    HloInstruction* root = computation->root_instruction();
+    ShapeUtil::ForEachMutableSubshape(
+        root->mutable_shape(), [this, hlo, root, &output_shape](
+                                   Shape* subshape, const ShapeIndex& index) {
+          if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
+            return;
           }
-          CHECK_EQ(value->shape().element_type(), F32);
-          type = F32;
-          break;
-        }
-        // It's possible that a user has been changed from BF16 to F32
-        // during this final adjustment pass, so we need to check
-        // AllUsersConsumeBF16() again.
-        if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
-          type = F32;
-        }
-        if (type == F32) {
-          for (const auto* value :
-               dataflow_->GetValueSet(hlo, index).values()) {
-            // We rely on the fact that this adjustment works in reverse
-            // topological order. Adding the value to
-            // values_that_must_be_kept_as_f32_ will ensure the correctness
-            // of the adjustment for HLOs that will be processed later.
-            values_that_must_be_kept_as_f32_.insert(value);
+          const PrimitiveType output_type =
+              ShapeUtil::GetSubshape(output_shape, index).element_type();
+          if (subshape->element_type() == output_type) {
+            return;
+          }
+          CHECK(output_type == F32 || output_type == BF16);
+          subshape->set_element_type(output_type);
+          // It's possible that output_type is F32, but the root instruction's
+          // type is BF16; e.g., a fusion node's output was changed to BF16
+          // initially but then adjusted back to F32, and the fusion computation
+          // is now being adjusted after the fusion node.
+          if (output_type == F32) {
+            for (const auto* value :
+                 dataflow_->GetValueSet(root, index).values()) {
+              // We rely on the fact that this adjustment works in reverse
+              // topological order so that called computation will be
+              // processed later. Adding the value to
+              // values_that_must_be_kept_as_f32_ will ensure the
+              // correctness of the adjustment for HLOs that will be
+              // processed later.
+              values_that_must_be_kept_as_f32_.insert(value);
+            }
           }
+          changed_ = true;
+          VLOG(2) << "Called computation root " << root->ToString()
+                  << " at shape index " << index
+                  << " adjusted to match output shape of " << hlo->ToString();
+        });
+  };
+
+  switch (hlo->opcode()) {
+    case HloOpcode::kFusion:
+      adjust_computation(hlo->fused_instructions_computation(), hlo->shape());
+      break;
+    case HloOpcode::kWhile:
+      adjust_computation(hlo->while_condition(), hlo->shape());
+      adjust_computation(hlo->while_body(), hlo->shape());
+      break;
+    default:
+      break;
+  }
+}
+
+bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
+    HloComputation* computation,
+    tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
+  bool parameter_changed = false;
+  auto insts = computation->MakeInstructionPostOrder();
+  // Do the adjustment on each instruction in the computation in reverse
+  // topological order.
+  for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+    auto hlo = *inst_it;
+    auto adjust_hlo_output = [this, hlo, &parameter_changed](
+                                 Shape* subshape, const ShapeIndex& index) {
+      if (subshape->element_type() != F32 && subshape->element_type() != BF16) {
+        return;
+      }
+      PrimitiveType type = BF16;
+      for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+        if (value->shape().element_type() == BF16) {
+          continue;
         }
+        CHECK_EQ(value->shape().element_type(), F32);
+        type = F32;
+        break;
+      }
+      // It's possible that a user has been changed from BF16 to F32
+      // during this final adjustment pass, so we need to check
+      // AllUsersConsumeBF16() again.
+      if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
+        type = F32;
+      }
+      if (type == F32) {
+        for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+          // We rely on the fact that this adjustment works in reverse
+          // topological order. Adding the value to
+          // values_that_must_be_kept_as_f32_ will ensure the correctness
+          // of the adjustment for HLOs that will be processed later.
+          values_that_must_be_kept_as_f32_.insert(value);
+        }
+      }
+      if (type != subshape->element_type()) {
         subshape->set_element_type(type);
-      };
-      ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_buffer);
-    }
-    // Now adjust parameters of fusions inside this computation.
-    for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
-      auto hlo = *inst_it;
-      if (hlo->opcode() == HloOpcode::kFusion) {
-        AdjustFusionParameters(hlo);
+        VLOG(2) << "HloInstruction output at shape index " << index
+                << " adjusted to " << *subshape << ": " << hlo->ToString();
+        if (hlo->opcode() == HloOpcode::kParameter) {
+          parameter_changed = true;
+        }
+      }
+    };
+    ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output);
+    AdjustCalledComputationRoot(hlo);
+    if (hlo->opcode() == HloOpcode::kWhile) {
+      // We need to run on the while body and condition repeatedly until a fixed
+      // point is reached, i.e., the parameters do not change any more. We may
+      // need more than one iteration because the while input and output alias
+      // each other, so changing one input parameter requires changing the
+      // corresponding output element and thus may transitively require changing
+      // another input parameter. A fixed point will be reached because the
+      // parameters can only be changed from BF16 to F32, not the other way
+      // around.
+      tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
+      while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
+                                                         &visited_in_while) ||
+             ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
+                                                         &visited_in_while)) {
+        visited_in_while.clear();
+        ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(),
+                                          adjust_hlo_output);
+        AdjustCalledComputationRoot(hlo);
       }
+      visited_computations->insert(visited_in_while.begin(),
+                                   visited_in_while.end());
     }
   }
+  // Now adjust parameters of called computations.
+  for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+    AdjustCalledComputationParameters(*inst_it);
+  }
+  return parameter_changed;
+}
+
+Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
+    HloModule* module) {
+  std::list<HloComputation*> computations_topological_order =
+      module->MakeComputationPostOrder();
+  tensorflow::gtl::FlatSet<const HloComputation*> resolved;
+  for (auto comp_it = computations_topological_order.rbegin();
+       comp_it != computations_topological_order.rend(); ++comp_it) {
+    if (ContainsKey(resolved, *comp_it)) {
+      continue;
+    }
+    ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
+  }
 
   // We could have changed a fusion computation's root shape to have a different
   // precision than the fusion node's output, if the fusion root does not
@@ -382,9 +584,39 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
       needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
     }
   }
+
+  // We may have converted some constants from F32 to BF16, so adjust the
+  // constant literals in such cases. We do this here instead of when the
+  // constant node's is changed because 1) the HloInstruction interface does not
+  // allow resetting the literal so we have to create a new kConstant
+  // instruction to replace the old one, which invalidates dataflow analysis,
+  // and 2) it's possible that a kConstant's output gets changed to BF16 at the
+  // beginning but later on adjusted back to F32, so converting literals here
+  // can avoid repeated conversions.
+  //
+  // TODO(b/73833576): Consider resetting literal in HloInstruction.
+  bool needs_dce = needs_tuple_simplifier;
+  for (auto computation : computations_topological_order) {
+    for (auto hlo : computation->MakeInstructionPostOrder()) {
+      if (hlo->opcode() != HloOpcode::kConstant) {
+        continue;
+      }
+      if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
+        TF_ASSIGN_OR_RETURN(auto converted_literal,
+                            hlo->literal().ConvertToShape(hlo->shape()));
+        auto new_constant = computation->AddInstruction(
+            HloInstruction::CreateConstant(std::move(converted_literal)));
+        TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
+        needs_dce = true;
+      }
+    }
+  }
+
   if (needs_tuple_simplifier) {
     TupleSimplifier tuple_simplifier;
     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+  }
+  if (needs_dce) {
     HloDCE dce;
     TF_RETURN_IF_ERROR(dce.Run(module).status());
   }
index ccf77d7b4eb6bd7b76b1b6743bd724f42c141f08..89a5ac5db1549877a135182ae8df57fa6bf9d579 100644 (file)
@@ -38,7 +38,8 @@ namespace xla {
 // be bitwise identical to that without this pass; this is possible if the
 // backend already reduces precision to BF16 on some HLO instructions.
 //
-// This pass will not modify the signature of any non-fusion computation.
+// This pass will not modify the signature of a computation, unless it is a
+// fusion computation or its only caller is a while.
 //
 // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
 // which has two issues:
@@ -92,8 +93,23 @@ class BFloat16Propagation : public HloPassInterface {
                                               bool skip_parameters);
 
   // Special handling in the mutation pass for fusion computations.
+  //
+  // Precondition: hlo->opcode() == kFusion
   void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
 
+  // Special handling in the mutation pass for while computations.
+  //
+  // Precondition: hlo->opcode() == kWhile
+  void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo);
+
+  // The set of HloInstructions that have been visited in the mutation pass.
+  tensorflow::gtl::FlatSet<const HloInstruction*>
+      instructions_visited_in_mutation_pass_;
+
+  // The set of HloComputations that have been visited in the mutation pass.
+  tensorflow::gtl::FlatSet<const HloComputation*>
+      computations_visited_in_mutation_pass_;
+
   // ***************************
   // Functions called by the final inconsistency resolving pass.
 
@@ -102,9 +118,20 @@ class BFloat16Propagation : public HloPassInterface {
   // same precision.
   Status ResolveInconsistencyOfAliasingBuffers(HloModule* module);
 
-  // Makes the fusion parameters match the precision of the actual parameters
-  // passed to the fusion node.
-  void AdjustFusionParameters(HloInstruction* fusion);
+  // Resolves inconsistency of aliasing buffers for the given computation, and
+  // recursively runs on a while instruction's condition and body until a fixed
+  // point is reached.
+  bool ResolveInconsistencyOfAliasingBuffersHelper(
+      HloComputation* computation,
+      tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
+
+  // Makes the parameters of called computations match how they are called by
+  // the given HLO.
+  void AdjustCalledComputationParameters(HloInstruction* hlo);
+
+  // Makes the root instructions of called computations match how they are used
+  // by the given HLO.
+  void AdjustCalledComputationRoot(HloInstruction* hlo);
 
   // ***************************
   // Functions called and state used by two or more passes.
@@ -117,8 +144,10 @@ class BFloat16Propagation : public HloPassInterface {
   // The set of F32 HLO values that must be kept in F32.
   tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
 
-  // ***************************
-  // State used by both passes.
+  // Mapping from each HloComputation to the number of callers to it in the
+  // module. Populated at the beginning of this pass.
+  tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
+
   const BFloat16Support* bfloat16_support_;
   std::unique_ptr<HloDataflowAnalysis> dataflow_;
 
index 2047e2053a1a819a2d534f34fc4ba2f8768dc861..5950b004b3da439c442eec6e5e09ea2307fcb018 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -121,6 +122,41 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
   EXPECT_FALSE(OutputsBF16(c));
 }
 
+// Tests that if a constant is converted to BF16 then its literal must also be
+// converted.
+TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
+  auto builder = HloComputation::Builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+  Array2D<float> array_a(4, 4);
+  array_a.FillUnique(1.0f);
+  Array2D<float> array_b(4, 4);
+  array_b.FillUnique(10.0f);
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateFromArray(array_a)));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateFromArray(array_b)));
+  HloInstruction* dot = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+
+  auto module = CreateNewModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(PropagatePrecision(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), dot);
+  EXPECT_TRUE(OutputsBF16(dot->operand(0)));
+  EXPECT_TRUE(OutputsBF16(dot->operand(1)));
+  EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
+  EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
+  LiteralTestUtil::ExpectEqual(
+      dot->operand(0)->literal(),
+      *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
+  LiteralTestUtil::ExpectEqual(
+      dot->operand(1)->literal(),
+      *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
+}
+
 // Tests that BF16 can be propagated through nested tuples.
 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
   auto builder = HloComputation::Builder(TestName());
@@ -390,4 +426,195 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) {
   EXPECT_TRUE(OutputsBF16(xpose));
 }
 
+// Tests that BF16 is propagated properly through while computations.
+TEST_F(BFloat16PropagationTest, PropagateThroughWhile) {
+  auto module = CreateNewModule();
+  auto builder = HloComputation::Builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "param0"));
+  HloInstruction* param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, shape, "param1"));
+  HloInstruction* add0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+
+  auto builder_cond = HloComputation::Builder("cond");
+  auto cond_param = builder_cond.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
+  auto cond_lhs = builder_cond.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
+  auto cond_rhs = builder_cond.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
+  // This add should prevent RHS from using BF16
+  auto cond_add_rhs = builder_cond.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
+  auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+  builder_cond.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+      builder_cond.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
+      builder_cond.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+  auto cond = module->AddEmbeddedComputation(builder_cond.Build());
+
+  auto builder_body = HloComputation::Builder("body");
+  auto body_param = builder_body.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
+  auto body_lhs = builder_body.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, body_param, 0));
+  auto body_rhs = builder_body.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, body_param, 1));
+  auto body_dot = builder_body.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+  builder_body.AddInstruction(
+      HloInstruction::CreateTuple({body_dot, body_rhs}));
+  auto body = module->AddEmbeddedComputation(builder_body.Build());
+
+  auto while_hlo = builder.AddInstruction(
+      HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
+
+  auto lhs = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
+  auto rhs = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
+  auto dot = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(PropagatePrecision(module.get()));
+
+  EXPECT_EQ(computation->root_instruction(), dot);
+  EXPECT_TRUE(OutputsBF16(lhs));
+  EXPECT_FALSE(OutputsBF16(rhs));
+  EXPECT_TRUE(OutputsBF16(body_dot));
+  EXPECT_TRUE(OutputsBF16(body_lhs));
+  EXPECT_FALSE(OutputsBF16(body_rhs));
+  EXPECT_TRUE(OutputsBF16(cond_lhs));
+  EXPECT_FALSE(OutputsBF16(cond_rhs));
+  EXPECT_TRUE(OutputsBF16(add0));
+  EXPECT_FALSE(OutputsBF16(add1));
+}
+
+// Tests that BF16 is not propagated through multiple whiles that invoke the
+// same computation as long as one while prevents the propagation.
+TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
+  auto module = CreateNewModule();
+  auto builder = HloComputation::Builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "param0"));
+  HloInstruction* param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, shape, "param1"));
+  HloInstruction* add0 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* add3 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+  HloInstruction* tuple0 =
+      builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+  HloInstruction* tuple1 =
+      builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
+
+  // Condition computation for the first while.
+  auto builder_cond0 = HloComputation::Builder("cond0");
+  auto cond0_param = builder_cond0.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
+  auto cond0_lhs = builder_cond0.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
+  auto cond0_rhs = builder_cond0.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
+  // This add should prevent RHS from using BF16
+  auto cond0_add_rhs =
+      builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
+  auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+  builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+      builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
+      builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+  auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
+
+  // Condition computation for the second while.
+  auto builder_cond1 = HloComputation::Builder("cond1");
+  auto cond1_param = builder_cond1.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
+  auto cond1_lhs = builder_cond1.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
+  auto cond1_rhs = builder_cond1.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
+  // This add should prevent LHS from using BF16
+  auto cond1_add_lhs =
+      builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
+  auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+  builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+      builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
+      builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+          ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+  auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
+
+  // Body computation shared by both whiles.
+  auto builder_body = HloComputation::Builder("body");
+  auto body_param = builder_body.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
+  auto body_lhs = builder_body.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, body_param, 0));
+  auto body_rhs = builder_body.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, body_param, 1));
+  auto body_dot = builder_body.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+  builder_body.AddInstruction(
+      HloInstruction::CreateTuple({body_dot, body_rhs}));
+  auto body = module->AddEmbeddedComputation(builder_body.Build());
+
+  auto while0 = builder.AddInstruction(
+      HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
+  auto while1 = builder.AddInstruction(
+      HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
+
+  auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kDot,
+      builder.AddInstruction(
+          HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+      builder.AddInstruction(
+          HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+  auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kDot,
+      builder.AddInstruction(
+          HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+      builder.AddInstruction(
+          HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+  auto dot = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+  auto computation = module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_FALSE(OutputsBF16(body_dot));
+  EXPECT_FALSE(OutputsBF16(body_rhs));
+  EXPECT_FALSE(OutputsBF16(body_lhs));
+  EXPECT_FALSE(OutputsBF16(cond0_lhs));
+  EXPECT_FALSE(OutputsBF16(cond0_rhs));
+  EXPECT_FALSE(OutputsBF16(cond1_lhs));
+  EXPECT_FALSE(OutputsBF16(cond1_rhs));
+  EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
+  EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
+  EXPECT_EQ(computation->root_instruction(), dot);
+}
+
 }  // namespace xla