StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
const Literal& operand_literal, const Literal& update_literal,
const Literal& start_indices_literal) {
- auto start_indices_typed = start_indices_literal.data<IndexT>();
- const std::vector<int64> start(start_indices_typed.begin(),
- start_indices_typed.end());
-
auto result = operand_literal.CloneToUnique();
- std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
+ auto start_indices_typed = start_indices_literal.data<IndexT>();
+ const auto rank = ShapeUtil::Rank(result->shape());
+ std::vector<int64> start(rank, 0);
+ for (int64 i = 0; i < rank; ++i) {
+ // All other implementations currently wrap-around the index, so this
+ // should do so as well.
+ start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
+ start[i] += (start[i] < 0) * result->shape().dimensions(i);
+ }
+ std::vector<int64> result_index(rank, 0);
auto func = [&](ArraySlice<int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
-
+ // Same as above, wrap-around only to match other implementations'
+ // semantics.
+ std::transform(result_index.begin(), result_index.end(),
+ result->shape().dimensions().begin(), result_index.begin(),
+ std::modulus<int64>());
result->Set<ReturnT>(result_index,
update_literal.Get<ReturnT>(update_index));
return true;
name = "dynamic_ops_test",
timeout = "moderate",
srcs = ["dynamic_ops_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ ],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",