[XLA] Clamp indices in DynamicSlice and DynamicUpdateSlice instead of wrapping.
authorMichael Kuperstein <mkuper@google.com>
Fri, 18 May 2018 01:09:45 +0000 (18:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 01:12:26 +0000 (18:12 -0700)
This implements the following index clamping in all backends (CPU, GPU, Interpreter):

for(int i = 0; i < rank; ++i)
  start_index[i] = clamp(start_index[i], 0, output_dim_size[i] - update_dim_size[i])

Which ensures the slice (or update region) is always inbounds w.r.t the input.

PiperOrigin-RevId: 197082276

tensorflow/compiler/xla/reference_util.h
tensorflow/compiler/xla/service/elemental_ir_emitter.cc
tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
tensorflow/compiler/xla/service/llvm_ir/ops.cc
tensorflow/compiler/xla/tests/dynamic_ops_test.cc

index 28d6a8c..2698ba7 100644 (file)
@@ -330,13 +330,14 @@ class ReferenceUtil {
     return result;
   }
 
-  // Slices with modulo-wrapping.
+  // Slices with index clamping
   template <typename T>
-  static std::vector<T> ModSlice1D(const tensorflow::gtl::ArraySlice<T>& input,
-                                   int64 start, int64 size) {
+  static std::vector<T> ClampSlice1D(
+      const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
+    start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
     std::vector<T> result;
     for (int64 i = 0; i < size; ++i) {
-      result.push_back(input[(start + i) % input.size()]);
+      result.push_back(input[(start + i)]);
     }
     return result;
   }
index 0a400e9..9a8bab3 100644 (file)
@@ -1547,6 +1547,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
     llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
     TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
                         operand_to_generator.at(hlo->operand(1))(dim_index));
+
+    // Clamp the start index so that the sliced portion fits in the operand:
+    // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
+
+    // TODO(b/74360564): This is implementation defined behavior, but is
+    // currently respected by all implementations. Change this if we ever decide
+    // to oficially document different behavior.
+    start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
+                                                         index[i]->getType());
+    llvm::Value* operand_dim_size = llvm::ConstantInt::get(
+        start_index_value->getType(), input_hlo->shape().dimensions(i));
+    llvm::Value* output_dim_size = llvm::ConstantInt::get(
+        start_index_value->getType(), hlo->shape().dimensions(i));
+
+    start_index_value = EmitIntegralMin(
+        ir_builder_->CreateSub(operand_dim_size, output_dim_size),
+        EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
+                        start_index_value, /*is_signed=*/true),
+        /*is_signed=*/true);
+
     start_index_value->setName(
         AsStringRef(IrName(hlo, StrCat("start_idx", i))));
     slice_start_index[i] = start_index_value;
@@ -1555,14 +1575,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
   llvm_ir::IrArray::Index input_index(rank);
   for (int64 i = 0; i < rank; ++i) {
     // Emit IR which computes:
-    //   input_index = (start_index + offset_index) % dim_size
-    // Security note: this is the code that keeps the indices in-bounds.
-    llvm::Value* dim_size = llvm::ConstantInt::get(
-        index[i]->getType(), input_hlo->shape().dimensions(i));
-    llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
-        slice_start_index[i], index[i]->getType());
-    input_index[i] = ir_builder_->CreateURem(
-        ir_builder_->CreateAdd(start_index, index[i]), dim_size);
+    //   input_index = start_index + offset_index
+    input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]);
   }
   return operand_to_generator.at(input_hlo)(input_index);
 }
@@ -1661,104 +1675,48 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
   const int64 rank = ShapeUtil::Rank(input_hlo->shape());
   llvm_ir::IrArray::Index slice_start_index(rank);
   llvm_ir::IrArray::Index slice_limit_index(rank);
-  // Slice starts at update[index - slice_start_index_adjusted],
-  // where adjusted value = slice_start_index when in bounds, and
-  // adjusted value = slice_start_index - input_dim, when wrapping.
-  llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
-
   // Slice intersection gathers (ANDs) conditions on all ranks for which
   // 'input' is set to 'update'
   llvm::Value* slice_intersection = ir_builder_->getTrue();
 
   for (int64 i = 0; i < rank; ++i) {
-    // Emit IR to read dynamic start indices from 'start_hlo'.
     llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
     TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
                         operand_to_generator.at(start_hlo)(dim_index));
-    start_index_value->setName(
-        AsStringRef(IrName(hlo, StrCat("start_idx", i))));
-    slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
-        start_index_value, index[i]->getType());
 
