[TF:XLA:INTERPRETER] implement bfloat16 comparisons
authorNick Desaulniers <ndesaulniers@google.com>
Fri, 27 Apr 2018 23:14:49 +0000 (16:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 23:17:22 +0000 (16:17 -0700)
PiperOrigin-RevId: 194608854

tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator_test.cc

index c5e3014..f1dcef1 100644 (file)
@@ -2536,6 +2536,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
     } break;
     case F16:
       return Unimplemented("unhandled primitive type: F16.");
+    case BF16: {
+      TF_ASSIGN_OR_RETURN(evaluated_[compare],
+                          Compare<bfloat16>(compare->shape(), opcode,
+                                            lhs_literal, rhs_literal));
+    } break;
     case F32: {
       TF_ASSIGN_OR_RETURN(
           evaluated_[compare],
index dd14dd3..230147a 100644 (file)
@@ -2005,6 +2005,22 @@ ENTRY main {
       *Evaluate({operand.get(), gather_indices.get()}));
 }
 
+// Verifies that HloEvaluator evaluates a HLO instruction that performs
+// element-wise comparison with 2 bfloat16 operands.
+TEST_P(HloEvaluatorTest, DoesCompareBF16) {
+  // lhs >= rhs
+  auto lhs = Literal::CreateR2<bfloat16>(
+      {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
+       {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
+  auto rhs = Literal::CreateR2<bfloat16>(
+      {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
+       {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
+  auto expected =
+      Literal::CreateR2<bool>({{false, true, true}, {false, true, true}});
+  TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
+               std::move(rhs));
+}
+
 INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
                         ::testing::ValuesIn(use_bf16_params));