static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
HloInstruction* gather_indices, int64 index_vector_dim) {
const Shape& gather_indices_shape = gather_indices->shape();
+
+ if (gather_indices_shape.dimensions_size() == index_vector_dim) {
+ return gather_indices;
+ }
+
if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
return gather_indices;
}
+
std::vector<int64> permutation;
permutation.reserve(gather_indices_shape.dimensions_size());
for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
return MakeTransposeHlo(gather_indices, permutation);
}
-// If the gather_indices holds scalar indices (i.e. gather_indices has rank N
-// and index_vector_dim is N) then reshape it to have a trailing degenerate
-// dimension. This makes the code for slicing out the index vector more
-// uniform.
-static StatusOr<HloInstruction*> DeScalarizeGatherIndices(
- HloInstruction* gather_indices, int64 index_vector_dim) {
- const Shape& gather_indices_shape = gather_indices->shape();
- if (index_vector_dim != gather_indices_shape.dimensions_size()) {
- return gather_indices;
- }
-
- DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size());
-
- std::vector<int64> result_shape_dims;
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(result_shape_dims));
- result_shape_dims.push_back(1);
-
- return MakeReshapeHlo(result_shape_dims, gather_indices);
-}
-
// Canonicalizes the gather_indices tensors so that we only have deal with some
// specific cases in the while loop that does the heavy lifting.
//
// See the "High Level Algorithm" section for a broader picture.
static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
HloInstruction* gather_indices, int64 index_vector_dim) {
- // If gather_indices holds scalar indices, normalize it to hold index vectors
- // of size 1.
+ // Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
- HloInstruction * descalarized_gather_indices,
- DeScalarizeGatherIndices(gather_indices, index_vector_dim));
+ HloInstruction * transposed_gather_indices,
+ TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+ bool indices_are_scalar =
+ index_vector_dim == gather_indices->shape().dimensions_size();
- // Transpose the non-index-vector dimensions to the front.
- TF_ASSIGN_OR_RETURN(HloInstruction * transposed_gather_indices,
- TransposeIndexVectorDimToLast(descalarized_gather_indices,
- index_vector_dim));
+ // The number of dimensions in gather_indices that are index dimensions.
+ const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
// If there is only one index (i.e. gather_indices has rank 1 and this gather
// is really just a dynamic slice) add a leading degenerate dimension for
// uniformity. Otherwise create a "collapsed" leading dimension that subsumes
// all of the non-index-vector dimensions.
const Shape& shape = transposed_gather_indices->shape();
- if (shape.dimensions_size() == 1) {
+ if (shape.dimensions_size() == index_dims_in_gather_indices) {
return PrependDegenerateDims(transposed_gather_indices, 1);
} else {
- return CollapseFirstNDims(transposed_gather_indices,
- shape.dimensions_size() - 1);
+ // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+ // index vectors.
+ return CollapseFirstNDims(
+ transposed_gather_indices,
+ shape.dimensions_size() - index_dims_in_gather_indices);
}
}
static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
const HloInstruction& gather, HloInstruction* induction_var,
const std::vector<HloInstruction*>& incoming_loop_state) {
+ const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
CHECK_EQ(incoming_loop_state.size(), 3);
HloInstruction* const operand = incoming_loop_state[0];
HloInstruction* const gather_indices = incoming_loop_state[1];
HloInstruction* const output_accumulator = incoming_loop_state[2];
- int64 index_vector_size = gather_indices->shape().dimensions(1);
+ bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+ CHECK_EQ(has_scalar_indices,
+ dim_numbers.index_vector_dim() ==
+ gather.operand(1)->shape().dimensions_size());
TF_ASSIGN_OR_RETURN(
HloInstruction * induction_var_as_vector,
MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/{1}));
- TF_ASSIGN_OR_RETURN(
- HloInstruction * index_into_gather_indices,
- PadVectorWithZeros(induction_var_as_vector,
- /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
-
- TF_ASSIGN_OR_RETURN(
- HloInstruction * index_vector_2d,
- MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
- {1, index_vector_size}));
+ HloInstruction* index_vector;
- TF_ASSIGN_OR_RETURN(HloInstruction * index_vector,
- ElideDegenerateDims(index_vector_2d, {0}));
+ if (has_scalar_indices) {
+ // In this case gather_indices has rank 1 and induction_var_as_vector (of
+ // shape {1}) is an index into this rank 1 tensor.
+ TF_ASSIGN_OR_RETURN(
+ index_vector,
+ MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+ } else {
+ // In this case gather_indices has rank 2 and induction_var_as_vector (of
+ // shape {1}) is an index into just the first dimension of this rank 2
+ // tensor.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_into_gather_indices,
+ PadVectorWithZeros(induction_var_as_vector,
+ /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
+
+ int64 index_vector_size = gather_indices->shape().dimensions(1);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_vector_2d,
+ MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ {1, index_vector_size}));
+
+ TF_ASSIGN_OR_RETURN(index_vector,
+ ElideDegenerateDims(index_vector_2d, {0}));
+ }
- TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_start,
- ExpandIndexVectorIntoOperandSpace(
- index_vector, gather.gather_dimension_numbers(),
- operand->shape().dimensions_size()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * gathered_slice_start,
+ ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
+ operand->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
MakeDynamicSliceHlo(operand, gathered_slice_start,
gather.gather_window_bounds()));
- TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_for_update,
- PrependDegenerateDims(gathered_slice, 1));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * gathered_slice_with_dims_elided,
+ ElideDegenerateDims(gathered_slice,
+ AsInt64Slice(dim_numbers.elided_window_dims())));
+
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * gathered_slice_for_update,
+ PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_into_accumulator,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
- /*zeros_to_append=*/gathered_slice->shape().dimensions_size()));
+ /*zeros_to_append=*/
+ gathered_slice_with_dims_elided->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(
HloInstruction * updated_accumulator,
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> window_bounds, int64 gather_loop_trip_count) {
+ ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+ const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
accumulator_state_shape_dims.reserve(1 + window_bounds.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- c_copy(window_bounds, std::back_inserter(accumulator_state_shape_dims));
+ for (int64 i = 0; i < window_bounds.size(); i++) {
+ if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ accumulator_state_shape_dims.push_back(window_bounds[i]);
+ }
+ }
return BroadcastZeros(computation, element_type,
accumulator_state_shape_dims);
}
-static StatusOr<HloInstruction*> ElideWindowDimsFromAccumulator(
- HloInstruction* accumulator, const GatherDimensionNumbers& dim_numbers) {
- std::vector<int64> dims_to_elide;
- dims_to_elide.reserve(dim_numbers.elided_window_dims_size());
- for (int64 elided_window_dim : dim_numbers.elided_window_dims()) {
- dims_to_elide.push_back(elided_window_dim + 1);
- }
-
- return ElideDegenerateDims(accumulator, dims_to_elide);
-}
-
// `accumulator` is almost the tensor the gather operation would have produced,
// except that it has the dimensions in the wrong order -- the gather dimensions
// are the major dimensions and the window dimensions are the minor dimensions.
HloInstruction * accumulator_init,
CreateGatherLoopAccumulatorInitValue(
computation, output_shape.element_type(),
- gather_instr->gather_window_bounds(), gather_loop_trip_count));
+ gather_instr->gather_window_bounds(), gather_loop_trip_count,
+ gather_instr->gather_dimension_numbers()));
StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
WhileUtil::MakeCountedLoop(
gather_loop_result_or_error);
HloInstruction* accumulator_result = gather_loop_result.back();
- TF_ASSIGN_OR_RETURN(
- HloInstruction * accumulator_with_window_dims_elided,
- ElideWindowDimsFromAccumulator(accumulator_result, dim_numbers));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- AdjustGatherDimsInAccumulator(gather_indices->shape(),
- accumulator_with_window_dims_elided,
+ AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
dim_numbers.index_vector_dim()));
return PermuteGatherAndWindowDims(
"indices are not supported."));
}
+TEST(GatherExpanderTest, AvoidDegenerateDims) {
+ const string hlo_text = R"(
+HloModule TensorFlowGatherV2
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ ROOT gather = s32[3,2] gather(operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3, 1}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
+ ASSERT_TRUE(changed);
+
+ HloInstruction* while_instr = nullptr;
+ for (auto* instr : module->entry_computation()->instructions()) {
+ if (instr->opcode() == HloOpcode::kWhile) {
+ ASSERT_EQ(while_instr, nullptr)
+ << "Expected exactly one while instruction in the entry computation "
+ "after gather expansion";
+ while_instr = instr;
+ }
+ }
+
+ ASSERT_NE(while_instr, nullptr)
+ << "Expected exactly one while instruction in the entry computation "
+ "after gather expansion";
+
+ // We want to avoid create while loop with shapes that have degenerate
+ // dimensions for TF gather. In this case we expect the loop state to be of
+ // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}). The leading
+ // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
+ // check it here (though in theory the form of the while loop state is itself
+ // an implementation detail from WhileUtil::MakeCountedLoop).
+
+ const Shape& while_shape = while_instr->shape();
+ ASSERT_TRUE(ShapeUtil::IsTuple(while_shape));
+ ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
+
+ EXPECT_TRUE(ShapeUtil::SameDimensions(
+ ShapeUtil::MakeShape(S32, {3, 3}),
+ ShapeUtil::GetTupleElementShape(while_shape, 1)));
+
+ EXPECT_TRUE(ShapeUtil::SameDimensions(
+ ShapeUtil::MakeShape(S32, {2}),
+ ShapeUtil::GetTupleElementShape(while_shape, 2)));
+
+ EXPECT_TRUE(ShapeUtil::SameDimensions(
+ ShapeUtil::MakeShape(S32, {2, 3}),
+ ShapeUtil::GetTupleElementShape(while_shape, 3)));
+}
} // namespace
} // namespace xla