Generalize the gather_indices dimension that stores indices
authorSanjoy Das <sanjoy@google.com>
Mon, 26 Feb 2018 18:17:15 +0000 (10:17 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
This is now exposed as a index_vector_dim dimension number.

Also fixed an off-by-one error in ValidateGatherDimensionNumbers in the
expression computing output_shape_rank.

PiperOrigin-RevId: 187040748

tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/hlo_instruction_test.cc
tensorflow/compiler/xla/service/shape_inference.cc
tensorflow/compiler/xla/service/shape_inference_test.cc
tensorflow/compiler/xla/xla_data.proto
tensorflow/docs_src/performance/xla/operation_semantics.md

index b7dd055..a534d8f 100644 (file)
@@ -1172,7 +1172,8 @@ bool HloInstruction::HasSideEffect() const {
 /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
     tensorflow::gtl::ArraySlice<int64> output_window_dims,
     tensorflow::gtl::ArraySlice<int64> elided_window_dims,
-    tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) {
+    tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+    int64 index_vector_dim) {
   GatherDimensionNumbers gather_dim_numbers;
   for (int64 output_window_dim : output_window_dims) {
     gather_dim_numbers.add_output_window_dims(output_window_dim);
@@ -1184,6 +1185,7 @@ bool HloInstruction::HasSideEffect() const {
     gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
   }
 
+  gather_dim_numbers.set_index_vector_dim(index_vector_dim);
   return gather_dim_numbers;
 }
 
@@ -3369,9 +3371,12 @@ string HloInstruction::GatherDimensionNumbersToString() const {
   string gather_dims_to_operand_dims = StrCat(
       "gather_dims_to_operand_dims={",
       Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+  string index_vector_dim = StrCat(
+      "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
 
   return Join<std::initializer_list<string>>(
-      {output_window_dims, elided_window_dims, gather_dims_to_operand_dims},
+      {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
+       index_vector_dim},
       ", ");
 }
 
index e4d22e5..e4c8621 100644 (file)
@@ -502,7 +502,8 @@ class HloInstruction {
   static GatherDimensionNumbers MakeGatherDimNumbers(
       tensorflow::gtl::ArraySlice<int64> output_window_dims,
       tensorflow::gtl::ArraySlice<int64> elided_window_dims,
-      tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
+      tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+      int64 index_vector_dim);
 
   // Returns the opcode for this instruction.
   HloOpcode opcode() const { return opcode_; }
index 32d3ed2..f2980d3 100644 (file)
@@ -1271,7 +1271,7 @@ TEST_F(HloInstructionTest, Stringification) {
             "true_computation=%TransposeDot, false_computation=%TransposeDot");
 }
 
-TEST_F(HloInstructionTest, StringifyGather) {
+TEST_F(HloInstructionTest, StringifyGather_0) {
   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
   Shape gather_indices_tensor_shape =
       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
@@ -1291,7 +1291,8 @@ TEST_F(HloInstructionTest, StringifyGather) {
           HloInstruction::MakeGatherDimNumbers(
               /*output_window_dims=*/{4, 5, 6, 7, 8},
               /*elided_window_dims=*/{},
-              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/4),
           /*window_bounds=*/{30, 29, 28, 27, 26}));
 
   HloModule module(TestName());
@@ -1303,7 +1304,43 @@ TEST_F(HloInstructionTest, StringifyGather) {
             "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
             "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
             "gather_dims_to_operand_dims={0,1,2,3,4}, "
-            "window_bounds={30,29,28,27,26}");
+            "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+}
+
+TEST_F(HloInstructionTest, StringifyGather_1) {
+  Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+  Shape gather_indices_tensor_shape =
+      ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
+  Shape gather_result_shape =
+      ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
+
+  HloComputation::Builder builder("Gather");
+  HloInstruction* input = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+  HloInstruction* gather_indices =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, gather_indices_tensor_shape, "gather_indices"));
+
+  HloInstruction* gather_instruction =
+      builder.AddInstruction(HloInstruction::CreateGather(
+          gather_result_shape, input, gather_indices,
+          HloInstruction::MakeGatherDimNumbers(
+              /*output_window_dims=*/{4, 5, 6, 7, 8},
+              /*elided_window_dims=*/{},
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/2),
+          /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+  HloModule module(TestName());
+  module.AddEntryComputation(builder.Build());
+
+  EXPECT_EQ(gather_instruction->ToString(),
+            "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
+            "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+            "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
+            "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
+            "gather_dims_to_operand_dims={0,1,2,3,4}, "
+            "index_vector_dim=2, window_bounds={30,29,28,27,26}");
 }
 
 }  // namespace
index c969275..607a672 100644 (file)
@@ -2467,27 +2467,27 @@ static Status ValidateGatherDimensionNumbers(
 
   const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
   const int64 output_shape_rank =
-      output_window_dim_count + gather_indices_shape.size();
+      output_window_dim_count + gather_indices_shape.size() - 1;
 
   for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
     int64 window_index = dim_numbers.output_window_dims(i);
     if (window_index < 0 || window_index >= output_shape_rank) {
       return InvalidArgument(
           "Window index %d in gather op is out of bounds; got %lld, but should "
-          "have been in"
-          "[0,%lld)",
+          "have been in [0,%lld)",
           i, window_index, output_shape_rank);
     }
   }
 
   if (dim_numbers.gather_dims_to_operand_dims_size() !=
-      gather_indices_shape.back()) {
+      gather_indices_shape[dim_numbers.index_vector_dim()]) {
     return InvalidArgument(
-        "There must be exactly as many elements in gather_dims_to_operand_dims "
-        "as there are elements in the last dimension of %%gather_indices; got: "
-        "%d, expected %lld",
+        "Gather op has %d elements in gather_dims_to_operand_dims and the "
+        "bound of dimension index_vector_dim=%lld of gather_indices is "
+        "%lld. These two numbers must be equal.",
         dim_numbers.gather_dims_to_operand_dims_size(),
-        gather_indices_shape.back());
+        dim_numbers.index_vector_dim(),
+        gather_indices_shape[dim_numbers.index_vector_dim()]);
   }
 
   for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
@@ -2550,24 +2550,33 @@ static Status ValidateGatherDimensionNumbers(
   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
       gather_indices_shape, "gather indices operand of gather op"));
 
-  if (gather_indices_shape.dimensions_size() < 1) {
+  if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
     return InvalidArgument(
-        "Gather indices parameter must at least of rank 1; got %s",
+        "Gather indices parameter must be an integral tensor; got %s",
         ShapeUtil::HumanString(gather_indices_shape).c_str());
   }
 
-  if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+  // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
+  // index_vector_dim is rank(P).  The bounds of this expanded shape is
+  // stored in expanded_gather_indices_shape.
+
+  if (gather_indices_shape.dimensions_size() <
+          gather_dim_numbers.index_vector_dim() ||
+      gather_dim_numbers.index_vector_dim() < 0) {
     return InvalidArgument(
-        "Gather indices parameter must be an integral tensor; got %s",
-        ShapeUtil::HumanString(gather_indices_shape).c_str());
+        "Gather index leaf dimension must be within [0, rank(gather_indices) + "
+        "1). rank(gather_indices) is %d and gather index leaf dimension is "
+        "%lld.",
+        gather_indices_shape.dimensions_size(),
+        gather_dim_numbers.index_vector_dim());
   }
 
   std::vector<int64> expanded_gather_indices_shape;
-  // We implicitly reshape gather indices of shape P[N] to P[N,1].
   expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
   c_copy(gather_indices_shape.dimensions(),
          std::back_inserter(expanded_gather_indices_shape));
-  if (expanded_gather_indices_shape.size() == 1) {
+  if (expanded_gather_indices_shape.size() ==
+      gather_dim_numbers.index_vector_dim()) {
     expanded_gather_indices_shape.push_back(1);
   }
 
@@ -2632,6 +2641,9 @@ static Status ValidateGatherDimensionNumbers(
       }
       current_bound = window_bounds[window_dims_seen++];
     } else {
+      if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
+        gather_dims_seen++;
+      }
       current_bound = expanded_gather_indices_shape[gather_dims_seen++];
     }
 
