[XLA] Add tests for R1 PRED Slices.
authorMichael Kuperstein <mkuper@google.com>
Tue, 27 Mar 2018 03:18:03 +0000 (20:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 03:20:07 +0000 (20:20 -0700)
PiperOrigin-RevId: 190566036

tensorflow/compiler/xla/tests/slice_test.cc

index fe36df1..a14a365 100644 (file)
@@ -193,7 +193,9 @@ class SliceR1Test : public ClientLibraryTestBase,
  protected:
   template <typename NativeT>
   void Run(const R1Spec& spec) {
-    std::vector<NativeT> input(spec.input_dim0);
+    // This can't be an std::vector, since you can't grab an ArraySlice of a
+    // vector<bool>.
+    tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
     std::iota(input.begin(), input.end(), NativeT());
 
     ComputationBuilder builder(client_, TestName());
@@ -201,7 +203,8 @@ class SliceR1Test : public ClientLibraryTestBase,
     builder.Slice(original, {spec.slice_start}, {spec.slice_limit},
                   {spec.slice_stride});
 
-    std::vector<NativeT> expected;
+    // Ditto.
+    tensorflow::gtl::InlinedVector<NativeT, 1> expected;
     for (int i = spec.slice_start; i < spec.slice_limit;
          i += spec.slice_stride) {
       expected.push_back(i);
@@ -230,6 +233,8 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
 
 XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
 
+XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
+
 // Tests for R1 slice ops.
 // The format for each testcase is {input size, start, limit, stride}.
 // clang-format off