I discovered these when changing the tf2xla bridge to directly emit gather
operations.
- DeScalarizeGatherIndices was assuming that gather_indices must be of at least
rank 1. Fix this to be more general.
- We were passing in the wrong version of gather indices to
ExpandFirstDimIntoNDims. We don't strictly need to pass in
transposed_gather_indices (since if transposed_gather_indices is rank 1 then
the transpose has to be an identity transpose), passing in
descalarized_gather_indices would also have been fine, but
transposed_gather_indices seems more uniform.
- ExpandGatherDimsInAccumulator was assuming that gather_indices must be of at
least rank 1 (by calling CollapseFirstNDims). Fix this to be more general.
- We were trying to go through with emitting zero sized gather operations. I
don't think it is worth dealing with all of the edge cases this would expose
so now we just punt to ZeroSizedHloElimination.
PiperOrigin-RevId:
189696444
return gather_indices;
}
- int64 last_index = gather_indices_shape.dimensions(
- gather_indices_shape.dimensions_size() - 1);
- return ExpandLastDimIntoNDims(gather_indices, {last_index, 1});
+ DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size());
+
+ std::vector<int64> result_shape_dims;
+ c_copy(gather_indices_shape.dimensions(),
+ std::back_inserter(result_shape_dims));
+ result_shape_dims.push_back(1);
+
+ return MakeReshapeHlo(result_shape_dims, gather_indices);
}
// Canonicalizes the gather_indices tensors so that we only have deal with some
// all of the non-index-vector dimensions.
const Shape& shape = transposed_gather_indices->shape();
if (shape.dimensions_size() == 1) {
- return ExpandFirstDimIntoNDims(gather_indices, {1, shape.dimensions(0)});
+ return ExpandFirstDimIntoNDims(transposed_gather_indices,
+ {1, shape.dimensions(0)});
} else {
return CollapseFirstNDims(transposed_gather_indices,
shape.dimensions_size() - 1);
}
}
-// Expands out the gather dimensions in the accumulator produced by the while
-// loop.
-static StatusOr<HloInstruction*> ExpandGatherDimsInAccumulator(
+// Expands out or contracts away the gather dimensions in the accumulator
+// produced by the while loop.
+static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
const Shape& gather_indices_shape, HloInstruction* accumulator,
int64 index_vector_dim) {
std::vector<int64> output_gather_dim_bounds;
if (output_gather_dim_bounds.empty()) {
// If output_gather_dim_bounds is empty we must be lowering a (effectively)
- // dynamic-slice.
+ // dynamic-slice. In that case, there is a leading degenerate gather
+ // dimension that we added to make this special case play well with the
+ // general while loop which we need to remove now.
CHECK_EQ(accumulator->shape().dimensions(0), 1);
- return CollapseFirstNDims(accumulator, 2);
+ ArraySlice<int64> reshaped_dim_sizes =
+ AsInt64Slice(accumulator->shape().dimensions());
+ reshaped_dim_sizes.remove_prefix(1);
+ return MakeReshapeHlo(reshaped_dim_sizes, accumulator);
}
return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* gather_instr) {
+ CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape()));
+
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
HloInstruction* gather_indices = gather_instr->mutable_operand(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- ExpandGatherDimsInAccumulator(gather_indices->shape(),
+ AdjustGatherDimsInAccumulator(gather_indices->shape(),
accumulator_with_window_dims_elided,
dim_numbers.index_vector_dim()));
}
StatusOr<bool> GatherExpander::Run(HloModule* module) {
+ auto is_nontrivial_gather = [](HloInstruction* inst) {
+ return inst->opcode() == HloOpcode::kGather &&
+ // Avoid expanding gather ops that produce zero sized tensors,
+ // instead punt these to ZeroSizedHloElimination.
+ !ShapeUtil::HasZeroElements(inst->shape());
+ };
+
std::vector<HloInstruction*> gather_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
c_copy_if(computation->instructions(), std::back_inserter(gather_instrs),
- [](HloInstruction* inst) {
- return inst->opcode() == HloOpcode::kGather;
- });
+ is_nontrivial_gather);
}
for (HloInstruction* inst : gather_instrs) {
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
- HloInstruction* operand, ArraySlice<int64> expanded_dims) {
- CHECK_GT(operand->shape().dimensions_size(), 0);
- CHECK_EQ(operand->shape().dimensions(operand->shape().dimensions_size() - 1),
- Product(expanded_dims));
-
- std::vector<int64> expanded_shape_dim_bounds;
- expanded_shape_dim_bounds.reserve(expanded_dims.size() +
- operand->shape().dimensions_size() - 1);
- std::copy(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end() - 1,
- std::back_inserter(expanded_shape_dim_bounds));
- c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
- Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
- expanded_shape_dim_bounds);
- return MakeReshapeHlo(new_shape, operand);
-}
-
StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
ArraySlice<int64> dims_to_elide) {
CHECK(c_is_sorted(dims_to_elide));
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
-// Expands (via reshape) the last (logical) dimension of `operand` into a
-// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1
-// and the number of elements in its last dimension must be equal to the
-// product of `expanded_dims`.
-//
-// For instance if `operand` has shape f32[9,7,200] and expanded_dims is
-// {2,5,20} the result is `operand` reshaped to [9,7,2,5,20].
-StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
-
// Elides (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_elide` from `operand`. Every dimension in
// `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be
{operand.get(), gather_indices.get(), in_bounds_mask.get()});
}
+XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
+ const char* hlo_text = R"(
+HloModule OneScalarIndex
+
+ENTRY main {
+ operand = s32[2,3,2]{2,1,0} parameter(0)
+ index = s32[] parameter(1)
+ ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
+ output_window_dims={0,1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=0,
+ window_bounds={1,3,2}
+}
+)";
+ std::unique_ptr<Literal> operand = Literal::CreateR3<int32>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, ScalarResult) {
+ const char* hlo_text = R"(
+HloModule ScalarResult
+
+ENTRY main {
+ operand = s32[4]{0} parameter(0)
+ index = s32[] parameter(1)
+ ROOT gather = s32[] gather(operand, index),
+ output_window_dims={},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=0,
+ window_bounds={1}
+}
+)";
+ std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
+ const string hlo_text = R"(
+HloModule ZeroSizedResult
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[0] parameter(1)
+ ROOT gather = s32[0,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1, 3}
+}
+)";
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
} // namespace
} // namespace xla