index 7eb1208..029d2b3 100644 (file)
@@ -1530,11 +1530,17 @@ TEST_F(ShapeInferenceTest, BadSlice) {
 
 class GatherShapeInferenceTest : public ShapeInferenceTest {
  protected:
+  const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
+  const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
   const Shape s64_4d_tensor_10_9_8_7_1_ =
       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
   const Shape s64_4d_tensor_10_9_8_7_5_ =
       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+  const Shape s64_4d_tensor_5_10_9_7_6_ =
+      ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
+  const Shape s64_4d_tensor_10_9_5_7_6_ =
+      ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
   const Shape f32_5d_tensor_50_49_48_47_46_ =
       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
@@ -1548,7 +1554,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
                                        HloInstruction::MakeGatherDimNumbers(
                                            /*output_window_dims=*/{0},
                                            /*elided_window_dims=*/{1},
-                                           /*gather_dims_to_operand_dims=*/{1}),
+                                           /*gather_dims_to_operand_dims=*/{1},
+                                           /*index_vector_dim=*/1),
                                        /*window_bounds=*/{64, 1}));
   EXPECT_TRUE(
       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
@@ -1562,7 +1569,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
                                        HloInstruction::MakeGatherDimNumbers(
                                            /*output_window_dims=*/{1},
                                            /*elided_window_dims=*/{0},
-                                           /*gather_dims_to_operand_dims=*/{0}),
+                                           /*gather_dims_to_operand_dims=*/{0},
+                                           /*index_vector_dim=*/1),
                                        /*window_bounds=*/{1, 48}));
   EXPECT_TRUE(
       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
@@ -1576,7 +1584,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
                                        HloInstruction::MakeGatherDimNumbers(
                                            /*output_window_dims=*/{4},
                                            /*elided_window_dims=*/{0},
-                                           /*gather_dims_to_operand_dims=*/{0}),
+                                           /*gather_dims_to_operand_dims=*/{0},
+                                           /*index_vector_dim=*/4),
                                        /*window_bounds=*/{1, 48}));
   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
