Literal::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
- auto copy_proc = [&](const std::vector<int64>& indexes) {
+ auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
src_indexes.begin(), std::plus<int64>());
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
- auto init_function = [&](const std::vector<int64>& indexes) {
+ auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
namespace xla {
namespace {
+using tensorflow::gtl::ArraySlice;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
std::vector<int64> expected_values = {8, 9, 7, 10};
EXPECT_EQ(literal->sparse_indices()->data(),
- tensorflow::gtl::ArraySlice<int64>(
- expected_indices.data(), expected_indices.num_elements()));
- EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
- expected_values.size()),
- tensorflow::gtl::ArraySlice<int64>(expected_values));
+ ArraySlice<int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(
+ ArraySlice<int64>(literal->data<int64>().data(), expected_values.size()),
+ ArraySlice<int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
literal->EachCellAsString(
- [&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ [&seen](ArraySlice<int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
// clang-format on
auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
- {indices[2], indices[3], indices[0], indices[1]}));
- });
+ reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
+ EXPECT_EQ(value, original->Get<float>(
+ {indices[2], indices[3], indices[0], indices[1]}));
+ });
}
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
const int64 zero_base[] = {0, 0, 0, 0};
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
- auto init_proc = [&](const std::vector<int64>& indexes) {
+ auto init_proc = [&](ArraySlice<int64> indexes) {
source->Set(indexes, ++seqnr);
return true;
};
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
bool matched = true;
- auto check_proc = [&](const std::vector<int64>& indexes) {
+ auto check_proc = [&](ArraySlice<int64> indexes) {
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
source_indexes.begin(), std::plus<int64>());
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
auto literal = Literal::CreateFromShape(shape);
- auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
+ auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
- auto check_function = [&](const std::vector<int64>& indexes) {
+ auto check_function = [&](ArraySlice<int64> indexes) {
auto value = literal->Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
// corresponding index of the resulting padded literal.
const PaddingConfig& pad_config = pad->padding_config();
- auto func = [&](const std::vector<int64>& input_index) {
+ auto func = [&](ArraySlice<int64> input_index) {
for (auto i = 0; i < input_index.size(); ++i) {
// Interior padding occurs logically before edge padding, so in the case
// of negative edge padding elements are removed from the
base[result_to_arg_index[i]] = multi_index[i];
}
- auto func = [&](const std::vector<int64>& input_index) {
+ auto func = [&](ArraySlice<int64> input_index) {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
auto result = operand_literal.CloneToUnique();
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
- auto func = [&](const std::vector<int64>& update_index) {
+ auto func = [&](ArraySlice<int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
// The visitor_function visitor function should return true if it wants to
// continue, or false otherwise.
//
- // visitor_function must be a callable of type bool(const std::vector<int64>&)
- // or compatible.
+ // visitor_function must be a callable of type
+ // StatusOr<bool>(ArraySlice<int64>) or compatible.
template <typename FnType>
- static void ForEachIndex(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
- const FnType& visitor_function) {
+ static Status ForEachIndexWithStatus(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const FnType& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
- return;
+ return Status::OK();
}
CHECK_EQ(Rank(shape), base.size());
CHECK_EQ(incr.size(), base.size());
// once with the proper empty indexes.
int64 n = -1;
std::vector<int64> indexes(base.begin(), base.end());
- while (n < rank && visitor_function(indexes)) {
+ while (n < rank) {
+ TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
+ if (!should_continue) {
+ break;
+ }
// Increments dimensions in minor to major order.
for (n = 0; n < rank; ++n) {
int64 dim = LayoutUtil::Minor(shape.layout(), n);
indexes[dim] = base[dim];
}
}
+
+ return Status::OK();
+ }
+
+ template <typename FnType>
+ static void ForEachIndex(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const FnType& visitor_function) {
+ ForEachIndexWithStatus(shape, base, count, incr,
+ [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return StatusOr<bool>(visitor_function(indices));
+ })
+ .IgnoreError();
}
private:
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
- auto increment_func = [&invocations](const std::vector<int64>& indexes) {
- invocations++;
- return true;
- };
+ auto increment_func =
+ [&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
+ invocations++;
+ return true;
+ };
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
}
}
+TEST(ShapeUtilTest, ForEachIndexWithStatus) {
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
+ // Increments at every invocation.
+ int invocations = 0;
+ auto increment_func =
+ [&invocations](
+ tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
+ if (++invocations == 5) {
+ return Unimplemented("Cannot increment beyond 5.");
+ }
+ return true;
+ };
+
+ Status error_status = ShapeUtil::ForEachIndexWithStatus(
+ shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1},
+ increment_func);
+
+ EXPECT_FALSE(error_status.ok());
+ EXPECT_THAT(error_status.error_message(),
+ ::testing::HasSubstr("Cannot increment beyond 5."));
+ EXPECT_EQ(invocations, 5);
+}
+
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.