From a888a0ab8cb20ca310a1eec9aab006eaf11309b7 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 10 May 2018 17:06:27 -0700 Subject: [PATCH] Add a HLO evaluator test case for gather PiperOrigin-RevId: 196193959 --- .../compiler/xla/service/hlo_evaluator_test.cc | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cc16446..8e9688c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -2005,6 +2005,31 @@ ENTRY main { *Evaluate({operand.get(), gather_indices.get()})); } +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { -- 2.7.4