+    // Clamp the start index so that the update region fits in the operand.
+    // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
+
+    // TODO(b/74360564): This is implementation defined behavior, but is
+    // currently respected by all implementations. Change this if we ever decide
+    // to oficially document different behavior.
+    start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
+                                                         index[i]->getType());
     llvm::Value* input_dim_size = llvm::ConstantInt::get(
         index[i]->getType(), input_hlo->shape().dimensions(i));
     llvm::Value* update_dim_size = llvm::ConstantInt::get(
         index[i]->getType(), update_hlo->shape().dimensions(i));
 
-    // Generate code to handle wrapping semantics:
-    // slice_start_index[i] = slice_start_index[i] % input_dim_size;
-    // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
-    // slice_start_index[i] is updated in place and it will now be in
-    // range. slice_limit_index[i] may be out of range, and it's being
-    // URem-ed below if so.
-    slice_start_index[i] =
-        ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
+    start_index_value = EmitIntegralMin(
+        ir_builder_->CreateSub(input_dim_size, update_dim_size),
+        EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
+                        start_index_value, /*is_signed=*/true),
+        /*is_signed=*/true);
+
+    start_index_value->setName(
+        AsStringRef(IrName(hlo, StrCat("start_idx", i))));
+    slice_start_index[i] = start_index_value;
     slice_limit_index[i] =
         ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
 
-    // Test if slice_limit_index[i] is in bounds
-    llvm::Value* in_bounds =
-        ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
-    llvm_ir::LlvmIfData if_in_bounds =
-        llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
-
-    // Handle true BB (slice_limit_index[i] <= input_dim_size).
-    SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
-    // Check that index[i] >= slice_start_index[i] &&
-    //            index[i] < slice_limit_index[i]
-    llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
+    slice_intersection = ir_builder_->CreateAnd(
         slice_intersection,
         ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
-        "slice_intersection_in");
-    slice_intersection_in_bounds = ir_builder_->CreateAnd(
-        slice_intersection_in_bounds,
+        "slice_intersection");
+    slice_intersection = ir_builder_->CreateAnd(
+        slice_intersection,
         ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
-        "slice_intersection_in");
-
-    // Handle false BB (slice_limit_index[i] > input_dim_size).
-    SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
-    // Check that index[i] >= slice_start_index[i] ||
-    //            index[i] < slice_limit_index[i]%input_dim_size.
-    llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
-        index[i],
-        ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
-    llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
-        ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps,
-        "slice_intersection_out");
-    llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd(
-        slice_intersection, slice_intersection_or, "slice_intersection_out");
-    // Create value for slice_start_index_adjusted[i] when out of bounds.
-    // If within out-of-bounds if.
-    llvm_ir::LlvmIfData if_start_needs_adjustment =
-        llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
-    SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_);
-    llvm::Value* slice_start_index_adjusted_oob =
-        ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
-    SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_);
-    llvm::PHINode* slice_start_index_adjusted_phi =
-        ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2);
-    slice_start_index_adjusted_phi->addIncoming(
-        slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block);
-    slice_start_index_adjusted_phi->addIncoming(
-        slice_start_index[i], if_start_needs_adjustment.false_block);
-    // End of if within if.
-
-    // After checking in/out of bounds.
-    SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
-    llvm::PHINode* phi_slice_intersection =
-        ir_builder_->CreatePHI(slice_intersection->getType(), 2);
-    phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
-                                        if_in_bounds.true_block);
-    phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds,
-                                        if_start_needs_adjustment.after_block);
-    slice_intersection = phi_slice_intersection;
-
-    llvm::PHINode* phi_index =
-        ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
-    phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
-    phi_index->addIncoming(slice_start_index_adjusted_phi,
-                           if_start_needs_adjustment.after_block);
-    slice_start_index_adjusted[i] = phi_index;
+        "slice_intersection");
   }
 
   // Emit:
@@ -1775,12 +1733,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
   // Compute update index for intersection case.
   llvm_ir::IrArray::Index update_index(rank);
   for (int64 i = 0; i < rank; ++i) {
-    llvm::Value* update_dim_size = llvm::ConstantInt::get(
-        index[i]->getType(), update_hlo->shape().dimensions(i));
-    // NOTE: Subtraction will be positive due to bounds checking above.
-    update_index[i] = ir_builder_->CreateURem(
-        ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
-        update_dim_size);
+    update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
   }
   TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
                       operand_to_generator.at(update_hlo)(update_index));