@@ -1591,7 +1600,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
           HloInstruction::MakeGatherDimNumbers(
               /*output_window_dims=*/{4, 5, 6, 7, 8},
               /*elided_window_dims=*/{},
-              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/4),
           /*window_bounds=*/{30, 29, 28, 27, 26}));
   EXPECT_TRUE(ShapeUtil::Equal(
       gather_shape,
@@ -1599,12 +1609,85 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
       << ShapeUtil::HumanString(gather_shape);
 }
 
+TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape gather_shape,
+      ShapeInference::InferGatherShape(
+          f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+          HloInstruction::MakeGatherDimNumbers(
+              /*output_window_dims=*/{4, 5, 6, 7, 8},
+              /*elided_window_dims=*/{},
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/2),
+          /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+  EXPECT_TRUE(ShapeUtil::Equal(
+      gather_shape,
+      ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
+      << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape gather_shape,
+      ShapeInference::InferGatherShape(
+          f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
+          HloInstruction::MakeGatherDimNumbers(
+              /*output_window_dims=*/{4, 5, 6, 7, 8},
+              /*elided_window_dims=*/{},
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/0),
+          /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+  EXPECT_TRUE(ShapeUtil::Equal(
+      gather_shape,
+      ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
+      << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
+  // This is equivalent to a dynamic slice.
+  TF_ASSERT_OK_AND_ASSIGN(
+      Shape gather_shape,
+      ShapeInference::InferGatherShape(
+          f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+          HloInstruction::MakeGatherDimNumbers(
+              /*output_window_dims=*/{0, 1, 2, 3, 4},
+              /*elided_window_dims=*/{},
+              /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+              /*index_vector_dim=*/0),
+          /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+  EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+                               ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
+      << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
+  // The gather indices "tensor" is a scalar S here that's used to slice out
+  // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
+  TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+                          ShapeInference::InferGatherShape(
+                              f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
+                              HloInstruction::MakeGatherDimNumbers(
+                                  /*output_window_dims=*/{0, 1, 2, 3},
+                                  /*elided_window_dims=*/{0},
+                                  /*gather_dims_to_operand_dims=*/{0},
+                                  /*index_vector_dim=*/0),
+                              /*window_bounds=*/{1, 30, 29, 28, 27}));
+
+  EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+                               ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
+      << ShapeUtil::HumanString(gather_shape);
+}
+
 TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
       tuple_shape_, s64_vector_32_,
       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
                                            /*elided_window_dims=*/{1},
-                                           /*gather_dims_to_operand_dims=*/{1}),
+                                           /*gather_dims_to_operand_dims=*/{1},
+                                           /*index_vector_dim=*/1),
       /*window_bounds=*/{64, 1});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1617,7 +1700,8 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
       s64_vector_32_, tuple_shape_,
       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
                                            /*elided_window_dims=*/{1},
-                                           /*gather_dims_to_operand_dims=*/{1}),
+                                           /*gather_dims_to_operand_dims=*/{1},
+                                           /*index_vector_dim=*/0),
       /*window_bounds=*/{64, 1});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1625,25 +1709,13 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
       << statusor.status();
 }
 
-TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
-  StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
-      s64_vector_32_, s32_,
-      HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
-                                           /*elided_window_dims=*/{1},
-                                           /*gather_dims_to_operand_dims=*/{1}),
-      /*window_bounds=*/{64, 1});
-  ASSERT_FALSE(statusor.ok());
-  EXPECT_THAT(statusor.status().error_message(),
-              HasSubstr("Gather indices parameter must at least of rank 1"))
-      << statusor.status();
-}
-
 TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
       s64_vector_32_, vector_32_,
       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
                                            /*elided_window_dims=*/{1},
-                                           /*gather_dims_to_operand_dims=*/{1}),
+                                           /*gather_dims_to_operand_dims=*/{1},
+                                           /*index_vector_dim=*/0),
       /*window_bounds=*/{64, 1});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1658,7 +1730,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 8, 7},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1674,7 +1747,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 7},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1690,7 +1764,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 99, 100, 101},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1699,13 +1774,30 @@ TEST_F(GatherShapeInferenceTest,
 }
 
 TEST_F(GatherShapeInferenceTest,
+       InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
+  StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+      f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+      HloInstruction::MakeGatherDimNumbers(
+          /*output_window_dims=*/{4, 5, 6, 7, 9},
+          /*elided_window_dims=*/{},
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
+      /*window_bounds=*/{30, 29, 28, 27, 26});
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().error_message(),
+              HasSubstr("Window index 4 in gather op is out of bounds"))
+      << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{4},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1722,7 +1814,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{0, 1, 2, 3, 19},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1738,7 +1831,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{0, 1, 2, 3, 3},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1755,15 +1849,15 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
       statusor.status().error_message(),
-      HasSubstr(
-          "There must be exactly as many elements in "
-          "gather_dims_to_operand_dims "
-          "as there are elements in the last dimension of %gather_indices"))
+      HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
+                "the bound of dimension index_vector_dim=4 of "
+                "gather_indices is 5. These two numbers must be equal."))
       << statusor.status();
 }
 
@@ -1774,7 +1868,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1791,7 +1886,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1808,7 +1904,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{2, 1},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{1, 1, 28, 27, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1822,7 +1919,8 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7},
           /*elided_window_dims=*/{2},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 1, 300, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1838,7 +1936,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7, 8},
           /*elided_window_dims=*/{},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 26});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(
@@ -1855,7 +1954,8 @@ TEST_F(GatherShapeInferenceTest,
       HloInstruction::MakeGatherDimNumbers(
           /*output_window_dims=*/{4, 5, 6, 7},
           /*elided_window_dims=*/{1},
-          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/4),
       /*window_bounds=*/{30, 29, 28, 26, 20});
   ASSERT_FALSE(statusor.ok());
   EXPECT_THAT(statusor.status().error_message(),
@@ -1864,5 +1964,22 @@ TEST_F(GatherShapeInferenceTest,
       << statusor.status();
 }
 
+TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
+  StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+      f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+      HloInstruction::MakeGatherDimNumbers(
+          /*output_window_dims=*/{4, 5, 6, 7, 8},
+          /*elided_window_dims=*/{},
+          /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+          /*index_vector_dim=*/32),
+      /*window_bounds=*/{30, 29, 28, 27, 26});
+
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().error_message(),
+              HasSubstr("Gather index leaf dimension must be within [0, "
+                        "rank(gather_indices) + 1)"))
+      << statusor.status();
+}
+
 }  // namespace
 }  // namespace xla
index 28620c3..1f16e6d 100644 (file)
@@ -418,6 +418,10 @@ message GatherDimensionNumbers {
   // transforms the gather index looked up from the gather_indices tensor into
   // the starting index in the input space.
   repeated int64 gather_dims_to_operand_dims = 3;
+
+  // The dimension in the gather_indices input that contains the starting
+  // indices.
+  int64 index_vector_dim = 4;
 }
 
 // Operation requests that are all collected as a tagged union with a oneof
index b0abf5f..b2190c5 100644 (file)
@@ -1050,6 +1050,9 @@ For a more intuitive description, see the "Informal Description" section below.
 :                  :                         : indices of the slices we're     :
 :                  :                         : we're stitching together into   :
 :                  :                         : the output tensor.              :
+|`index_vector_dim`  | `int64`               | The dimension in                |
+:                  :                         : `gather_indices` that contains  :
+:                  :                         : the starting indices.           :
 |`output_window_dims` | `ArraySlice<int64>`  | The set of dimensions in the    |
 :                  :                         : output shape that are _window   :
 :                  :                         : dimensions_ (defined below).    :
