}
}
+StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+ const Shape& dest_shape) const {
+ if (!ShapeUtil::IsTuple(dest_shape)) {
+ return Convert(dest_shape.element_type());
+ }
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ auto element = LiteralView::Create(*this, {i});
+ TF_ASSIGN_OR_RETURN(
+ auto new_element,
+ element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
+ elements.push_back(std::move(*new_element));
+ }
+ auto converted = MakeUnique<Literal>();
+ *converted = Literal::MoveIntoTuple(&elements);
+ return std::move(converted);
+}
+
template <typename NativeT>
bool Literal::Piece::EqualElementsInternal(
const Literal::Piece& other, std::vector<int64>* multi_index) const {
StatusOr<std::unique_ptr<Literal>> Convert(
PrimitiveType primitive_dest_type) const;
+ // Converts this literal to the given shape. Returns an error is the
+ // conversion is not possible.
+ StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+ const Shape& dest_shape) const;
+
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
":hlo_dce",
":hlo_pass",
":tuple_simplifier",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
)
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
}
+ computations_visited_in_mutation_pass_.insert(
+ fusion->fused_instructions_computation());
}
-void BFloat16Propagation::AdjustFusionParameters(HloInstruction* fusion) {
- CHECK_EQ(fusion->fused_parameters().size(), fusion->operand_count());
- for (int64 i = 0; i < fusion->operand_count(); ++i) {
- auto parameter = fusion->fused_parameter(i);
- ShapeUtil::ForEachMutableSubshape(
- parameter->mutable_shape(),
- [&](Shape* subshape, const ShapeIndex& index) {
- if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
- return;
- }
- PrimitiveType operand_type =
- ShapeUtil::GetSubshape(fusion->operand(i)->shape(), index)
- .element_type();
- if (subshape->element_type() == operand_type) {
- return;
- }
- CHECK(operand_type == F32 || operand_type == BF16);
- subshape->set_element_type(operand_type);
+void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision(
+ HloInstruction* while_hlo) {
+ CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
+
+ // We are depending on the while node itself having already been analyzed for
+ // whether it can output BF16 and this has been adjusted in the output shape,
+ // and now we're looking to update the body and condition computations to
+ // match the new output shape, as well as recursively process the whole while
+ // node even if the output shape was not modified.
+ HloComputation* body = while_hlo->while_body();
+ auto body_root = body->root_instruction();
+ HloComputation* condition = while_hlo->while_condition();
+
+ ShapeUtil::ForEachMutableSubshape(
+ body_root->mutable_shape(),
+ [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) {
+ if (subshape->element_type() != F32) {
+ return;
+ }
+ if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() ==
+ BF16) {
+ subshape->set_element_type(BF16);
changed_ = true;
- VLOG(2) << "Fused parameter " << parameter->ToString()
+ VLOG(2) << "While body root " << body_root->ToString()
<< " at shape index " << index
- << " adjusted to match operand in fusion "
- << fusion->ToString();
- });
+ << " changed to BF16 precision for while "
+ << while_hlo->ToString();
+ }
+ });
+
+ auto body_insts = body->MakeInstructionPostOrder();
+ for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
+ ++inst_it) {
+ DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
}
+ computations_visited_in_mutation_pass_.insert(body);
+
+ auto condition_insts = condition->MakeInstructionPostOrder();
+ for (auto inst_it = condition_insts.rbegin();
+ inst_it != condition_insts.rend(); ++inst_it) {
+ DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
+ }
+ computations_visited_in_mutation_pass_.insert(condition);
}
bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
continue;
}
for (const HloUse& use : value->uses()) {
+ if (!ContainsKey(instructions_visited_in_mutation_pass_,
+ use.instruction)) {
+ // We don't know yet whether use.instruction will consume BF16 since it
+ // hasn't been visited. Although we visit instructions in reverse
+ // topological order, this is still possible because there may be
+ // unvisited instruction that alias the same buffer. In this case, we
+ // aggressively skip this use, and if this causes inconsistency (e.g.,
+ // one use is in BF16 but another use is in F32), it will be resolved at
+ // the end of the BFloat16Propagation pass.
+ continue;
+ }
+ // Any visited user that can accept BF16 has already been updated if
+ // necessary, e.g., the output has been changed to BF16 if it propagates
+ // precision, or a called computation's parameters have been changed to
+ // BF16 for fusions or whiles.
if (use.instruction->opcode() == HloOpcode::kFusion) {
- auto fused_parameter =
+ const auto* fused_parameter =
use.instruction->fused_parameter(use.operand_number);
if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index)
.element_type() != BF16) {
return false;
}
continue;
+ } else if (use.instruction->opcode() == HloOpcode::kWhile) {
+ const auto* cond_parameter =
+ use.instruction->while_condition()->parameter_instruction(
+ use.operand_number);
+ if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index)
+ .element_type() != BF16) {
+ return false;
+ }
+ const auto* body_parameter =
+ use.instruction->while_body()->parameter_instruction(
+ use.operand_number);
+ if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index)
+ .element_type() != BF16) {
+ return false;
+ }
+ continue;
}
if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
*use.instruction, use.operand_number)) {
void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
HloInstruction* hlo, bool skip_parameters) {
- // We handle any fusion computation after the instruction is handled, because
- // we need to know a fusion's output shape before propagating inside its fused
- // computation.
- auto cleaner = tensorflow::gtl::MakeCleanup([this, hlo] {
- if (hlo->opcode() == HloOpcode::kFusion) {
- DetermineAndMutateFusionComputationPrecision(hlo);
- }
- });
+ // We handle any fusion computation or while body/condition after the
+ // instruction is handled, because we need to know the output shape of a
+ // fusion or while before propagating inside its computations.
+ bool postpone_processing_called_computations = false;
+ auto cleaner = tensorflow::gtl::MakeCleanup(
+ [this, hlo, &postpone_processing_called_computations] {
+ if (!postpone_processing_called_computations) {
+ if (hlo->opcode() == HloOpcode::kFusion) {
+ DetermineAndMutateFusionComputationPrecision(hlo);
+ } else if (hlo->opcode() == HloOpcode::kWhile) {
+ DetermineAndMutateWhileComputationsPrecision(hlo);
+ }
+ }
+ instructions_visited_in_mutation_pass_.insert(hlo);
+ });
+
+ if (hlo->opcode() == HloOpcode::kWhile &&
+ (caller_counts_[hlo->while_condition()] > 1 ||
+ caller_counts_[hlo->while_body()] > 1)) {
+ postpone_processing_called_computations = true;
+ return;
+ }
// Do not change precision for instructions related to entry and exit of a
// computation, and control flow, because this pass might break the interfaces
// or assumptions for them.
if (hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
- hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kCall || //
- hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional || //
(hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
return;
return true;
}
-Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
- HloModule* module) {
- std::list<HloComputation*> computations_topological_order =
- module->MakeComputationPostOrder();
- for (auto comp_it = computations_topological_order.rbegin();
- comp_it != computations_topological_order.rend(); ++comp_it) {
- auto insts = (*comp_it)->MakeInstructionPostOrder();
- // Do the adjustment on each instruction in the computation in reverse
- // topological order.
- for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
- auto hlo = *inst_it;
- auto adjust_buffer = [this, hlo](Shape* subshape,
- const ShapeIndex& index) {
- if (subshape->element_type() != F32 &&
- subshape->element_type() != BF16) {
- return;
+void BFloat16Propagation::AdjustCalledComputationParameters(
+ HloInstruction* hlo) {
+ auto adjust_computation =
+ [this, hlo](HloComputation* computation,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ // Adjust parameters.
+ CHECK_EQ(operands.size(), computation->num_parameters());
+ for (int64 i = 0; i < operands.size(); ++i) {
+ auto parameter = computation->parameter_instruction(i);
+ ShapeUtil::ForEachMutableSubshape(
+ parameter->mutable_shape(),
+ [this, i, hlo, &operands, parameter](Shape* subshape,
+ const ShapeIndex& index) {
+ if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
+ return;
+ }
+ PrimitiveType operand_type =
+ ShapeUtil::GetSubshape(operands[i]->shape(), index)
+ .element_type();
+ if (subshape->element_type() == operand_type) {
+ return;
+ }
+ CHECK(operand_type == F32 || operand_type == BF16);
+ subshape->set_element_type(operand_type);
+ changed_ = true;
+ VLOG(2) << "Called computation parameter "
+ << parameter->ToString() << " at shape index " << index
+ << " adjusted to match operand in HLO "
+ << hlo->ToString();
+ });
}
- PrimitiveType type = BF16;
- for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
- if (value->shape().element_type() == BF16) {
- continue;
+ };
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kFusion:
+ adjust_computation(hlo->fused_instructions_computation(),
+ hlo->operands());
+ break;
+ case HloOpcode::kWhile:
+ adjust_computation(hlo->while_condition(), hlo->operands());
+ adjust_computation(hlo->while_body(), hlo->operands());
+ break;
+ default:
+ break;
+ }
+}
+
+void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
+ auto adjust_computation = [this, hlo](HloComputation* computation,
+ const Shape& output_shape) {
+ // Adjust root.
+ HloInstruction* root = computation->root_instruction();
+ ShapeUtil::ForEachMutableSubshape(
+ root->mutable_shape(), [this, hlo, root, &output_shape](
+ Shape* subshape, const ShapeIndex& index) {
+ if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
+ return;
}
- CHECK_EQ(value->shape().element_type(), F32);
- type = F32;
- break;
- }
- // It's possible that a user has been changed from BF16 to F32
- // during this final adjustment pass, so we need to check
- // AllUsersConsumeBF16() again.
- if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
- type = F32;
- }
- if (type == F32) {
- for (const auto* value :
- dataflow_->GetValueSet(hlo, index).values()) {
- // We rely on the fact that this adjustment works in reverse
- // topological order. Adding the value to
- // values_that_must_be_kept_as_f32_ will ensure the correctness
- // of the adjustment for HLOs that will be processed later.
- values_that_must_be_kept_as_f32_.insert(value);
+ const PrimitiveType output_type =
+ ShapeUtil::GetSubshape(output_shape, index).element_type();
+ if (subshape->element_type() == output_type) {
+ return;
+ }
+ CHECK(output_type == F32 || output_type == BF16);
+ subshape->set_element_type(output_type);
+ // It's possible that output_type is F32, but the root instruction's
+ // type is BF16; e.g., a fusion node's output was changed to BF16
+ // initially but then adjusted back to F32, and the fusion computation
+ // is now being adjusted after the fusion node.
+ if (output_type == F32) {
+ for (const auto* value :
+ dataflow_->GetValueSet(root, index).values()) {
+ // We rely on the fact that this adjustment works in reverse
+ // topological order so that called computation will be
+ // processed later. Adding the value to
+ // values_that_must_be_kept_as_f32_ will ensure the
+ // correctness of the adjustment for HLOs that will be
+ // processed later.
+ values_that_must_be_kept_as_f32_.insert(value);
+ }
}
+ changed_ = true;
+ VLOG(2) << "Called computation root " << root->ToString()
+ << " at shape index " << index
+ << " adjusted to match output shape of " << hlo->ToString();
+ });
+ };
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kFusion:
+ adjust_computation(hlo->fused_instructions_computation(), hlo->shape());
+ break;
+ case HloOpcode::kWhile:
+ adjust_computation(hlo->while_condition(), hlo->shape());
+ adjust_computation(hlo->while_body(), hlo->shape());
+ break;
+ default:
+ break;
+ }
+}
+
+bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
+ HloComputation* computation,
+ tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
+ bool parameter_changed = false;
+ auto insts = computation->MakeInstructionPostOrder();
+ // Do the adjustment on each instruction in the computation in reverse
+ // topological order.
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ auto hlo = *inst_it;
+ auto adjust_hlo_output = [this, hlo, ¶meter_changed](
+ Shape* subshape, const ShapeIndex& index) {
+ if (subshape->element_type() != F32 && subshape->element_type() != BF16) {
+ return;
+ }
+ PrimitiveType type = BF16;
+ for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+ if (value->shape().element_type() == BF16) {
+ continue;
}
+ CHECK_EQ(value->shape().element_type(), F32);
+ type = F32;
+ break;
+ }
+ // It's possible that a user has been changed from BF16 to F32
+ // during this final adjustment pass, so we need to check
+ // AllUsersConsumeBF16() again.
+ if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
+ type = F32;
+ }
+ if (type == F32) {
+ for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+ // We rely on the fact that this adjustment works in reverse
+ // topological order. Adding the value to
+ // values_that_must_be_kept_as_f32_ will ensure the correctness
+ // of the adjustment for HLOs that will be processed later.
+ values_that_must_be_kept_as_f32_.insert(value);
+ }
+ }
+ if (type != subshape->element_type()) {
subshape->set_element_type(type);
- };
- ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_buffer);
- }
- // Now adjust parameters of fusions inside this computation.
- for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
- auto hlo = *inst_it;
- if (hlo->opcode() == HloOpcode::kFusion) {
- AdjustFusionParameters(hlo);
+ VLOG(2) << "HloInstruction output at shape index " << index
+ << " adjusted to " << *subshape << ": " << hlo->ToString();
+ if (hlo->opcode() == HloOpcode::kParameter) {
+ parameter_changed = true;
+ }
+ }
+ };
+ ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output);
+ AdjustCalledComputationRoot(hlo);
+ if (hlo->opcode() == HloOpcode::kWhile) {
+ // We need to run on the while body and condition repeatedly until a fixed
+ // point is reached, i.e., the parameters do not change any more. We may
+ // need more than one iteration because the while input and output alias
+ // each other, so changing one input parameter requires changing the
+ // corresponding output element and thus may transitively require changing
+ // another input parameter. A fixed point will be reached because the
+ // parameters can only be changed from BF16 to F32, not the other way
+ // around.
+ tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
+ while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
+ &visited_in_while) ||
+ ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
+ &visited_in_while)) {
+ visited_in_while.clear();
+ ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(),
+ adjust_hlo_output);
+ AdjustCalledComputationRoot(hlo);
}
+ visited_computations->insert(visited_in_while.begin(),
+ visited_in_while.end());
}
}
+ // Now adjust parameters of called computations.
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ AdjustCalledComputationParameters(*inst_it);
+ }
+ return parameter_changed;
+}
+
+Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
+ HloModule* module) {
+ std::list<HloComputation*> computations_topological_order =
+ module->MakeComputationPostOrder();
+ tensorflow::gtl::FlatSet<const HloComputation*> resolved;
+ for (auto comp_it = computations_topological_order.rbegin();
+ comp_it != computations_topological_order.rend(); ++comp_it) {
+ if (ContainsKey(resolved, *comp_it)) {
+ continue;
+ }
+ ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
+ }
// We could have changed a fusion computation's root shape to have a different
// precision than the fusion node's output, if the fusion root does not
needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
}
}
+
+ // We may have converted some constants from F32 to BF16, so adjust the
+ // constant literals in such cases. We do this here instead of when the
+ // constant node's is changed because 1) the HloInstruction interface does not
+ // allow resetting the literal so we have to create a new kConstant
+ // instruction to replace the old one, which invalidates dataflow analysis,
+ // and 2) it's possible that a kConstant's output gets changed to BF16 at the
+ // beginning but later on adjusted back to F32, so converting literals here
+ // can avoid repeated conversions.
+ //
+ // TODO(b/73833576): Consider resetting literal in HloInstruction.
+ bool needs_dce = needs_tuple_simplifier;
+ for (auto computation : computations_topological_order) {
+ for (auto hlo : computation->MakeInstructionPostOrder()) {
+ if (hlo->opcode() != HloOpcode::kConstant) {
+ continue;
+ }
+ if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
+ TF_ASSIGN_OR_RETURN(auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape()));
+ auto new_constant = computation->AddInstruction(
+ HloInstruction::CreateConstant(std::move(converted_literal)));
+ TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
+ needs_dce = true;
+ }
+ }
+ }
+
if (needs_tuple_simplifier) {
TupleSimplifier tuple_simplifier;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+ }
+ if (needs_dce) {
HloDCE dce;
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
// be bitwise identical to that without this pass; this is possible if the
// backend already reduces precision to BF16 on some HLO instructions.
//
-// This pass will not modify the signature of any non-fusion computation.
+// This pass will not modify the signature of a computation, unless it is a
+// fusion computation or its only caller is a while.
//
// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
// which has two issues:
bool skip_parameters);
// Special handling in the mutation pass for fusion computations.
+ //
+ // Precondition: hlo->opcode() == kFusion
void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
+ // Special handling in the mutation pass for while computations.
+ //
+ // Precondition: hlo->opcode() == kWhile
+ void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo);
+
+ // The set of HloInstructions that have been visited in the mutation pass.
+ tensorflow::gtl::FlatSet<const HloInstruction*>
+ instructions_visited_in_mutation_pass_;
+
+ // The set of HloComputations that have been visited in the mutation pass.
+ tensorflow::gtl::FlatSet<const HloComputation*>
+ computations_visited_in_mutation_pass_;
+
// ***************************
// Functions called by the final inconsistency resolving pass.
// same precision.
Status ResolveInconsistencyOfAliasingBuffers(HloModule* module);
- // Makes the fusion parameters match the precision of the actual parameters
- // passed to the fusion node.
- void AdjustFusionParameters(HloInstruction* fusion);
+ // Resolves inconsistency of aliasing buffers for the given computation, and
+ // recursively runs on a while instruction's condition and body until a fixed
+ // point is reached.
+ bool ResolveInconsistencyOfAliasingBuffersHelper(
+ HloComputation* computation,
+ tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
+
+ // Makes the parameters of called computations match how they are called by
+ // the given HLO.
+ void AdjustCalledComputationParameters(HloInstruction* hlo);
+
+ // Makes the root instructions of called computations match how they are used
+ // by the given HLO.
+ void AdjustCalledComputationRoot(HloInstruction* hlo);
// ***************************
// Functions called and state used by two or more passes.
// The set of F32 HLO values that must be kept in F32.
tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
- // ***************************
- // State used by both passes.
+ // Mapping from each HloComputation to the number of callers to it in the
+ // module. Populated at the beginning of this pass.
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
+
const BFloat16Support* bfloat16_support_;
std::unique_ptr<HloDataflowAnalysis> dataflow_;
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
EXPECT_FALSE(OutputsBF16(c));
}
+// Tests that if a constant is converted to BF16 then its literal must also be
+// converted.
+TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+ Array2D<float> array_a(4, 4);
+ array_a.FillUnique(1.0f);
+ Array2D<float> array_b(4, 4);
+ array_b.FillUnique(10.0f);
+
+ HloInstruction* a = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateFromArray(array_a)));
+ HloInstruction* b = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateFromArray(array_b)));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_TRUE(OutputsBF16(dot->operand(0)));
+ EXPECT_TRUE(OutputsBF16(dot->operand(1)));
+ EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
+ EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
+ LiteralTestUtil::ExpectEqual(
+ dot->operand(0)->literal(),
+ *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
+ LiteralTestUtil::ExpectEqual(
+ dot->operand(1)->literal(),
+ *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
+}
+
// Tests that BF16 can be propagated through nested tuples.
TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
auto builder = HloComputation::Builder(TestName());
EXPECT_TRUE(OutputsBF16(xpose));
}
+// Tests that BF16 is propagated properly through while computations.
+TEST_F(BFloat16PropagationTest, PropagateThroughWhile) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param1"));
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+
+ auto builder_cond = HloComputation::Builder("cond");
+ auto cond_param = builder_cond.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
+ auto cond_lhs = builder_cond.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
+ auto cond_rhs = builder_cond.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
+ // This add should prevent RHS from using BF16
+ auto cond_add_rhs = builder_cond.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
+ auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+ builder_cond.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ auto cond = module->AddEmbeddedComputation(builder_cond.Build());
+
+ auto builder_body = HloComputation::Builder("body");
+ auto body_param = builder_body.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
+ auto body_lhs = builder_body.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, body_param, 0));
+ auto body_rhs = builder_body.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, body_param, 1));
+ auto body_dot = builder_body.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ builder_body.AddInstruction(
+ HloInstruction::CreateTuple({body_dot, body_rhs}));
+ auto body = module->AddEmbeddedComputation(builder_body.Build());
+
+ auto while_hlo = builder.AddInstruction(
+ HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
+
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
+ auto dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_TRUE(OutputsBF16(lhs));
+ EXPECT_FALSE(OutputsBF16(rhs));
+ EXPECT_TRUE(OutputsBF16(body_dot));
+ EXPECT_TRUE(OutputsBF16(body_lhs));
+ EXPECT_FALSE(OutputsBF16(body_rhs));
+ EXPECT_TRUE(OutputsBF16(cond_lhs));
+ EXPECT_FALSE(OutputsBF16(cond_rhs));
+ EXPECT_TRUE(OutputsBF16(add0));
+ EXPECT_FALSE(OutputsBF16(add1));
+}
+
+// Tests that BF16 is not propagated through multiple whiles that invoke the
+// same computation as long as one while prevents the propagation.
+TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param1"));
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* add3 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
+ HloInstruction* tuple0 =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+ HloInstruction* tuple1 =
+ builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
+
+ // Condition computation for the first while.
+ auto builder_cond0 = HloComputation::Builder("cond0");
+ auto cond0_param = builder_cond0.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
+ auto cond0_lhs = builder_cond0.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
+ auto cond0_rhs = builder_cond0.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
+ // This add should prevent RHS from using BF16
+ auto cond0_add_rhs =
+ builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
+ auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+ builder_cond0.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+ builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
+ builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+ auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
+
+ // Condition computation for the second while.
+ auto builder_cond1 = HloComputation::Builder("cond1");
+ auto cond1_param = builder_cond1.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
+ auto cond1_lhs = builder_cond1.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
+ auto cond1_rhs = builder_cond1.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
+ // This add should prevent LHS from using BF16
+ auto cond1_add_lhs =
+ builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
+ auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+ builder_cond1.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
+ builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
+ builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+ auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
+
+ // Body computation shared by both whiles.
+ auto builder_body = HloComputation::Builder("body");
+ auto body_param = builder_body.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
+ auto body_lhs = builder_body.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, body_param, 0));
+ auto body_rhs = builder_body.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, body_param, 1));
+ auto body_dot = builder_body.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ builder_body.AddInstruction(
+ HloInstruction::CreateTuple({body_dot, body_rhs}));
+ auto body = module->AddEmbeddedComputation(builder_body.Build());
+
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
+
+ auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kDot,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+ auto dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(OutputsBF16(body_dot));
+ EXPECT_FALSE(OutputsBF16(body_rhs));
+ EXPECT_FALSE(OutputsBF16(body_lhs));
+ EXPECT_FALSE(OutputsBF16(cond0_lhs));
+ EXPECT_FALSE(OutputsBF16(cond0_rhs));
+ EXPECT_FALSE(OutputsBF16(cond1_lhs));
+ EXPECT_FALSE(OutputsBF16(cond1_rhs));
+ EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
+ EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
+ EXPECT_EQ(computation->root_instruction(), dot);
+}
+
} // namespace xla