[TF:XLA] Do not rely on implementation-defined semantics of DynamicSlice.
authorMichael Kuperstein <mkuper@google.com>
Thu, 17 May 2018 20:37:57 +0000 (13:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 20:44:32 +0000 (13:44 -0700)
ReverseSequence relies on DynamicSlice wrapping around, which is implementation-defined behavior, and is not guaranteed. Pad the input instead.

PiperOrigin-RevId: 197043307

tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc

index 0ed4c47..5d1c052 100644 (file)
@@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel {
           seq_lens, body_builder->Reshape(i, {1}), {1});
 
       // Indices is the offset of the batch element in the input.
-      auto indices = body_builder->Broadcast(
+      auto batch_element_indices = body_builder->Broadcast(
           XlaHelpers::Zero(body_builder.get(), seq_lens_type),
           {input_shape.dims()});
-      indices = body_builder->DynamicUpdateSlice(
-          indices, body_builder->Reshape(i, {1}),
+      batch_element_indices = body_builder->DynamicUpdateSlice(
+          batch_element_indices, body_builder->Reshape(i, {1}),
           body_builder->Reshape(
               XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
                                          batch_dim_),
               {1}));
 
-      // slice_indices is the offset of the start of the reversed sequence in
-      // the input.
-      auto slice_indices = body_builder->DynamicUpdateSlice(
-          indices,
+      // Slice out the current batch element and pad it out in the sequence
+      // dimension.
+      TensorShape slice_shape = input_shape;
+      slice_shape.set_dim(batch_dim_, 1);
+      slice_shape.set_dim(seq_dim_, max_seq_len);
+      auto slice = body_builder->DynamicSlice(output, batch_element_indices,
+                                              slice_shape.dim_sizes());
+      auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
+      padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
+          slice_shape.dim_size(seq_dim_));
+      slice = body_builder->Pad(
+          slice, XlaHelpers::Zero(body_builder.get(), input_type),
+          padding_config);
+
+      // Now slice out the reversed sequence from its actual start.
+      // sequence_start_indices is the offset of the start of the reversed
+      // sequence in the input. The slice will go into the padding, however, we
+      // will mask off these elements and replace them with elements from the
+      // original input so their values do not matter.
+      auto sequence_start_indices = body_builder->Broadcast(
+          XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+          {slice_shape.dims()});
+      sequence_start_indices = body_builder->DynamicUpdateSlice(
+          sequence_start_indices,
           body_builder->Sub(XlaHelpers::IntegerLiteral(
                                 body_builder.get(), seq_lens_type, max_seq_len),
                             seq_len),
@@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel {
               XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
                                          seq_dim_),
               {1}));
-
-      // Slice out the reversed sequence. The slice will overflow the end of the
-      // sequence, and the contents of the overflow are implementation-defined.
-      // However, we will mask off these elements and replace them with elements
-      // from the original input so their values do not matter.
-      TensorShape slice_shape = input_shape;
-      slice_shape.set_dim(batch_dim_, 1);
-      auto slice = body_builder->DynamicSlice(output, slice_indices,
-                                              slice_shape.dim_sizes());
+      slice = body_builder->DynamicSlice(slice, sequence_start_indices,
+                                         slice_shape.dim_sizes());
 
       // Shift the reversed sequence to the left.
-      output = body_builder->DynamicUpdateSlice(output, slice, indices);
+      output = body_builder->DynamicUpdateSlice(output, slice,
+                                                batch_element_indices);
 
       body_builder->Tuple(
           {body_builder->Add(