Support shrink_axis_mask argument of StridedSlice Op for TfLite.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 29 Jan 2018 20:43:46 +0000 (12:43 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 29 Jan 2018 20:54:24 +0000 (12:54 -0800)
PiperOrigin-RevId: 183709796

tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/strided_slice.cc
tensorflow/contrib/lite/kernels/strided_slice_test.cc
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py

index 31bade26f98274e64fc7e224a16d5b78bc8bbe68..4bcf4993e92136388b3febe7e04030943f4b54f3 100644 (file)
@@ -2370,13 +2370,15 @@ inline int StartIndex(int start, int stride, int dim, bool masked) {
   return masked ? (stride > 0 ? 0 : dim - 1) : start;
 }
 
-inline int StopIndex(int stop, int stride, int dim, bool masked) {
-  return masked ? (stride > 0 ? dim : -1) : stop;
+inline int StopIndex(int start, int stop, int stride, int dim, bool masked,
+                     bool shrink_axis_masked) {
+  return shrink_axis_masked ? stride > 0 ? start + 1 : start - 1
+                            : masked ? (stride > 0 ? dim : -1) : stop;
 }
 
 template <typename T>
 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
-                         int begin_mask, int end_mask,
+                         int begin_mask, int end_mask, int shrink_axis_mask,
                          const std::vector<int>& starts,
                          const std::vector<int>& stops,
                          const std::vector<int>& strides, T* output_data,
@@ -2387,19 +2389,23 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
   const int start_b =
       StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8);
   const int stop_b =
-      StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8);
+      StopIndex(start_b, stops[3], strides[3], input_dims.sizes[3],
+                end_mask & 8, shrink_axis_mask & 8);
   const int start_h =
       StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4);
   const int stop_h =
-      StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4);
+      StopIndex(start_h, stops[2], strides[2], input_dims.sizes[2],
+                end_mask & 4, shrink_axis_mask & 4);
   const int start_w =
       StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2);
   const int stop_w =
-      StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2);
+      StopIndex(start_w, stops[1], strides[1], input_dims.sizes[1],
+                end_mask & 2, shrink_axis_mask & 2);
   const int start_d =
       StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1);
   const int stop_d =
-      StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1);
+      StopIndex(start_d, stops[0], strides[0], input_dims.sizes[0],
+                end_mask & 1, shrink_axis_mask & 1);
 
   T* out_ptr = output_data;
   for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]);
@@ -2417,6 +2423,18 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
   }
 }
 
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+                         int begin_mask, int end_mask,
+                         const std::vector<int>& starts,
+                         const std::vector<int>& stops,
+                         const std::vector<int>& strides, T* output_data,
+                         const Dims<4>& output_dims) {
+  StridedSlice(input_data, input_dims, begin_mask, end_mask,
+               /*shrink_axis_mask=*/0, starts, stops, strides, output_data,
+               output_dims);
+}
+
 template <typename T>
 inline void Slice(const T* input_data, const Dims<4>& input_dims,
                   const std::vector<int>& begin, const std::vector<int>& size,
index 91ba4a9b7851c35a5138f4ccea307c810a4731a1..c510ee3b9f4555c79a49df5932b49ef735d2feef 100644 (file)
@@ -81,8 +81,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
                      "ellipsis_mask is not implemented yet.");
   TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
                      "new_axis_mask is not implemented yet.");
-  TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0,
-                     "shrink_axis_mask is not implemented yet.");
 
   // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
   op_context.output->allocation_type = kTfLiteDynamic;
@@ -153,9 +151,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   std::vector<int> starts;
   std::vector<int> stops;
   std::vector<int> strides;
+  std::vector<int> output_shape_vector;
 
-  // Determine size of output tensor and map indices
-  TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims);
   for (int idx = op_context.dims - 1; idx >= 0; --idx) {
     int dim = op_context.input->dims->data[idx];
     int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
@@ -174,14 +171,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
                            pos_stride);
 
     // This is valid for both positive and negative strides
-    output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride));
-    output_shape->data[idx] =
-        output_shape->data[idx] < 0 ? 0 : output_shape->data[idx];
+    int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
+    dim_shape = dim_shape < 0 ? 0 : dim_shape;
+
+    if (!(op_context.params->shrink_axis_mask & (1 << idx))) {
+      output_shape_vector.push_back(dim_shape);
+    }
+
     starts.emplace_back(begin);
     stops.emplace_back(end);
     strides.emplace_back(stride);
   }
 