index 52db58f..024e875 100644 (file)
@@ -1986,17 +1986,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
     std::vector<int64> start(start_indices_typed.begin(),
                              start_indices_typed.end());
 
-    std::vector<int64> operand_indices(start.size());
+    // Clamp the start indices so the slice is in-bounds w.r.t the operand.
+
+    // TODO(b/74360564): This is implementation defined behavior, but is
+    // currently respected by all implementations. Change this if we ever decide
+    // to oficially document different behavior.
+    for (int64 i = 0; i < start.size(); ++i) {
+      start[i] = std::min<int64>(
+          std::max(0LL, start[i]),
+          operand_literal.shape().dimensions(i) - result_shape.dimensions(i));
+    }
 
+    std::vector<int64> operand_indices(start.size());
     auto result = MakeUnique<Literal>(result_shape);
     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
         [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
           for (int64 i = 0; i < operand_indices.size(); ++i) {
             CHECK_GE(multi_index[i] + start[i], 0);
-            // Mod is only used here to be consistent with the existing
-            // backends' behavior.
-            operand_indices[i] = (multi_index[i] + start[i]) %
-                                 operand_literal.shape().dimensions(i);
+            operand_indices[i] = multi_index[i] + start[i];
           }
 
           auto result = operand_literal.Get<ReturnT>(operand_indices);
@@ -2013,23 +2020,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
     auto result = operand_literal.CloneToUnique();
     auto start_indices_typed = start_indices_literal.data<IndexT>();
     const auto rank = ShapeUtil::Rank(result->shape());
-    std::vector<int64> start(rank, 0);
+    std::vector<int64> start(start_indices_typed.begin(),
+                             start_indices_typed.end());
+    // Clamp the update start indices so the slice is in-bounds w.r.t the
+    // operand.
+
+    // TODO(b/74360564): This is implementation defined behavior, but is
+    // currently respected by all implementations. Change this if we ever decide
+    // to oficially document different behavior.
     for (int64 i = 0; i < rank; ++i) {
-      // All other implementations currently wrap-around the index, so this
-      // should do so as well.
-      start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
-      start[i] += (start[i] < 0) * result->shape().dimensions(i);
+      start[i] = std::min<int64>(
+          std::max<int64>(0, start[i]),
+          result->shape().dimensions(i) - update_literal.shape().dimensions(i));
     }
     std::vector<int64> result_index(rank, 0);
 
     auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
       std::transform(update_index.begin(), update_index.end(), start.begin(),
                      result_index.begin(), std::plus<int64>());
-      // Same as above, wrap-around only to match other implementations'
-      // semantics.
-      std::transform(result_index.begin(), result_index.end(),
-                     result->shape().dimensions().begin(), result_index.begin(),
-                     std::modulus<int64>());
       result->Set<ReturnT>(result_index,
                            update_literal.Get<ReturnT>(update_index));
       return true;
index 34899b7..dacc547 100644 (file)
@@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
   for (int64 i = 0; i < rank; ++i) {
     IrArray::Index dim_index({ir_builder->getInt64(i)});
     TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index));
+    llvm::Value* output_dim_size = llvm::ConstantInt::get(
+        start_index[i]->getType(), output_shape.dimensions(i));
+    llvm::Value* update_dim_size = llvm::ConstantInt::get(
+        start_index[i]->getType(), update_shape.dimensions(i));
+
+    // Clamp the start index so that the update region fits in the operand.
+    // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
+
+    // TODO(b/74360564): This is implementation defined behavior, but is
+    // currently respected by all implementations. Change this if we ever decide
+    // to oficially document different behavior.
+    llvm::Value* max_bound =
+        ir_builder->CreateSub(output_dim_size, update_dim_size);
+    llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
+    start_index[i] = ir_builder->CreateSelect(
+        ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]),
+        zero, start_index[i]);
+
+    start_index[i] = ir_builder->CreateSelect(
+        ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound,
+                               start_index[i]),
+        max_bound, start_index[i]);
   }
 
   auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
     // Calculate output_index, where we'll write the value from update.  For
     // each dimension,
     //
-    //   output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size.
+    //   output_index[dim] = start_index[dim] + update_index[dim]
     //
     IrArray::Index output_index(rank);
     for (int64 i = 0; i < rank; ++i) {
-      llvm::Value* dim_size = llvm::ConstantInt::get(
-          update_index[i]->getType(), output_shape.dimensions(i));
-      llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast(
+      llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast(
           start_index[i], update_index[i]->getType());
-      output_index[i] = ir_builder->CreateURem(
-          ir_builder->CreateAdd(start_index0, update_index[i]), dim_size);
+      output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]);
     }
 
     // Do output[output_index] = update[update_index].
