From 0a799feaea50d4e48e8daa1f3954427fdccd76f1 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 26 Feb 2018 10:17:15 -0800 Subject: [PATCH] Generalize the gather_indices dimension that stores indices This is now exposed as a index_vector_dim dimension number. Also fixed an off-by-one error in ValidateGatherDimensionNumbers in the expression computing output_shape_rank. PiperOrigin-RevId: 187040748 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 9 +- tensorflow/compiler/xla/service/hlo_instruction.h | 3 +- .../compiler/xla/service/hlo_instruction_test.cc | 43 ++++- tensorflow/compiler/xla/service/shape_inference.cc | 42 +++-- .../compiler/xla/service/shape_inference_test.cc | 191 +++++++++++++++++---- tensorflow/compiler/xla/xla_data.proto | 4 + .../performance/xla/operation_semantics.md | 61 ++++--- 7 files changed, 274 insertions(+), 79 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b7dd055..a534d8f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1172,7 +1172,8 @@ bool HloInstruction::HasSideEffect() const { /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims) { + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; for (int64 output_window_dim : output_window_dims) { gather_dim_numbers.add_output_window_dims(output_window_dim); @@ -1184,6 +1185,7 @@ bool HloInstruction::HasSideEffect() const { gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); } + gather_dim_numbers.set_index_vector_dim(index_vector_dim); return gather_dim_numbers; } @@ -3369,9 +3371,12 @@ string HloInstruction::GatherDimensionNumbersToString() const { string gather_dims_to_operand_dims = StrCat( "gather_dims_to_operand_dims={", Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims}, + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, ", "); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e4d22e5..e4c8621 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -502,7 +502,8 @@ class HloInstruction { static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims); + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 32d3ed2..f2980d3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1271,7 +1271,7 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } -TEST_F(HloInstructionTest, StringifyGather) { +TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); Shape gather_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); @@ -1291,7 +1291,8 @@ TEST_F(HloInstructionTest, StringifyGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); HloModule module(TestName()); @@ -1303,7 +1304,43 @@ TEST_F(HloInstructionTest, StringifyGather) { "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " "gather_dims_to_operand_dims={0,1,2,3,4}, " - "window_bounds={30,29,28,27,26}"); + "index_vector_dim=4, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyGather_1) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } } // namespace diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c969275..607a672 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2467,27 +2467,27 @@ static Status ValidateGatherDimensionNumbers( const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size(); + output_window_dim_count + gather_indices_shape.size() - 1; for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { int64 window_index = dim_numbers.output_window_dims(i); if (window_index < 0 || window_index >= output_shape_rank) { return InvalidArgument( "Window index %d in gather op is out of bounds; got %lld, but should " - "have been in" - "[0,%lld)", + "have been in [0,%lld)", i, window_index, output_shape_rank); } } if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape.back()) { + gather_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "There must be exactly as many elements in gather_dims_to_operand_dims " - "as there are elements in the last dimension of %%gather_indices; got: " - "%d, expected %lld", + "Gather op has %d elements in gather_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of gather_indices is " + "%lld. These two numbers must be equal.", dim_numbers.gather_dims_to_operand_dims_size(), - gather_indices_shape.back()); + dim_numbers.index_vector_dim(), + gather_indices_shape[dim_numbers.index_vector_dim()]); } for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { @@ -2550,24 +2550,33 @@ static Status ValidateGatherDimensionNumbers( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( gather_indices_shape, "gather indices operand of gather op")); - if (gather_indices_shape.dimensions_size() < 1) { + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( - "Gather indices parameter must at least of rank 1; got %s", + "Gather indices parameter must be an integral tensor; got %s", ShapeUtil::HumanString(gather_indices_shape).c_str()); } - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if + // index_vector_dim is rank(P). The bounds of this expanded shape is + // stored in expanded_gather_indices_shape. + + if (gather_indices_shape.dimensions_size() < + gather_dim_numbers.index_vector_dim() || + gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather indices parameter must be an integral tensor; got %s", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + "Gather index leaf dimension must be within [0, rank(gather_indices) + " + "1). rank(gather_indices) is %d and gather index leaf dimension is " + "%lld.", + gather_indices_shape.dimensions_size(), + gather_dim_numbers.index_vector_dim()); } std::vector expanded_gather_indices_shape; - // We implicitly reshape gather indices of shape P[N] to P[N,1]. expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); c_copy(gather_indices_shape.dimensions(), std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == 1) { + if (expanded_gather_indices_shape.size() == + gather_dim_numbers.index_vector_dim()) { expanded_gather_indices_shape.push_back(1); } @@ -2632,6 +2641,9 @@ static Status ValidateGatherDimensionNumbers( } current_bound = window_bounds[window_dims_seen++]; } else { + if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { + gather_dims_seen++; + } current_bound = expanded_gather_indices_shape[gather_dims_seen++]; } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7eb1208..029d2b3 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1530,11 +1530,17 @@ TEST_F(ShapeInferenceTest, BadSlice) { class GatherShapeInferenceTest : public ShapeInferenceTest { protected: + const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); + const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); const Shape s64_4d_tensor_10_9_8_7_1_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); const Shape s64_4d_tensor_10_9_8_7_5_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape s64_4d_tensor_5_10_9_7_6_ = + ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); + const Shape s64_4d_tensor_10_9_5_7_6_ = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); const Shape f32_5d_tensor_50_49_48_47_46_ = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -1548,7 +1554,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) @@ -1562,7 +1569,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{1}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) @@ -1576,7 +1584,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) @@ -1591,7 +1600,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1599,12 +1609,85 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { + // This is equivalent to a dynamic slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3, 4}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { + // The gather indices "tensor" is a scalar S here that's used to slice out + // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0), + /*window_bounds=*/{1, 30, 29, 28, 27})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) + << ShapeUtil::HumanString(gather_shape); +} + TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1617,7 +1700,8 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { s64_vector_32_, tuple_shape_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1625,25 +1709,13 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( - s64_vector_32_, s32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), - /*window_bounds=*/{64, 1}); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather indices parameter must at least of rank 1")) - << statusor.status(); -} - TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1658,7 +1730,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1674,7 +1747,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1690,7 +1764,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1699,13 +1774,30 @@ TEST_F(GatherShapeInferenceTest, } TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 9}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 4 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1722,7 +1814,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1738,7 +1831,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1755,15 +1849,15 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "There must be exactly as many elements in " - "gather_dims_to_operand_dims " - "as there are elements in the last dimension of %gather_indices")) + HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of " + "gather_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1774,7 +1868,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1791,7 +1886,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1808,7 +1904,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1822,7 +1919,8 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1838,7 +1936,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1855,7 +1954,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1864,5 +1964,22 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } +TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/32), + /*window_bounds=*/{30, 29, 28, 27, 26}); + + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather index leaf dimension must be within [0, " + "rank(gather_indices) + 1)")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 28620c3..1f16e6d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -418,6 +418,10 @@ message GatherDimensionNumbers { // transforms the gather index looked up from the gather_indices tensor into // the starting index in the input space. repeated int64 gather_dims_to_operand_dims = 3; + + // The dimension in the gather_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; } // Operation requests that are all collected as a tagged union with a oneof diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index b0abf5f..b2190c5 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1050,6 +1050,9 @@ For a more intuitive description, see the "Informal Description" section below. : : : indices of the slices we're : : : : we're stitching together into : : : : the output tensor. : +|`index_vector_dim` | `int64` | The dimension in | +: : : `gather_indices` that contains : +: : : the starting indices. : |`output_window_dims` | `ArraySlice` | The set of dimensions in the | : : : output shape that are _window : : : : dimensions_ (defined below). : @@ -1066,22 +1069,20 @@ For a more intuitive description, see the "Informal Description" section below. : : : `output_window_dims`) and the window : : : : dimensions that are elided (via : : : : `elided_window_dims`). : -|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | +|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | : : : array is interpreted as mapping `i` to : : : : `gather_dims_to_operand_dims[i]`) from : : : : the gather indices in `gather_indices` to : : : : the operand index space. It has to be : : : : one-to-one and total. : -If `gather_indices` is a vector with `N` elements then we implicitly reshape it -to a tensor of shape `[N,1]` before proceeding. - For every index `Out` in the output tensor, we compute two things (more precisely described later): - - An index into the first `gather_indices.rank` - `1` dimensions of - `gather_indices`, which gives us a starting index of a slice, _operand - slice_, in the operand tensor. + - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`, + which gives us a starting index of a slice, _operand slice_, in the operand + tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions + in `gather_indices` except `index_vector_dim`. - A _window index_ that has the same rank as the operand. This index is composed of the values in `Out` at dimensions `output_window_dims`, embedded @@ -1093,29 +1094,42 @@ should be present in the output at index `Out`. The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank` - `1`. Additionally, as a shorthand, we define `output_gather_dims` of type `ArraySlice` as the set of dimensions in the output shape but not in -`output_window_dims`, in ascending order. E.g. if the output tensor has rank 5, -`output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, `3`} +`output_window_dims`, in ascending order. E.g. if the output tensor has rank +`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, +`3`} + +If `index_vector_dim` is equal to `gather_indices.rank` we implicitly +consider `gather_indices` to have a trailing `1` dimension (i.e. if +`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then +we implicitly consider the shape of `gather_indices` to be `[6,7,1]`). The bounds for the output tensor along dimension `i` is computed as follows: 1. If `i` is present in `output_gather_dims` (i.e. is equal to - `output_gather_dims[k]` for some `k`) then we pick the corresponding - dimension bounds out of `gather_indices.shape` (i.e. pick - `gather_indices.shape.dims[k]`). + `output_gather_dims[k]` for some `k`) then we pick the corresponding + dimension bounds out of `gather_indices.shape`, skipping + `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k` + < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`] + otherwise). 2. If `i` is present in `output_window_dims` (i.e. equal to - `output_window_dims[k]` for some `k`) then we pick the corresponding bound - out of `window_bounds` after accounting for `elided_window_dims` (i.e. we - pick `adjusted_window_bounds[k]` where `adjusted_window_bounds` is - `window_bounds` with the bounds at indices `elided_window_dims` removed). + `output_window_dims`[`k`] for some `k`) then we pick the corresponding + bound out of `window_bounds` after accounting for `elided_window_dims` + (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds` + is `window_bounds` with the bounds at indices `elided_window_dims` + removed). The operand index `In` corresponding to an output index `Out` is computed as follows: 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice - out vector `S` such that `S`[`i`] = `gather_indices`[`G`, `i`]. - 2. Create an index, `S``in`, into `operand` using `S` by scattering - `S` using the `gather_dims_to_operand_dims` map (`S``in` is the - starting indices for _operand slice_ mentioned above.). More precisely: + out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)] + where Combine(A, b) inserts b at position `index_vector_dim` into A. + Note that this is well defined even if `G` is empty -- if `G` is empty then + `S` = `gather_indices`. + 2. Create an index, `S``in`, into `operand` using `S` by + scattering `S` using the `gather_dims_to_operand_dims` map + (`S``in` is the starting indices for _operand slice_ mentioned + above). More precisely: 1. `S``in`[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` < `gather_dims_to_operand_dims.size`. 2. `S``in`[`_`] = `0` otherwise. @@ -1136,7 +1150,12 @@ follows: `operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. -### Informal Description +### Informal Description and Examples + +`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the +examples that follow. More interesting values for `index_vector_dim` +does not change the operation fundamentally, but makes the visual representation +more cumbersome. To get an intuition on how all of the above fits together, let's look at an example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The -- 2.7.4