} break;
case F16:
return Unimplemented("unhandled primitive type: F16.");
+ case BF16: {
+ TF_ASSIGN_OR_RETURN(evaluated_[compare],
+ Compare<bfloat16>(compare->shape(), opcode,
+ lhs_literal, rhs_literal));
+ } break;
case F32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
*Evaluate({operand.get(), gather_indices.get()}));
}
+// Verifies that HloEvaluator evaluates a HLO instruction that performs
+// element-wise comparison with 2 bfloat16 operands.
+TEST_P(HloEvaluatorTest, DoesCompareBF16) {
+ // lhs >= rhs
+ auto lhs = Literal::CreateR2<bfloat16>(
+ {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
+ {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
+ auto rhs = Literal::CreateR2<bfloat16>(
+ {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
+ {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
+ auto expected =
+ Literal::CreateR2<bool>({{false, true, true}, {false, true, true}});
+ TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
+ std::move(rhs));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));