[XLA] Make HloEvaluator use wrap-around semantics for DynamicUpdateSlice.
authorMichael Kuperstein <mkuper@google.com>
Mon, 5 Mar 2018 22:33:22 +0000 (14:33 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 22:40:09 +0000 (14:40 -0800)
PiperOrigin-RevId: 187923671

tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/tests/BUILD

index 42de7ad..534433b 100644 (file)
@@ -1970,17 +1970,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
   StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
       const Literal& operand_literal, const Literal& update_literal,
       const Literal& start_indices_literal) {
-    auto start_indices_typed = start_indices_literal.data<IndexT>();
-    const std::vector<int64> start(start_indices_typed.begin(),
-                                   start_indices_typed.end());
-
     auto result = operand_literal.CloneToUnique();
-    std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
+    auto start_indices_typed = start_indices_literal.data<IndexT>();
+    const auto rank = ShapeUtil::Rank(result->shape());
+    std::vector<int64> start(rank, 0);
+    for (int64 i = 0; i < rank; ++i) {
+      // All other implementations currently wrap-around the index, so this
+      // should do so as well.
+      start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
+      start[i] += (start[i] < 0) * result->shape().dimensions(i);
+    }
+    std::vector<int64> result_index(rank, 0);
 
     auto func = [&](ArraySlice<int64> update_index) {
       std::transform(update_index.begin(), update_index.end(), start.begin(),
                      result_index.begin(), std::plus<int64>());
-
+      // Same as above, wrap-around only to match other implementations'
+      // semantics.
+      std::transform(result_index.begin(), result_index.end(),
+                     result->shape().dimensions().begin(), result_index.begin(),
+                     std::modulus<int64>());
       result->Set<ReturnT>(result_index,
                            update_literal.Get<ReturnT>(update_index));
       return true;
index 7c95b03..1b2008a 100644 (file)
@@ -948,6 +948,9 @@ xla_test(
     name = "dynamic_ops_test",
     timeout = "moderate",
     srcs = ["dynamic_ops_test.cc"],
+    tags = [
+        "enable_for_xla_interpreter",
+    ],
     deps = [
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:reference_util",