index bfb83fa..49f3a10 100644 (file)
@@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
   }
 
   template <typename IndexT, typename DataT>
-  void TestR1Wrap() {
-    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
-    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1});
+  void TestR1OOB() {
+    // Slice at dimension boundaries, but with out of bounds indices.
+    RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7});
   }
 
   template <typename IndexT, typename DataT>
@@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase {
   }
 
   template <typename IndexT, typename DataT>
-  void TestR2Wrap() {
-    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+  void TestR2OOB() {
+    // Slice at dimension boundaries, but with out of bounds indices.
     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
-                         {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}});
+                         {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   }
 
   template <typename IndexT, typename DataT>
@@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
   }
 
   template <typename IndexT, typename DataT>
-  void TestR3Wrap() {
-    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+  void TestR3OOB() {
+    // Slice at dimension boundaries, but with out of bounds indices.
     RunR3<IndexT, DataT>(
         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
-        {2, 1, 2}, {{{6, 5}}, {{12, 11}}});
+        {2, 1, 2}, {{{5, 6}}, {{11, 12}}});
   }
 
   template <typename IndexT, typename DataT>
@@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase {
 
 XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
 XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
-XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap<int32, int32>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32, int32>(); }
 XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
 XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, float>(); }
 
 XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
 XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
-XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap<int32, int32>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32, int32>(); }
 XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, float>(); }
 XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
 
 XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
 XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
-XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap<int32, float>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32, float>(); }
 XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
 XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, float>(); }
 
@@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
   }
 
   template <typename IndexT, typename DataT>
-  void TestWrap() {
-    // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+  void TestOOB() {
+    // // Slice at dimension boundaries, but with out of bounds indices.
     RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
-                         {10, 1, 2, 3, 4, 5, 8, 9});
+                         {0, 1, 2, 3, 4, 8, 9, 10});
     // R2 Shape: [3, 3]
     RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
-                         {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}});
+                         {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
     // R3 Shape: [2, 3, 2]
     RunR3<IndexT, DataT>(
         {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
-        {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}});
+        {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
   }
 
   template <typename IndexT, typename DataT>
@@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
     Array3D<T> input_values(kSeq, kBatch, kDim);
     Array3D<T> update_values(size, kBatch, kDim);
     Array3D<T> expected_values(kSeq, kBatch, kDim);
+    index = std::min(std::max(0, index), kSeq - size);
 
     input_values.FillIota(static_cast<T>(0));
     T value = static_cast<T>(10);
     update_values.FillIota(static_cast<T>(value));
 
     // TODO(b/34128753) Expected values may vary depending on backend when
-    // the update wraps. According to documentation, the results are technically
-    // implementation specific where the update is out of bounds, and hence
-    // we don't really know what to pass into ComputeAndCompareR3.
+    // the indices are out of bounds.
     expected_values.FillIota(static_cast<T>(0));
     for (int i = 0; i < size; i++) {
       for (int j = 0; j < kBatch; j++) {
         for (int k = 0; k < kDim; k++) {
-          expected_values((index + i) % kSeq, j, k) = value++;
+          expected_values(index + i, j, k) = value++;
         }
       }
     }
@@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
 XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
 XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
 
-XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) {
-  TestWrap<int32, bfloat16>();
-}
-XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap<int32, float>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap<int64, int64>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap<uint64, uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB<int32, bfloat16>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32, float>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB<int64, int64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB<uint64, uint64>(); }
 
 XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
   // Slice at dimension start.
@@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
 // Tests for simple R3 case where the update is contiguous (i.e. the minor
 // two dimensions are not sliced).
 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
-  // Single element, no wrap.
+  // Single element, index in-bounds
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
 }
 
 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
-  // Single element, no wrap.
+  // Single element, index in-bounds
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
 }
 
 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
-  // Multiple element, no wrap.
+  // Multiples element, index in-bounds.
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
 }
 
 XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
-  // Multiple element, no wrap.
+  // Multiples element, index in-bounds.
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
 }
 
-XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) {
-  // Multiple element, wrapping.
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) {
+  // Multiple element, index out of bounds.
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
 }
 
-XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) {
-  // Multiple element, wrapping.
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) {
+  // Multiple element, index out of bounds.
   std::vector<int32> operand_shape({4, 5, 2});
   RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
 }