Status XlaGather(const xla::ComputationDataHandle& input,
const TensorShape& input_shape,
const xla::ComputationDataHandle& indices,
- TensorShape indices_shape, int64 axis, bool indices_are_nd,
- DataType dtype, DataType index_type,
+ const TensorShape& indices_shape, int64 axis,
+ bool indices_are_nd, DataType dtype, DataType index_type,
xla::ComputationBuilder* builder,
xla::ComputationDataHandle* gather_output) {
+ // There is no deep reason why we need this precondition, but this is the only
+ // combination that is used and tested today.
+ CHECK(!indices_are_nd || axis == 0);
+
+ // num_index_dims is the number of components in each index in the indices
+ // tensor.
+ //
+ // num_indices is the total number of (n dimensional or scalar) indices in the
+ // indices tensor.
+ //
// If the indices are N-dimensional, then the minor dimension of indices
// should be of size N and correspond to the N indices.
- int64 num_index_dims = 1;
+ int64 num_index_dims;
+ int64 num_indices = 1;
if (indices_are_nd) {
CHECK_GE(indices_shape.dims(), 1);
num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
- indices_shape.RemoveLastDims(1);
+ for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
+ num_indices *= indices_shape.dim_size(i);
+ }
+ } else {
+ num_index_dims = 1;
+ for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
+ num_indices *= indices_shape.dim_size(i);
+ }
}
- // Although the indices Tensor is flattened into rank 1 during the lookup,
- // and each scalar entry is used as an index into the first dimension of the
- // input, the output is returned with shape:
- // input.shape[:axis] + indices.shape + input.shape[axis+1:]
-
- const int64 num_indices = indices_shape.num_elements();
- TensorShape input_shape_pre_axis(input_shape);
- input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
- TensorShape input_shape_post_axis(input_shape);
- input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
- // Each slice of the input tensor has shape:
- // [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>]
- TensorShape slice_shape(input_shape);
- for (int64 i = 0; i < num_index_dims; ++i) {
- slice_shape.set_dim(axis + i, 1);
- }
+ // Degenerate case: empty indices.
+ if (num_indices == 0) {
+ TensorShape input_shape_pre_axis{input_shape};
+ input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
+ TensorShape input_shape_post_axis{input_shape};
+ input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
- TensorShape loop_out_shape;
- loop_out_shape.AppendShape(input_shape_pre_axis);
- loop_out_shape.AddDim(num_indices);
- loop_out_shape.AppendShape(input_shape_post_axis);
- TensorShape loop_out_slice_shape;
- loop_out_slice_shape.AppendShape(input_shape_pre_axis);
- loop_out_slice_shape.AddDim(1);
- loop_out_slice_shape.AppendShape(input_shape_post_axis);
+ TensorShape indices_shape_no_index_vectors{indices_shape};
+ if (indices_are_nd) {
+ indices_shape_no_index_vectors.RemoveLastDims(1);
+ }
- TensorShape out_shape;
- out_shape.AppendShape(input_shape_pre_axis);
- out_shape.AppendShape(indices_shape);
- out_shape.AppendShape(input_shape_post_axis);
+ TensorShape out_shape;
+ out_shape.AppendShape(input_shape_pre_axis);
+ out_shape.AppendShape(indices_shape_no_index_vectors);
+ out_shape.AppendShape(input_shape_post_axis);
- // Degenerate case: empty indices.
- if (num_indices == 0) {
*gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
out_shape.dim_sizes());
return Status::OK();
}
}
- // Flatten the major dimensions of indices into a single dimension for ease of
- // iteration. If there is an axis dimension, we must leave it alone.
- std::vector<int64> flat_indices_shape = {num_indices};
- if (indices_are_nd) {
- flat_indices_shape.push_back(num_index_dims);
- }
-
- // Specify the shape of the loop-carried Tensor tuple.
-
- // Construct the initial values of the loop-carried Tensors.
- auto flat_indices = builder->Reshape(indices, flat_indices_shape);
- auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- loop_out_shape.dim_sizes());
- auto init = {input, flat_indices, init_out};
-
- // Construct the while loop body's function. The implementation of gather is:
- // for i in range(num_indices):
- // index = dynamic-slice(indices, i)
- // xi = dynamic-slice(input, index)
- // output = dynamic-update-slice(output, xi, i)
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* bodyb) {
- auto input = loop_vars[0];
- auto indices = loop_vars[1];
- auto output = loop_vars[2];
-
- auto zero_index = XlaHelpers::Zero(bodyb, index_type);
-
- // Slice the i-th index from the indices array.
- xla::ComputationDataHandle index;
- auto indices_offset = bodyb->Reshape(i, {1});
- if (indices_are_nd) {
- // Slice out the entire nd index, if applicable.
- indices_offset = bodyb->Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
- index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims});
- index = bodyb->Collapse(index, {0, 1});
+ // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
+ // tensor of shape [3,3].
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // gather = s32[3,2] gather(operand, indices),
+ // output_window_dims={0},
+ // elided_window_dims={1},
+ // gather_dims_to_operand_dims={1},
+ // index_vector_dim=1,
+ // window_bounds={3, 1}
+ //
+ //
+ // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
+ // tensor of shape [3,3,2].
+ //
+ // operand = s32[3,3,2] parameter(0)
+ // indices = s32[2,2] parameter(1)
+ // gather = s32[2,2] gather(operand, indices),
+ // output_window_dims={1},
+ // elided_window_dims={0,1},
+ // gather_dims_to_operand_dims={0,1},
+ // index_vector_dim=0,
+ // window_bounds={1,1,2}
+
+ xla::GatherDimensionNumbers dim_numbers;
+ std::vector<int64> window_bounds;
+ window_bounds.reserve(input_shape.dims());
+ for (int64 i = 0; i < input_shape.dims(); i++) {
+ int64 window_bound;
+ if (axis <= i && i < (axis + num_index_dims)) {
+ dim_numbers.add_elided_window_dims(i);
+ window_bound = 1;
} else {
- index = bodyb->DynamicSlice(indices, indices_offset, {1});
+ window_bound = input_shape.dim_size(i);
+ }
+
+ window_bounds.push_back(window_bound);
+
+ if (i < axis) {
+ dim_numbers.add_output_window_dims(i);
+ } else if (i >= (axis + num_index_dims)) {
+ int64 indices_rank =
+ indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
+ dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
}
+ }
+
+ dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
+ : indices_shape.dims());
+ for (int64 i = axis; i < axis + num_index_dims; i++) {
+ dim_numbers.add_gather_dims_to_operand_dims(i);
+ }
- // Slice the corresponding data from the input array.
- auto start_indices = bodyb->Pad(
- index, zero_index,
- xla::MakeEdgePaddingConfig(
- {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}}));
- auto slice_i = bodyb->Reshape(
- bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()),
- loop_out_slice_shape.dim_sizes());
-
- // Construct the index into the output Tensor 0, ..., <index>, 0, ...
- std::vector<xla::ComputationDataHandle> out_index_vals(
- loop_out_shape.dims(), bodyb->Reshape(zero_index, {1}));
- out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1});
- auto out_index = bodyb->ConcatInDim(out_index_vals, 0);
-
- // Update the output Tensor
- auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index);
-
- return std::vector<xla::ComputationDataHandle>{input, indices,
- updated_output};
- };
-
- // Construct the While loop, extract and reshape the output.
- xla::PrimitiveType ptype;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
- TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
- init, "gather", builder));
- *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
+ *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds);
return Status::OK();
}