@@ -1066,22 +1069,20 @@ For a more intuitive description, see the "Informal Description" section below.
 :                  :            : `output_window_dims`) and the window         :
 :                  :            : dimensions that are elided (via              :
 :                  :            : `elided_window_dims`).                       :
-|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the  |
+|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the    |
 :                  :            : array is interpreted as mapping `i` to       :
 :                  :            : `gather_dims_to_operand_dims[i]`)  from      :
 :                  :            : the gather indices in `gather_indices` to    :
 :                  :            : the operand index space.  It has to be       :
 :                  :            : one-to-one and total.                        :
 
-If `gather_indices` is a vector with `N` elements then we implicitly reshape it
-to a tensor of shape `[N,1]` before proceeding.
-
 For every index `Out` in the output tensor, we compute two things (more
 precisely described later):
 
-  - An index into the first `gather_indices.rank` - `1` dimensions of
-    `gather_indices`, which gives us a starting index of a slice, _operand
-    slice_, in the operand tensor.
+  - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`,
+    which gives us a starting index of a slice, _operand slice_, in the operand
+    tensor.  These `gather_indices.rank` - `1` dimensions are all the dimensions
+    in `gather_indices` except `index_vector_dim`.
 
   - A _window index_ that has the same rank as the operand.  This index is
     composed of the values in `Out` at dimensions `output_window_dims`, embedded
@@ -1093,29 +1094,42 @@ should be present in the output at index `Out`.
 The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank`
 - `1`.  Additionally, as a shorthand, we define `output_gather_dims` of type
 `ArraySlice<int64>` as the set of dimensions in the output shape but not in
-`output_window_dims`, in ascending order.  E.g. if the output tensor has rank 5,
-`output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, `3`}
+`output_window_dims`, in ascending order.  E.g. if the output tensor has rank
+`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`,
+`3`}
+
+If `index_vector_dim` is equal to `gather_indices.rank` we implicitly
+consider `gather_indices` to have a trailing `1` dimension (i.e. if
+`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then
+we implicitly consider the shape of `gather_indices` to be `[6,7,1]`).
 
 The bounds for the output tensor along dimension `i` is computed as follows:
 
   1. If `i` is present in `output_gather_dims` (i.e. is equal to
-    `output_gather_dims[k]` for some `k`) then we pick the corresponding
-    dimension bounds out of `gather_indices.shape` (i.e. pick
-    `gather_indices.shape.dims[k]`).
+     `output_gather_dims[k]` for some `k`) then we pick the corresponding
+     dimension bounds out of `gather_indices.shape`, skipping
+     `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k`
+     < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`]
+     otherwise).
   2. If `i` is present in `output_window_dims` (i.e. equal to
-     `output_window_dims[k]` for some `k`) then we pick the corresponding bound
-     out of `window_bounds` after accounting for `elided_window_dims` (i.e. we
-     pick `adjusted_window_bounds[k]` where `adjusted_window_bounds` is
-     `window_bounds` with the bounds at indices `elided_window_dims` removed).
+     `output_window_dims`[`k`] for some `k`) then we pick the corresponding
+     bound out of `window_bounds` after accounting for `elided_window_dims`
+     (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds`
+     is `window_bounds` with the bounds at indices `elided_window_dims`
+     removed).
 
 The operand index `In` corresponding to an output index `Out` is computed as
 follows:
 
   1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }.  Use `G` to slice
-     out vector `S` such that `S`[`i`] = `gather_indices`[`G`, `i`].
-  2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by scattering
-     `S` using the `gather_dims_to_operand_dims` map (`S`<sub>`in`</sub> is the
-     starting indices for _operand slice_ mentioned above.).  More precisely:
+     out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)]
+     where Combine(A, b) inserts b at position `index_vector_dim` into A.
+     Note that this is well defined even if `G` is empty -- if `G` is empty then
+     `S` = `gather_indices`.
+  2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by
+     scattering `S` using the `gather_dims_to_operand_dims` map
+     (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned
+     above).  More precisely:
        1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` <
           `gather_dims_to_operand_dims.size`.
        2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
@@ -1136,7 +1150,12 @@ follows:
 `operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then
 `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
 
-### Informal Description
+### Informal Description and Examples
+
+`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the
+examples that follow.  More interesting values for `index_vector_dim`
+does not change the operation fundamentally, but makes the visual representation
+more cumbersome.
 
 To get an intuition on how all of the above fits together, let's look at an
 example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor.  The