Avoid generating degenerate dimensions during gather expansions
authorSanjoy Das <sanjoy@google.com>
Wed, 18 Apr 2018 16:03:21 +0000 (09:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 18 Apr 2018 16:06:11 +0000 (09:06 -0700)
This gets rid of two cases that used to introduce degenerate dimensions
(dimensions with bound = 1) into the while loop state:

 - Previously we'd explicitly reshape gathers using scalar indices to have a
   minor degenerate dimension.  With this CL we no longer do that - instead we
   push this into the code that looks up the index vector from the gather
   indices tensor.
 - Previously we'd have the accumulator (the tensor we're
   dynamic-update-slice-ing into) contain all of the degenerate window dims that
   the gather op would later elide (after the while loop).  With this CL we
   eagerly elide these dimensions as we slice out individual windows from the
   operand.

PiperOrigin-RevId: 193365863

tensorflow/compiler/xla/service/gather_expander.cc
tensorflow/compiler/xla/service/gather_expander_test.cc

index 1239f56..2d3e4b1 100644 (file)
@@ -28,9 +28,15 @@ using tensorflow::gtl::ArraySlice;
 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
     HloInstruction* gather_indices, int64 index_vector_dim) {
   const Shape& gather_indices_shape = gather_indices->shape();
+
+  if (gather_indices_shape.dimensions_size() == index_vector_dim) {
+    return gather_indices;
+  }
+
   if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
     return gather_indices;
   }
+
   std::vector<int64> permutation;
   permutation.reserve(gather_indices_shape.dimensions_size());
   for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
@@ -42,54 +48,35 @@ static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
   return MakeTransposeHlo(gather_indices, permutation);
 }
 
-// If the gather_indices holds scalar indices (i.e. gather_indices has rank N
-// and index_vector_dim is N) then reshape it to have a trailing degenerate
-// dimension.  This makes the code for slicing out the index vector more
-// uniform.
-static StatusOr<HloInstruction*> DeScalarizeGatherIndices(
-    HloInstruction* gather_indices, int64 index_vector_dim) {
-  const Shape& gather_indices_shape = gather_indices->shape();
-  if (index_vector_dim != gather_indices_shape.dimensions_size()) {
-    return gather_indices;
-  }
-
-  DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size());
-
-  std::vector<int64> result_shape_dims;
-  c_copy(gather_indices_shape.dimensions(),
-         std::back_inserter(result_shape_dims));
-  result_shape_dims.push_back(1);
-
-  return MakeReshapeHlo(result_shape_dims, gather_indices);
-}
-
 // Canonicalizes the gather_indices tensors so that we only have deal with some
 // specific cases in the while loop that does the heavy lifting.
 //
 // See the "High Level Algorithm" section for a broader picture.
 static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
     HloInstruction* gather_indices, int64 index_vector_dim) {
-  // If gather_indices holds scalar indices, normalize it to hold index vectors
-  // of size 1.
+  // Transpose the non-index-vector dimensions to the front.
   TF_ASSIGN_OR_RETURN(
-      HloInstruction * descalarized_gather_indices,
-      DeScalarizeGatherIndices(gather_indices, index_vector_dim));
+      HloInstruction * transposed_gather_indices,
+      TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+  bool indices_are_scalar =
+      index_vector_dim == gather_indices->shape().dimensions_size();
 
-  // Transpose the non-index-vector dimensions to the front.
-  TF_ASSIGN_OR_RETURN(HloInstruction * transposed_gather_indices,
-                      TransposeIndexVectorDimToLast(descalarized_gather_indices,
-                                                    index_vector_dim));
+  // The number of dimensions in gather_indices that are index dimensions.
+  const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
 
   // If there is only one index (i.e. gather_indices has rank 1 and this gather
   // is really just a dynamic slice) add a leading degenerate dimension for
   // uniformity.  Otherwise create a "collapsed" leading dimension that subsumes
   // all of the non-index-vector dimensions.
   const Shape& shape = transposed_gather_indices->shape();
-  if (shape.dimensions_size() == 1) {
+  if (shape.dimensions_size() == index_dims_in_gather_indices) {
     return PrependDegenerateDims(transposed_gather_indices, 1);
   } else {
-    return CollapseFirstNDims(transposed_gather_indices,
-                              shape.dimensions_size() - 1);
+    // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+    // index vectors.
+    return CollapseFirstNDims(
+        transposed_gather_indices,
+        shape.dimensions_size() - index_dims_in_gather_indices);
   }
 }
 
@@ -156,48 +143,73 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
 static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
     const HloInstruction& gather, HloInstruction* induction_var,
     const std::vector<HloInstruction*>& incoming_loop_state) {
+  const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
   CHECK_EQ(incoming_loop_state.size(), 3);
   HloInstruction* const operand = incoming_loop_state[0];
   HloInstruction* const gather_indices = incoming_loop_state[1];
   HloInstruction* const output_accumulator = incoming_loop_state[2];
 
-  int64 index_vector_size = gather_indices->shape().dimensions(1);
+  bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+  CHECK_EQ(has_scalar_indices,
+           dim_numbers.index_vector_dim() ==
+               gather.operand(1)->shape().dimensions_size());
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * induction_var_as_vector,
       MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
                        /*result_shape_bounds=*/{1}));
 
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * index_into_gather_indices,
-      PadVectorWithZeros(induction_var_as_vector,
-                         /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
-
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * index_vector_2d,
-      MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
-                          {1, index_vector_size}));
+  HloInstruction* index_vector;
 
