seq_lens, body_builder->Reshape(i, {1}), {1});
// Indices is the offset of the batch element in the input.
- auto indices = body_builder->Broadcast(
+ auto batch_element_indices = body_builder->Broadcast(
XlaHelpers::Zero(body_builder.get(), seq_lens_type),
{input_shape.dims()});
- indices = body_builder->DynamicUpdateSlice(
- indices, body_builder->Reshape(i, {1}),
+ batch_element_indices = body_builder->DynamicUpdateSlice(
+ batch_element_indices, body_builder->Reshape(i, {1}),
body_builder->Reshape(
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
batch_dim_),
{1}));
- // slice_indices is the offset of the start of the reversed sequence in
- // the input.
- auto slice_indices = body_builder->DynamicUpdateSlice(
- indices,
+ // Slice out the current batch element and pad it out in the sequence
+ // dimension.
+ TensorShape slice_shape = input_shape;
+ slice_shape.set_dim(batch_dim_, 1);
+ slice_shape.set_dim(seq_dim_, max_seq_len);
+ auto slice = body_builder->DynamicSlice(output, batch_element_indices,
+ slice_shape.dim_sizes());
+ auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
+ padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
+ slice_shape.dim_size(seq_dim_));
+ slice = body_builder->Pad(
+ slice, XlaHelpers::Zero(body_builder.get(), input_type),
+ padding_config);
+
+ // Now slice out the reversed sequence from its actual start.
+ // sequence_start_indices is the offset of the start of the reversed
+ // sequence in the input. The slice will go into the padding, however, we
+ // will mask off these elements and replace them with elements from the
+ // original input so their values do not matter.
+ auto sequence_start_indices = body_builder->Broadcast(
+ XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {slice_shape.dims()});
+ sequence_start_indices = body_builder->DynamicUpdateSlice(
+ sequence_start_indices,
body_builder->Sub(XlaHelpers::IntegerLiteral(
body_builder.get(), seq_lens_type, max_seq_len),
seq_len),
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
seq_dim_),
{1}));
-
- // Slice out the reversed sequence. The slice will overflow the end of the
- // sequence, and the contents of the overflow are implementation-defined.
- // However, we will mask off these elements and replace them with elements
- // from the original input so their values do not matter.
- TensorShape slice_shape = input_shape;
- slice_shape.set_dim(batch_dim_, 1);
- auto slice = body_builder->DynamicSlice(output, slice_indices,
- slice_shape.dim_sizes());
+ slice = body_builder->DynamicSlice(slice, sequence_start_indices,
+ slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
- output = body_builder->DynamicUpdateSlice(output, slice, indices);
+ output = body_builder->DynamicUpdateSlice(output, slice,
+ batch_element_indices);
body_builder->Tuple(
{body_builder->Add(