Introduce a ShapeUtil::ForEachIndexWithStatus, change index type to ArraySlice
authorSanjoy Das <sanjoy@google.com>
Wed, 28 Feb 2018 19:07:10 +0000 (11:07 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Feb 2018 19:11:26 +0000 (11:11 -0800)
This is not used yet, but I need it in a later CL.  I don't specifically need
the argument to be an ArraySlice, but it seemed cleaner than taking a const ref
to a vector.

No functional change intended.

PiperOrigin-RevId: 187352376

tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util_test.cc
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/compiler/xla/shape_util_test.cc

index 823da43b5ab2e9c8e80181efc993735877a2c363..3962a9b31618003f39d297c68adb9ec228c72427 100644 (file)
@@ -223,7 +223,7 @@ Status Literal::CopySliceFromInternal(
     Literal::StrideConfig stride_config(src_literal.shape(), shape(),
                                         copy_size);
 
-    auto copy_proc = [&](const std::vector<int64>& indexes) {
+    auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
       // Map from multi-dimensional index, to source index.
       std::transform(indexes.begin(), indexes.end(), src_base.begin(),
                      src_indexes.begin(), std::plus<int64>());
index d5ae3fd72322fe243f0156dfbe236b6d62ab8c9d..1d58f0cbc72794bed659bcba211dfdc0077ebe06 100644 (file)
@@ -1269,7 +1269,7 @@ Status Literal::Populate(const FnType& generator) {
     int64 minor_dimension_size =
         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
 
-    auto init_function = [&](const std::vector<int64>& indexes) {
+    auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
       const int64 index =
           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
index ee2f4fe87440428c7364fe2924003c5124f4eaa2..9ff0771110077f0c8cbf15c4767392f1371a8480 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 namespace xla {
 namespace {
 
+using tensorflow::gtl::ArraySlice;
 using ::testing::ElementsAre;
 using ::testing::HasSubstr;
 
@@ -214,11 +215,11 @@ TEST_F(LiteralUtilTest, CreateSparse) {
   std::vector<int64> expected_values = {8, 9, 7, 10};
 
   EXPECT_EQ(literal->sparse_indices()->data(),
-            tensorflow::gtl::ArraySlice<int64>(
-                expected_indices.data(), expected_indices.num_elements()));
-  EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
-                                               expected_values.size()),
-            tensorflow::gtl::ArraySlice<int64>(expected_values));
+            ArraySlice<int64>(expected_indices.data(),
+                              expected_indices.num_elements()));
+  EXPECT_EQ(
+      ArraySlice<int64>(literal->data<int64>().data(), expected_values.size()),
+      ArraySlice<int64>(expected_values));
 }
 
 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
   // clang-format on
   std::vector<std::tuple<int64, int64, string>> seen;
   literal->EachCellAsString(
-      [&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+      [&seen](ArraySlice<int64> indices, const string& value) {
         seen.emplace_back(indices[0], indices[1], value);
       });
 
@@ -622,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
   // clang-format on
   auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
 
-  reshape->EachCell<float>(
-      [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
-        EXPECT_EQ(value, original->Get<float>(
-                             {indices[2], indices[3], indices[0], indices[1]}));
-      });
+  reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
+    EXPECT_EQ(value, original->Get<float>(
+                         {indices[2], indices[3], indices[0], indices[1]}));
+  });
 }
 
 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
@@ -863,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
     const int64 zero_base[] = {0, 0, 0, 0};
     const int64 step[] = {1, 1, 1, 1};
     uint32 seqnr = 0;
-    auto init_proc = [&](const std::vector<int64>& indexes) {
+    auto init_proc = [&](ArraySlice<int64> indexes) {
       source->Set(indexes, ++seqnr);
       return true;
     };
@@ -879,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
     bool matched = true;
-    auto check_proc = [&](const std::vector<int64>& indexes) {
+    auto check_proc = [&](ArraySlice<int64> indexes) {
       std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
       std::transform(source_indexes.begin(), source_indexes.end(), src_base,
                      source_indexes.begin(), std::plus<int64>());
@@ -1067,7 +1067,7 @@ TEST_F(LiteralUtilTest, Populate) {
         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
         data.layout);
     auto literal = Literal::CreateFromShape(shape);
-    auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
+    auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
       // Offsets from linear index just to avoid R0 literals to be initialized
       // with zero.
       return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
@@ -1079,7 +1079,7 @@ TEST_F(LiteralUtilTest, Populate) {
     std::vector<int64> zero_base(data.dimensions.size(), 0);
     std::vector<int64> step(data.dimensions.size(), 1);
     bool matched = true;
-    auto check_function = [&](const std::vector<int64>& indexes) {
+    auto check_function = [&](ArraySlice<int64> indexes) {
       auto value = literal->Get<uint32>(indexes);
       matched = matched && (value == generator(indexes));
       return matched;
index c3a3251b7ddcc00f9cbf8e021cff830f6f8bd02f..edb1ad236090bc92e65f4636edc20c733c38b378 100644 (file)
@@ -1222,7 +1222,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     // corresponding index of the resulting padded literal.
     const PaddingConfig& pad_config = pad->padding_config();
 
-    auto func = [&](const std::vector<int64>& input_index) {
+    auto func = [&](ArraySlice<int64> input_index) {
       for (auto i = 0; i < input_index.size(); ++i) {
         // Interior padding occurs logically before edge padding, so in the case
         // of negative edge padding elements are removed from the
@@ -1518,7 +1518,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
             base[result_to_arg_index[i]] = multi_index[i];
           }
 
-          auto func = [&](const std::vector<int64>& input_index) {
+          auto func = [&](ArraySlice<int64> input_index) {
             auto curr_val = arg_literal.Get<ReturnT>(input_index);
 
             // Evaluate computation with specified literal operands.
@@ -1954,7 +1954,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     auto result = operand_literal.CloneToUnique();
     std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
 
-    auto func = [&](const std::vector<int64>& update_index) {
+    auto func = [&](ArraySlice<int64> update_index) {
       std::transform(update_index.begin(), update_index.end(), start.begin(),
                      result_index.begin(), std::plus<int64>());
 
index 8ee263fe5e5fc20edf6d8ce1f56fe72b27b645d0..923315e001f6a4c439be18ee5954517fba2db8d2 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -564,16 +565,16 @@ class ShapeUtil {
   // The visitor_function visitor function should return true if it wants to
   // continue, or false otherwise.
   //
-  // visitor_function must be a callable of type bool(const std::vector<int64>&)
-  // or compatible.
+  // visitor_function must be a callable of type
+  // StatusOr<bool>(ArraySlice<int64>) or compatible.
   template <typename FnType>
-  static void ForEachIndex(const Shape& shape,
-                           tensorflow::gtl::ArraySlice<int64> base,
-                           tensorflow::gtl::ArraySlice<int64> count,
-                           tensorflow::gtl::ArraySlice<int64> incr,
-                           const FnType& visitor_function) {
+  static Status ForEachIndexWithStatus(const Shape& shape,
+                                       tensorflow::gtl::ArraySlice<int64> base,
+                                       tensorflow::gtl::ArraySlice<int64> count,
+                                       tensorflow::gtl::ArraySlice<int64> incr,
+                                       const FnType& visitor_function) {
     if (ShapeUtil::HasZeroElements(shape)) {
-      return;
+      return Status::OK();
     }
     CHECK_EQ(Rank(shape), base.size());
     CHECK_EQ(incr.size(), base.size());
@@ -583,7 +584,11 @@ class ShapeUtil {
     // once with the proper empty indexes.
     int64 n = -1;
     std::vector<int64> indexes(base.begin(), base.end());
-    while (n < rank && visitor_function(indexes)) {
+    while (n < rank) {
+      TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
+      if (!should_continue) {
+        break;
+      }
       // Increments dimensions in minor to major order.
       for (n = 0; n < rank; ++n) {
         int64 dim = LayoutUtil::Minor(shape.layout(), n);
@@ -594,6 +599,21 @@ class ShapeUtil {
         indexes[dim] = base[dim];
       }
     }
+
+    return Status::OK();
+  }
+
+  template <typename FnType>
+  static void ForEachIndex(const Shape& shape,
+                           tensorflow::gtl::ArraySlice<int64> base,
+                           tensorflow::gtl::ArraySlice<int64> count,
+                           tensorflow::gtl::ArraySlice<int64> incr,
+                           const FnType& visitor_function) {
+    ForEachIndexWithStatus(shape, base, count, incr,
+                           [&](tensorflow::gtl::ArraySlice<int64> indices) {
+                             return StatusOr<bool>(visitor_function(indices));
+                           })
+        .IgnoreError();
   }
 
  private:
index 4db97d45b20b86dc60531845c6e28a223203ff7f..a3574156983cfcb53cd240bdf83feef107f11d7e 100644 (file)
@@ -573,10 +573,11 @@ TEST(ShapeUtilTest, ForEachIndex) {
     Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
     // Increments at every invocation.
     int invocations = 0;
-    auto increment_func = [&invocations](const std::vector<int64>& indexes) {
-      invocations++;
-      return true;
-    };
+    auto increment_func =
+        [&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
+          invocations++;
+          return true;
+        };
 
     std::vector<int64> zero_base(data.dimensions.size(), 0);
     std::vector<int64> step(data.dimensions.size(), 1);
@@ -588,6 +589,29 @@ TEST(ShapeUtilTest, ForEachIndex) {
   }
 }
 
+TEST(ShapeUtilTest, ForEachIndexWithStatus) {
+  Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
+  // Increments at every invocation.
+  int invocations = 0;
+  auto increment_func =
+      [&invocations](
+          tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
+    if (++invocations == 5) {
+      return Unimplemented("Cannot increment beyond 5.");
+    }
+    return true;
+  };
+
+  Status error_status = ShapeUtil::ForEachIndexWithStatus(
+      shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1},
+      increment_func);
+
+  EXPECT_FALSE(error_status.ok());
+  EXPECT_THAT(error_status.error_message(),
+              ::testing::HasSubstr("Cannot increment beyond 5."));
+  EXPECT_EQ(invocations, 5);
+}
+
 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
   // All output dimensions should be unmodified. One of the input dimensions is
   // modified because the input rank is larger by one.