-  TF_ASSIGN_OR_RETURN(HloInstruction * index_vector,
-                      ElideDegenerateDims(index_vector_2d, {0}));
+  if (has_scalar_indices) {
+    // In this case gather_indices has rank 1 and induction_var_as_vector (of
+    // shape {1}) is an index into this rank 1 tensor.
+    TF_ASSIGN_OR_RETURN(
+        index_vector,
+        MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+  } else {
+    // In this case gather_indices has rank 2 and induction_var_as_vector (of
+    // shape {1}) is an index into just the first dimension of this rank 2
+    // tensor.
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * index_into_gather_indices,
+        PadVectorWithZeros(induction_var_as_vector,
+                           /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
+
+    int64 index_vector_size = gather_indices->shape().dimensions(1);
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * index_vector_2d,
+        MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+                            {1, index_vector_size}));
+
+    TF_ASSIGN_OR_RETURN(index_vector,
+                        ElideDegenerateDims(index_vector_2d, {0}));
+  }
 
-  TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_start,
-                      ExpandIndexVectorIntoOperandSpace(
-                          index_vector, gather.gather_dimension_numbers(),
-                          operand->shape().dimensions_size()));
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * gathered_slice_start,
+      ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
+                                        operand->shape().dimensions_size()));
 
   TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
                       MakeDynamicSliceHlo(operand, gathered_slice_start,
                                           gather.gather_window_bounds()));
 
-  TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_for_update,
-                      PrependDegenerateDims(gathered_slice, 1));
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * gathered_slice_with_dims_elided,
+      ElideDegenerateDims(gathered_slice,
+                          AsInt64Slice(dim_numbers.elided_window_dims())));
+
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * gathered_slice_for_update,
+      PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * index_vector_into_accumulator,
       PadVectorWithZeros(
           induction_var_as_vector, /*zeros_to_prepend=*/0,
-          /*zeros_to_append=*/gathered_slice->shape().dimensions_size()));
+          /*zeros_to_append=*/
+          gathered_slice_with_dims_elided->shape().dimensions_size()));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * updated_accumulator,
@@ -213,26 +225,20 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
 
 static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
     HloComputation* computation, PrimitiveType element_type,
-    ArraySlice<int64> window_bounds, int64 gather_loop_trip_count) {
+    ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+    const GatherDimensionNumbers& dim_numbers) {
   std::vector<int64> accumulator_state_shape_dims;
   accumulator_state_shape_dims.reserve(1 + window_bounds.size());
   accumulator_state_shape_dims.push_back(gather_loop_trip_count);
-  c_copy(window_bounds, std::back_inserter(accumulator_state_shape_dims));
+  for (int64 i = 0; i < window_bounds.size(); i++) {
+    if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
+      accumulator_state_shape_dims.push_back(window_bounds[i]);
+    }
+  }
   return BroadcastZeros(computation, element_type,
                         accumulator_state_shape_dims);
 }
 