+  TfLiteIntArray* output_shape =
+      TfLiteIntArrayCreate(output_shape_vector.size());
+
+  std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
+                    output_shape->data);
+
   for (int i = op_context.dims; i < kMaxDim; i++) {
     starts.emplace_back(0);
     stops.emplace_back(1);
@@ -202,13 +209,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
   op_context.params->end_mask =
       ReverseMaskBits(op_context.params->end_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                 \
-  kernel_type::StridedSlice(                                          \
-      GetTensorData<data_type>(op_context.input),                     \
-      GetTensorDims(op_context.input), op_context.params->begin_mask, \
-      op_context.params->end_mask, starts, stops, strides,            \
-      GetTensorData<data_type>(op_context.output),                    \
+  op_context.params->shrink_axis_mask =
+      ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                      \
+  kernel_type::StridedSlice(                                               \
+      GetTensorData<data_type>(op_context.input),                          \
+      GetTensorDims(op_context.input), op_context.params->begin_mask,      \
+      op_context.params->end_mask, op_context.params->shrink_axis_mask,    \
+      starts, stops, strides, GetTensorData<data_type>(op_context.output), \
       GetTensorDims(op_context.output))
 
   switch (op_context.input->type) {
index cd4a364682c0e66b2ceec92c0b34461945caf779..5bc7dc353b4904bdb182ce029b9b7c654e4a5f33 100644 (file)
@@ -79,8 +79,6 @@ TEST(StridedSliceOpTest, UnssupportedArgs) {
                "ellipsis_mask is not implemented yet.");
   EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
                "new_axis_mask is not implemented yet.");
-  EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1),
-               "shrink_axis_mask is not implemented yet.");
 }
 
 TEST(StridedSliceOpTest, In1D) {
@@ -213,6 +211,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
 }
+
 TEST(StridedSliceOpTest, In1D_NegStride) {
   StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
   m.SetInput({1, 2, 3});
@@ -234,6 +233,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
 }
+
 TEST(StridedSliceOpTest, In1D_OddLenStride2) {
   StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
   m.SetInput({1, 2, 3});
@@ -255,6 +255,7 @@ TEST(StridedSliceOpTest, In2D_Identity) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
 }
+
 TEST(StridedSliceOpTest, In2D) {
   StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
   m.SetInput({1, 2, 3, 4, 5, 6});
@@ -320,6 +321,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
 }
+
 TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
   StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
   m.SetInput({1, 2, 3, 4, 5, 6});
@@ -354,6 +356,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) {
   EXPECT_THAT(m.GetOutput(),
               ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}));
 }
+
 TEST(StridedSliceOpTest, In3D_Strided2) {
   StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
@@ -365,6 +368,159 @@ TEST(StridedSliceOpTest, In3D_Strided2) {
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5}));
 }
 
+TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
+  StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4});
+  m.SetBegin({1});
+  m.SetEnd({3});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) {
+  StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4});
+  m.SetBegin({2});
+  m.SetEnd({1});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
+  StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4});
+  m.SetBegin({1});
+  m.SetEnd({3});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
+  StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4});
+  m.SetBegin({-2});
+  m.SetEnd({-3});
+  m.SetStrides({-1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
+  StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4, 5, 6});
+  m.SetBegin({0, 0});
+  m.SetEnd({2, 3});
+  m.SetStrides({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
+  StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
+  m.SetInput({1, 2, 3, 4, 5, 6});
+  m.SetBegin({0, 0});
+  m.SetEnd({2, 3});
+  m.SetStrides({1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
+  StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+  m.SetInput({1, 2, 3, 4, 5, 6});
+  m.SetBegin({0, 0});
+  m.SetEnd({2, 3});
+  m.SetStrides({1, 1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
+  StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  m.SetBegin({0, 0, 0});
+  m.SetEnd({2, 3, 2});
+  m.SetStrides({1, 1, 1});
+  m.Invoke();
+  EXPECT_TRUE(m.GetOutputShape().empty());
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
 }  // namespace
 }  // namespace tflite
 
index 50e8ca75f8efd600d4773b83cd2c8de11c9d13ca..7f84a0ab9bb5b9fbacba1728f8bfe4df25f13f86 100644 (file)
@@ -197,7 +197,7 @@ cc_binary(
 
 tf_cc_test(
     name = "generated_examples_zip_test",
-    size = "medium",
+    size = "large",
     srcs = ["generated_examples_zip_test.cc"],
     args = [
         "--zip_files_dir=tensorflow/contrib/lite/testing/optest",
index fc8149bef98323743763ef69237926b1f4481144..f75d7c4bb990ae30c7f1a84be64f8151458ab6cc 100644 (file)
@@ -1533,8 +1533,9 @@ def make_strided_slice_tests(zip_path):
           "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
           "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
           "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]],
-          "begin_mask": [None, 0, 1, 2, 8],
-          "end_mask": [None, 0, 1, 2, 8],
+          "begin_mask": [None, 1, 2, 8],
+          "end_mask": [None, 1, 2, 8],
+          "shrink_axis_mask": [None, 1, 2, 4, 8, 11, 15, -1],
       },
       # 2-D
       {
@@ -1544,8 +1545,9 @@ def make_strided_slice_tests(zip_path):
           "begin": [[0, 0], [1, 0]],
           "end": [[2, 3], [2, 2]],
           "strides": [None, [1, 1], [2, 2]],
-          "begin_mask": [None, 0, 1, 2],
-          "end_mask": [None, 0, 1, 2],
+          "begin_mask": [None, 1, 2],
+          "end_mask": [None, 1, 2],
+          "shrink_axis_mask": [None, 1, 2, 3, -1],
       },
       # Negative strides
       {
@@ -1555,8 +1557,9 @@ def make_strided_slice_tests(zip_path):
           "begin": [[0, -1]],
           "end": [[2, -3]],
           "strides": [[1, -1]],
-          "begin_mask": [None, 0, 1, 2],
-          "end_mask": [None, 0, 1, 2],
+          "begin_mask": [None, 1, 2],
+          "end_mask": [None, 1, 2],
+          "shrink_axis_mask": [None, 1, 2, 3, -1],
       },
   ]