-static StatusOr<HloInstruction*> ElideWindowDimsFromAccumulator(
-    HloInstruction* accumulator, const GatherDimensionNumbers& dim_numbers) {
-  std::vector<int64> dims_to_elide;
-  dims_to_elide.reserve(dim_numbers.elided_window_dims_size());
-  for (int64 elided_window_dim : dim_numbers.elided_window_dims()) {
-    dims_to_elide.push_back(elided_window_dim + 1);
-  }
-
-  return ElideDegenerateDims(accumulator, dims_to_elide);
-}
-
 // `accumulator` is almost the tensor the gather operation would have produced,
 // except that it has the dimensions in the wrong order -- the gather dimensions
 // are the major dimensions and the window dimensions are the minor dimensions.
@@ -331,7 +337,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
       HloInstruction * accumulator_init,
       CreateGatherLoopAccumulatorInitValue(
           computation, output_shape.element_type(),
-          gather_instr->gather_window_bounds(), gather_loop_trip_count));
+          gather_instr->gather_window_bounds(), gather_loop_trip_count,
+          gather_instr->gather_dimension_numbers()));
 
   StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
       WhileUtil::MakeCountedLoop(
@@ -346,14 +353,10 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
                       gather_loop_result_or_error);
 
   HloInstruction* accumulator_result = gather_loop_result.back();
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * accumulator_with_window_dims_elided,
-      ElideWindowDimsFromAccumulator(accumulator_result, dim_numbers));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
-      AdjustGatherDimsInAccumulator(gather_indices->shape(),
-                                    accumulator_with_window_dims_elided,
+      AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
                                     dim_numbers.index_vector_dim()));
 
   return PermuteGatherAndWindowDims(
index ba41ee8..1c72ca0 100644 (file)
@@ -47,5 +47,62 @@ ENTRY main {
                            "indices are not supported."));
 }
 
+TEST(GatherExpanderTest, AvoidDegenerateDims) {
+  const string hlo_text = R"(
+HloModule TensorFlowGatherV2
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  ROOT gather = s32[3,2] gather(operand, indices),
+      output_window_dims={0},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=1,
+      window_bounds={3, 1}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_text));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
+  ASSERT_TRUE(changed);
+
+  HloInstruction* while_instr = nullptr;
+  for (auto* instr : module->entry_computation()->instructions()) {
+    if (instr->opcode() == HloOpcode::kWhile) {
+      ASSERT_EQ(while_instr, nullptr)
+          << "Expected exactly one while instruction in the entry computation "
+             "after gather expansion";
+      while_instr = instr;
+    }
+  }
+
+  ASSERT_NE(while_instr, nullptr)
+      << "Expected exactly one while instruction in the entry computation "
+         "after gather expansion";
+
+  // We want to avoid create while loop with shapes that have degenerate
+  // dimensions for TF gather.  In this case we expect the loop state to be of
+  // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}).  The leading
+  // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
+  // check it here (though in theory the form of the while loop state is itself
+  // an implementation detail from WhileUtil::MakeCountedLoop).
+
+  const Shape& while_shape = while_instr->shape();
+  ASSERT_TRUE(ShapeUtil::IsTuple(while_shape));
+  ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
+
+  EXPECT_TRUE(ShapeUtil::SameDimensions(
+      ShapeUtil::MakeShape(S32, {3, 3}),
+      ShapeUtil::GetTupleElementShape(while_shape, 1)));
+
+  EXPECT_TRUE(ShapeUtil::SameDimensions(
+      ShapeUtil::MakeShape(S32, {2}),
+      ShapeUtil::GetTupleElementShape(while_shape, 2)));
+
+  EXPECT_TRUE(ShapeUtil::SameDimensions(
+      ShapeUtil::MakeShape(S32, {2, 3}),
+      ShapeUtil::GetTupleElementShape(while_shape, 3)));
+}
 }  // namespace
 }  // namespace xla